diff options
author | Mehrdad Afshari <mmx@google.com> | 2017-12-11 11:32:11 -0800 |
---|---|---|
committer | Mehrdad Afshari <mmx@google.com> | 2017-12-11 11:32:11 -0800 |
commit | c5ba665b2259c0cff18666260bc5ea34d4ff753a (patch) | |
tree | 542d2cb4c69104c010b696c2edacab955439fc53 | |
parent | 57dc5453db6080e67f18f63a8c3aaec24e0aa0dc (diff) | |
parent | 4f8ffd852b21a35f37f5b110fc5d8ca864d04642 (diff) |
Merge remote-tracking branch 'origin/v1.8.x'
46 files changed, 2487 insertions, 626 deletions
@@ -27,13 +27,13 @@ Libraries in different languages may be in different states of development. We a | Language | Source | Status | |-------------------------|-------------------------------------|---------| -| Shared C [core library] | [src/core](src/core) | 1.6 | -| C++ | [src/cpp](src/cpp) | 1.6 | -| Ruby | [src/ruby](src/ruby) | 1.6 | -| Python | [src/python](src/python) | 1.6 | -| PHP | [src/php](src/php) | 1.6 | -| C# | [src/csharp](src/csharp) | 1.6 | -| Objective-C | [src/objective-c](src/objective-c) | 1.6 | +| Shared C [core library] | [src/core](src/core) | 1.8 | +| C++ | [src/cpp](src/cpp) | 1.8 | +| Ruby | [src/ruby](src/ruby) | 1.8 | +| Python | [src/python](src/python) | 1.8 | +| PHP | [src/php](src/php) | 1.8 | +| C# | [src/csharp](src/csharp) | 1.8 | +| Objective-C | [src/objective-c](src/objective-c) | 1.8 | Java source code is in the [grpc-java](http://github.com/grpc/grpc-java) repository. Go source code is in the diff --git a/examples/python/interceptors/default_value/default_value_client_interceptor.py b/examples/python/interceptors/default_value/default_value_client_interceptor.py new file mode 100644 index 0000000000..c549f2b861 --- /dev/null +++ b/examples/python/interceptors/default_value/default_value_client_interceptor.py @@ -0,0 +1,68 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Interceptor that adds headers to outgoing requests.""" + +import collections + +import grpc + + +class _ConcreteValue(grpc.Future): + + def __init__(self, result): + self._result = result + + def cancel(self): + return False + + def cancelled(self): + return False + + def running(self): + return False + + def done(self): + return True + + def result(self, timeout=None): + return self._result + + def exception(self, timeout=None): + return None + + def traceback(self, timeout=None): + return None + + def add_done_callback(self, fn): + fn(self._result) + + +class DefaultValueClientInterceptor(grpc.UnaryUnaryClientInterceptor, + grpc.StreamUnaryClientInterceptor): + + def __init__(self, value): + self._default = _ConcreteValue(value) + + def _intercept_call(self, continuation, client_call_details, + request_or_iterator): + response = continuation(client_call_details, request_or_iterator) + return self._default if response.exception() else response + + def intercept_unary_unary(self, continuation, client_call_details, request): + return self._intercept_call(continuation, client_call_details, request) + + def intercept_stream_unary(self, continuation, client_call_details, + request_iterator): + return self._intercept_call(continuation, client_call_details, + request_iterator) diff --git a/examples/python/interceptors/default_value/greeter_client.py b/examples/python/interceptors/default_value/greeter_client.py new file mode 100644 index 0000000000..aba7571d83 --- /dev/null +++ b/examples/python/interceptors/default_value/greeter_client.py @@ -0,0 +1,38 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The Python implementation of the gRPC helloworld.Greeter client.""" + +from __future__ import print_function + +import grpc + +import helloworld_pb2 +import helloworld_pb2_grpc +import default_value_client_interceptor + + +def run(): + default_value = helloworld_pb2.HelloReply( + message='Hello from your local interceptor!') + default_value_interceptor = default_value_client_interceptor.DefaultValueClientInterceptor( + default_value) + channel = grpc.insecure_channel('localhost:50051') + channel = grpc.intercept_channel(channel, default_value_interceptor) + stub = helloworld_pb2_grpc.GreeterStub(channel) + response = stub.SayHello(helloworld_pb2.HelloRequest(name='you')) + print("Greeter client received: " + response.message) + + +if __name__ == '__main__': + run() diff --git a/examples/python/interceptors/default_value/helloworld_pb2.py b/examples/python/interceptors/default_value/helloworld_pb2.py new file mode 100644 index 0000000000..e18ab9acc7 --- /dev/null +++ b/examples/python/interceptors/default_value/helloworld_pb2.py @@ -0,0 +1,134 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: helloworld.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +from google.protobuf import descriptor_pb2 +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='helloworld.proto', + package='helloworld', + syntax='proto3', + serialized_pb=_b('\n\x10helloworld.proto\x12\nhelloworld\"\x1c\n\x0cHelloRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\"\x1d\n\nHelloReply\x12\x0f\n\x07message\x18\x01 \x01(\t2I\n\x07Greeter\x12>\n\x08SayHello\x12\x18.helloworld.HelloRequest\x1a\x16.helloworld.HelloReply\"\x00\x42\x36\n\x1bio.grpc.examples.helloworldB\x0fHelloWorldProtoP\x01\xa2\x02\x03HLWb\x06proto3') +) + + + + +_HELLOREQUEST = _descriptor.Descriptor( + name='HelloRequest', + full_name='helloworld.HelloRequest', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='name', full_name='helloworld.HelloRequest.name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=32, + serialized_end=60, +) + + +_HELLOREPLY = _descriptor.Descriptor( + name='HelloReply', + full_name='helloworld.HelloReply', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='message', full_name='helloworld.HelloReply.message', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=62, + serialized_end=91, +) + +DESCRIPTOR.message_types_by_name['HelloRequest'] = _HELLOREQUEST +DESCRIPTOR.message_types_by_name['HelloReply'] = _HELLOREPLY +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +HelloRequest = _reflection.GeneratedProtocolMessageType('HelloRequest', (_message.Message,), dict( + DESCRIPTOR = _HELLOREQUEST, + __module__ = 'helloworld_pb2' + # @@protoc_insertion_point(class_scope:helloworld.HelloRequest) + )) +_sym_db.RegisterMessage(HelloRequest) + +HelloReply = _reflection.GeneratedProtocolMessageType('HelloReply', (_message.Message,), dict( + DESCRIPTOR = _HELLOREPLY, + __module__ = 'helloworld_pb2' + # @@protoc_insertion_point(class_scope:helloworld.HelloReply) + )) +_sym_db.RegisterMessage(HelloReply) + + +DESCRIPTOR.has_options = True +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\n\033io.grpc.examples.helloworldB\017HelloWorldProtoP\001\242\002\003HLW')) + +_GREETER = _descriptor.ServiceDescriptor( + name='Greeter', + full_name='helloworld.Greeter', + file=DESCRIPTOR, + index=0, + options=None, + serialized_start=93, + serialized_end=166, + methods=[ + _descriptor.MethodDescriptor( + name='SayHello', + full_name='helloworld.Greeter.SayHello', + index=0, + containing_service=None, + input_type=_HELLOREQUEST, + output_type=_HELLOREPLY, + options=None, + ), +]) +_sym_db.RegisterServiceDescriptor(_GREETER) + +DESCRIPTOR.services_by_name['Greeter'] = _GREETER + +# @@protoc_insertion_point(module_scope) diff --git a/examples/python/interceptors/default_value/helloworld_pb2_grpc.py b/examples/python/interceptors/default_value/helloworld_pb2_grpc.py new file mode 100644 index 0000000000..18e07d1679 --- /dev/null +++ b/examples/python/interceptors/default_value/helloworld_pb2_grpc.py @@ -0,0 +1,46 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +import grpc + +import helloworld_pb2 as helloworld__pb2 + + +class GreeterStub(object): + """The greeting service definition. + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.SayHello = channel.unary_unary( + '/helloworld.Greeter/SayHello', + request_serializer=helloworld__pb2.HelloRequest.SerializeToString, + response_deserializer=helloworld__pb2.HelloReply.FromString, + ) + + +class GreeterServicer(object): + """The greeting service definition. + """ + + def SayHello(self, request, context): + """Sends a greeting + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_GreeterServicer_to_server(servicer, server): + rpc_method_handlers = { + 'SayHello': grpc.unary_unary_rpc_method_handler( + servicer.SayHello, + request_deserializer=helloworld__pb2.HelloRequest.FromString, + response_serializer=helloworld__pb2.HelloReply.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'helloworld.Greeter', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) diff --git a/examples/python/interceptors/headers/generic_client_interceptor.py b/examples/python/interceptors/headers/generic_client_interceptor.py new file mode 100644 index 0000000000..30b0755aaf --- /dev/null +++ b/examples/python/interceptors/headers/generic_client_interceptor.py @@ -0,0 +1,55 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Base class for interceptors that operate on all RPC types.""" + +import grpc + + +class _GenericClientInterceptor( + grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor, grpc.StreamStreamClientInterceptor): + + def __init__(self, interceptor_function): + self._fn = interceptor_function + + def intercept_unary_unary(self, continuation, client_call_details, request): + new_details, new_request_iterator, postprocess = self._fn( + client_call_details, iter((request,)), False, False) + response = continuation(new_details, next(new_request_iterator)) + return postprocess(response) if postprocess else response + + def intercept_unary_stream(self, continuation, client_call_details, + request): + new_details, new_request_iterator, postprocess = self._fn( + client_call_details, iter((request,)), False, True) + response_it = continuation(new_details, new_request_iterator) + return postprocess(response_it) if postprocess else response_it + + def intercept_stream_unary(self, continuation, client_call_details, + request_iterator): + new_details, new_request_iterator, postprocess = self._fn( + client_call_details, request_iterator, True, False) + response = continuation(new_details, next(new_request_iterator)) + return postprocess(response) if postprocess else response + + def intercept_stream_stream(self, continuation, client_call_details, + request_iterator): + new_details, new_request_iterator, postprocess = self._fn( + client_call_details, request_iterator, True, True) + response_it = continuation(new_details, new_request_iterator) + return postprocess(response_it) if postprocess else response_it + + +def create(intercept_call): + return _GenericClientInterceptor(intercept_call) diff --git a/examples/python/interceptors/headers/greeter_client.py b/examples/python/interceptors/headers/greeter_client.py new file mode 100644 index 0000000000..2b0dd3e177 --- /dev/null +++ b/examples/python/interceptors/headers/greeter_client.py @@ -0,0 +1,36 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The Python implementation of the GRPC helloworld.Greeter client.""" + +from __future__ import print_function + +import grpc + +import helloworld_pb2 +import helloworld_pb2_grpc +import header_manipulator_client_interceptor + + +def run(): + header_adder_interceptor = header_manipulator_client_interceptor.header_adder_interceptor( + 'one-time-password', '42') + channel = grpc.insecure_channel('localhost:50051') + channel = grpc.intercept_channel(channel, header_adder_interceptor) + stub = helloworld_pb2_grpc.GreeterStub(channel) + response = stub.SayHello(helloworld_pb2.HelloRequest(name='you')) + print("Greeter client received: " + response.message) + + +if __name__ == '__main__': + run() diff --git a/examples/python/interceptors/headers/greeter_server.py b/examples/python/interceptors/headers/greeter_server.py new file mode 100644 index 0000000000..01968609b5 --- /dev/null +++ b/examples/python/interceptors/headers/greeter_server.py @@ -0,0 +1,52 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The Python implementation of the GRPC helloworld.Greeter server.""" + +from concurrent import futures +import time + +import grpc + +import helloworld_pb2 +import helloworld_pb2_grpc +from request_header_validator_interceptor import RequestHeaderValidatorInterceptor + +_ONE_DAY_IN_SECONDS = 60 * 60 * 24 + + +class Greeter(helloworld_pb2_grpc.GreeterServicer): + + def SayHello(self, request, context): + return helloworld_pb2.HelloReply(message='Hello, %s!' % request.name) + + +def serve(): + header_validator = RequestHeaderValidatorInterceptor( + 'one-time-password', '42', grpc.StatusCode.UNAUTHENTICATED, + 'Access denied!') + server = grpc.server( + futures.ThreadPoolExecutor(max_workers=10), + interceptors=(header_validator,)) + helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(), server) + server.add_insecure_port('[::]:50051') + server.start() + try: + while True: + time.sleep(_ONE_DAY_IN_SECONDS) + except KeyboardInterrupt: + server.stop(0) + + +if __name__ == '__main__': + serve() diff --git a/examples/python/interceptors/headers/header_manipulator_client_interceptor.py b/examples/python/interceptors/headers/header_manipulator_client_interceptor.py new file mode 100644 index 0000000000..ac7c605144 --- /dev/null +++ b/examples/python/interceptors/headers/header_manipulator_client_interceptor.py @@ -0,0 +1,42 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Interceptor that adds headers to outgoing requests.""" + +import collections + +import grpc +import generic_client_interceptor + + +class _ClientCallDetails( + collections.namedtuple('_ClientCallDetails', + ('method', 'timeout', 'metadata', + 'credentials')), grpc.ClientCallDetails): + pass + + +def header_adder_interceptor(header, value): + + def intercept_call(client_call_details, request_iterator, request_streaming, + response_streaming): + metadata = [] + if client_call_details.metadata is not None: + metadata = list(client_call_details.metadata) + metadata.append((header, value,)) + client_call_details = _ClientCallDetails( + client_call_details.method, client_call_details.timeout, metadata, + client_call_details.credentials) + return client_call_details, request_iterator, None + + return generic_client_interceptor.create(intercept_call) diff --git a/examples/python/interceptors/headers/helloworld_pb2.py b/examples/python/interceptors/headers/helloworld_pb2.py new file mode 100644 index 0000000000..e18ab9acc7 --- /dev/null +++ b/examples/python/interceptors/headers/helloworld_pb2.py @@ -0,0 +1,134 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: helloworld.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +from google.protobuf import descriptor_pb2 +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='helloworld.proto', + package='helloworld', + syntax='proto3', + serialized_pb=_b('\n\x10helloworld.proto\x12\nhelloworld\"\x1c\n\x0cHelloRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\"\x1d\n\nHelloReply\x12\x0f\n\x07message\x18\x01 \x01(\t2I\n\x07Greeter\x12>\n\x08SayHello\x12\x18.helloworld.HelloRequest\x1a\x16.helloworld.HelloReply\"\x00\x42\x36\n\x1bio.grpc.examples.helloworldB\x0fHelloWorldProtoP\x01\xa2\x02\x03HLWb\x06proto3') +) + + + + +_HELLOREQUEST = _descriptor.Descriptor( + name='HelloRequest', + full_name='helloworld.HelloRequest', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='name', full_name='helloworld.HelloRequest.name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=32, + serialized_end=60, +) + + +_HELLOREPLY = _descriptor.Descriptor( + name='HelloReply', + full_name='helloworld.HelloReply', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='message', full_name='helloworld.HelloReply.message', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=62, + serialized_end=91, +) + +DESCRIPTOR.message_types_by_name['HelloRequest'] = _HELLOREQUEST +DESCRIPTOR.message_types_by_name['HelloReply'] = _HELLOREPLY +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +HelloRequest = _reflection.GeneratedProtocolMessageType('HelloRequest', (_message.Message,), dict( + DESCRIPTOR = _HELLOREQUEST, + __module__ = 'helloworld_pb2' + # @@protoc_insertion_point(class_scope:helloworld.HelloRequest) + )) +_sym_db.RegisterMessage(HelloRequest) + +HelloReply = _reflection.GeneratedProtocolMessageType('HelloReply', (_message.Message,), dict( + DESCRIPTOR = _HELLOREPLY, + __module__ = 'helloworld_pb2' + # @@protoc_insertion_point(class_scope:helloworld.HelloReply) + )) +_sym_db.RegisterMessage(HelloReply) + + +DESCRIPTOR.has_options = True +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\n\033io.grpc.examples.helloworldB\017HelloWorldProtoP\001\242\002\003HLW')) + +_GREETER = _descriptor.ServiceDescriptor( + name='Greeter', + full_name='helloworld.Greeter', + file=DESCRIPTOR, + index=0, + options=None, + serialized_start=93, + serialized_end=166, + methods=[ + _descriptor.MethodDescriptor( + name='SayHello', + full_name='helloworld.Greeter.SayHello', + index=0, + containing_service=None, + input_type=_HELLOREQUEST, + output_type=_HELLOREPLY, + options=None, + ), +]) +_sym_db.RegisterServiceDescriptor(_GREETER) + +DESCRIPTOR.services_by_name['Greeter'] = _GREETER + +# @@protoc_insertion_point(module_scope) diff --git a/examples/python/interceptors/headers/helloworld_pb2_grpc.py b/examples/python/interceptors/headers/helloworld_pb2_grpc.py new file mode 100644 index 0000000000..18e07d1679 --- /dev/null +++ b/examples/python/interceptors/headers/helloworld_pb2_grpc.py @@ -0,0 +1,46 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +import grpc + +import helloworld_pb2 as helloworld__pb2 + + +class GreeterStub(object): + """The greeting service definition. + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.SayHello = channel.unary_unary( + '/helloworld.Greeter/SayHello', + request_serializer=helloworld__pb2.HelloRequest.SerializeToString, + response_deserializer=helloworld__pb2.HelloReply.FromString, + ) + + +class GreeterServicer(object): + """The greeting service definition. + """ + + def SayHello(self, request, context): + """Sends a greeting + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_GreeterServicer_to_server(servicer, server): + rpc_method_handlers = { + 'SayHello': grpc.unary_unary_rpc_method_handler( + servicer.SayHello, + request_deserializer=helloworld__pb2.HelloRequest.FromString, + response_serializer=helloworld__pb2.HelloReply.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'helloworld.Greeter', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) diff --git a/examples/python/interceptors/headers/request_header_validator_interceptor.py b/examples/python/interceptors/headers/request_header_validator_interceptor.py new file mode 100644 index 0000000000..95af4177ba --- /dev/null +++ b/examples/python/interceptors/headers/request_header_validator_interceptor.py @@ -0,0 +1,39 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Interceptor that ensures a specific header is present.""" + +import grpc + + +def _unary_unary_rpc_terminator(code, details): + + def terminate(ignored_request, context): + context.abort(code, details) + + return grpc.unary_unary_rpc_method_handler(terminate) + + +class RequestHeaderValidatorInterceptor(grpc.ServerInterceptor): + + def __init__(self, header, value, code, details): + self._header = header + self._value = value + self._terminator = _unary_unary_rpc_terminator(code, details) + + def intercept_service(self, continuation, handler_call_details): + if (self._header, + self._value) in handler_call_details.invocation_metadata: + return continuation(handler_call_details) + else: + return self._terminator diff --git a/src/python/grpcio/grpc/__init__.py b/src/python/grpcio/grpc/__init__.py index 564772527e..8b913ac949 100644 --- a/src/python/grpcio/grpc/__init__.py +++ b/src/python/grpcio/grpc/__init__.py @@ -342,6 +342,170 @@ class Call(six.with_metaclass(abc.ABCMeta, RpcContext)): raise NotImplementedError() +############## Invocation-Side Interceptor Interfaces & Classes ############## + + +class ClientCallDetails(six.with_metaclass(abc.ABCMeta)): + """Describes an RPC to be invoked. + + This is an EXPERIMENTAL API. + + Attributes: + method: The method name of the RPC. + timeout: An optional duration of time in seconds to allow for the RPC. + metadata: Optional :term:`metadata` to be transmitted to + the service-side of the RPC. + credentials: An optional CallCredentials for the RPC. + """ + + +class UnaryUnaryClientInterceptor(six.with_metaclass(abc.ABCMeta)): + """Affords intercepting unary-unary invocations. + + This is an EXPERIMENTAL API. + """ + + @abc.abstractmethod + def intercept_unary_unary(self, continuation, client_call_details, request): + """Intercepts a unary-unary invocation asynchronously. + + Args: + continuation: A function that proceeds with the invocation by + executing the next interceptor in chain or invoking the + actual RPC on the underlying Channel. It is the interceptor's + responsibility to call it if it decides to move the RPC forward. + The interceptor can use + `response_future = continuation(client_call_details, request)` + to continue with the RPC. `continuation` returns an object that is + both a Call for the RPC and a Future. In the event of RPC + completion, the return Call-Future's result value will be + the response message of the RPC. Should the event terminate + with non-OK status, the returned Call-Future's exception value + will be an RpcError. + client_call_details: A ClientCallDetails object describing the + outgoing RPC. + request: The request value for the RPC. + + Returns: + An object that is both a Call for the RPC and a Future. + In the event of RPC completion, the return Call-Future's + result value will be the response message of the RPC. + Should the event terminate with non-OK status, the returned + Call-Future's exception value will be an RpcError. + """ + raise NotImplementedError() + + +class UnaryStreamClientInterceptor(six.with_metaclass(abc.ABCMeta)): + """Affords intercepting unary-stream invocations. + + This is an EXPERIMENTAL API. + """ + + @abc.abstractmethod + def intercept_unary_stream(self, continuation, client_call_details, + request): + """Intercepts a unary-stream invocation. + + Args: + continuation: A function that proceeds with the invocation by + executing the next interceptor in chain or invoking the + actual RPC on the underlying Channel. It is the interceptor's + responsibility to call it if it decides to move the RPC forward. + The interceptor can use + `response_iterator = continuation(client_call_details, request)` + to continue with the RPC. `continuation` returns an object that is + both a Call for the RPC and an iterator for response values. + Drawing response values from the returned Call-iterator may + raise RpcError indicating termination of the RPC with non-OK + status. + client_call_details: A ClientCallDetails object describing the + outgoing RPC. + request: The request value for the RPC. + + Returns: + An object that is both a Call for the RPC and an iterator of + response values. Drawing response values from the returned + Call-iterator may raise RpcError indicating termination of + the RPC with non-OK status. + """ + raise NotImplementedError() + + +class StreamUnaryClientInterceptor(six.with_metaclass(abc.ABCMeta)): + """Affords intercepting stream-unary invocations. + + This is an EXPERIMENTAL API. + """ + + @abc.abstractmethod + def intercept_stream_unary(self, continuation, client_call_details, + request_iterator): + """Intercepts a stream-unary invocation asynchronously. + + Args: + continuation: A function that proceeds with the invocation by + executing the next interceptor in chain or invoking the + actual RPC on the underlying Channel. It is the interceptor's + responsibility to call it if it decides to move the RPC forward. + The interceptor can use + `response_future = continuation(client_call_details, + request_iterator)` + to continue with the RPC. `continuation` returns an object that is + both a Call for the RPC and a Future. In the event of RPC completion, + the return Call-Future's result value will be the response message + of the RPC. Should the event terminate with non-OK status, the + returned Call-Future's exception value will be an RpcError. + client_call_details: A ClientCallDetails object describing the + outgoing RPC. + request_iterator: An iterator that yields request values for the RPC. + + Returns: + An object that is both a Call for the RPC and a Future. + In the event of RPC completion, the return Call-Future's + result value will be the response message of the RPC. + Should the event terminate with non-OK status, the returned + Call-Future's exception value will be an RpcError. + """ + raise NotImplementedError() + + +class StreamStreamClientInterceptor(six.with_metaclass(abc.ABCMeta)): + """Affords intercepting stream-stream invocations. + + This is an EXPERIMENTAL API. + """ + + @abc.abstractmethod + def intercept_stream_stream(self, continuation, client_call_details, + request_iterator): + """Intercepts a stream-stream invocation. + + continuation: A function that proceeds with the invocation by + executing the next interceptor in chain or invoking the + actual RPC on the underlying Channel. It is the interceptor's + responsibility to call it if it decides to move the RPC forward. + The interceptor can use + `response_iterator = continuation(client_call_details, + request_iterator)` + to continue with the RPC. `continuation` returns an object that is + both a Call for the RPC and an iterator for response values. + Drawing response values from the returned Call-iterator may + raise RpcError indicating termination of the RPC with non-OK + status. + client_call_details: A ClientCallDetails object describing the + outgoing RPC. + request_iterator: An iterator that yields request values for the RPC. + + Returns: + An object that is both a Call for the RPC and an iterator of + response values. Drawing response values from the returned + Call-iterator may raise RpcError indicating termination of + the RPC with non-OK status. + """ + raise NotImplementedError() + + ############ Authentication & Authorization Interfaces & Classes ############# @@ -835,27 +999,47 @@ class ServicerContext(six.with_metaclass(abc.ABCMeta, RpcContext)): raise NotImplementedError() @abc.abstractmethod + def abort(self, code, details): + """Raises an exception to terminate the RPC with a non-OK status. + + The code and details passed as arguments will supercede any existing + ones. + + Args: + code: A StatusCode object to be sent to the client. + It must not be StatusCode.OK. + details: An ASCII-encodable string to be sent to the client upon + termination of the RPC. + + Raises: + Exception: An exception is always raised to signal the abortion the + RPC to the gRPC runtime. + """ + raise NotImplementedError() + + @abc.abstractmethod def set_code(self, code): """Sets the value to be used as status code upon RPC completion. - This method need not be called by method implementations if they wish the - gRPC runtime to determine the status code of the RPC. + This method need not be called by method implementations if they wish + the gRPC runtime to determine the status code of the RPC. - Args: - code: A StatusCode object to be sent to the client. - """ + Args: + code: A StatusCode object to be sent to the client. + """ raise NotImplementedError() @abc.abstractmethod def set_details(self, details): """Sets the value to be used as detail string upon RPC completion. - This method need not be called by method implementations if they have no - details to transmit. + This method need not be called by method implementations if they have + no details to transmit. - Args: - details: An arbitrary string to be sent to the client upon completion. - """ + Args: + details: An ASCII-encodable string to be sent to the client upon + termination of the RPC. + """ raise NotImplementedError() @@ -942,6 +1126,34 @@ class ServiceRpcHandler(six.with_metaclass(abc.ABCMeta, GenericRpcHandler)): raise NotImplementedError() +#################### Service-Side Interceptor Interfaces ##################### + + +class ServerInterceptor(six.with_metaclass(abc.ABCMeta)): + """Affords intercepting incoming RPCs on the service-side. + + This is an EXPERIMENTAL API. + """ + + @abc.abstractmethod + def intercept_service(self, continuation, handler_call_details): + """Intercepts incoming RPCs before handing them over to a handler. + + Args: + continuation: A function that takes a HandlerCallDetails and + proceeds to invoke the next interceptor in the chain, if any, + or the RPC handler lookup logic, with the call details passed + as an argument, and returns an RpcMethodHandler instance if + the RPC is considered serviced, or None otherwise. + handler_call_details: A HandlerCallDetails describing the RPC. + + Returns: + An RpcMethodHandler with which the RPC may be serviced if the + interceptor chooses to service this RPC, or None otherwise. + """ + raise NotImplementedError() + + ############################# Server Interface ############################### @@ -1356,53 +1568,88 @@ def secure_channel(target, credentials, options=None): credentials._credentials) +def intercept_channel(channel, *interceptors): + """Intercepts a channel through a set of interceptors. + + This is an EXPERIMENTAL API. + + Args: + channel: A Channel. + interceptors: Zero or more objects of type + UnaryUnaryClientInterceptor, + UnaryStreamClientInterceptor, + StreamUnaryClientInterceptor, or + StreamStreamClientInterceptor. + Interceptors are given control in the order they are listed. + + Returns: + A Channel that intercepts each invocation via the provided interceptors. + + Raises: + TypeError: If interceptor does not derive from any of + UnaryUnaryClientInterceptor, + UnaryStreamClientInterceptor, + StreamUnaryClientInterceptor, or + StreamStreamClientInterceptor. + """ + from grpc import _interceptor # pylint: disable=cyclic-import + return _interceptor.intercept_channel(channel, *interceptors) + + def server(thread_pool, handlers=None, + interceptors=None, options=None, maximum_concurrent_rpcs=None): """Creates a Server with which RPCs can be serviced. - Args: - thread_pool: A futures.ThreadPoolExecutor to be used by the Server - to execute RPC handlers. - handlers: An optional list of GenericRpcHandlers used for executing RPCs. - More handlers may be added by calling add_generic_rpc_handlers any time - before the server is started. - options: An optional list of key-value pairs (channel args in gRPC runtime) - to configure the channel. - maximum_concurrent_rpcs: The maximum number of concurrent RPCs this server - will service before returning RESOURCE_EXHAUSTED status, or None to - indicate no limit. + Args: + thread_pool: A futures.ThreadPoolExecutor to be used by the Server + to execute RPC handlers. + handlers: An optional list of GenericRpcHandlers used for executing RPCs. + More handlers may be added by calling add_generic_rpc_handlers any time + before the server is started. + interceptors: An optional list of ServerInterceptor objects that observe + and optionally manipulate the incoming RPCs before handing them over to + handlers. The interceptors are given control in the order they are + specified. This is an EXPERIMENTAL API. + options: An optional list of key-value pairs (channel args in gRPC runtime) + to configure the channel. + maximum_concurrent_rpcs: The maximum number of concurrent RPCs this server + will service before returning RESOURCE_EXHAUSTED status, or None to + indicate no limit. - Returns: - A Server object. - """ + Returns: + A Server object. + """ from grpc import _server # pylint: disable=cyclic-import return _server.Server(thread_pool, () if handlers is None else handlers, () - if options is None else options, - maximum_concurrent_rpcs) + if interceptors is None else interceptors, () if + options is None else options, maximum_concurrent_rpcs) ################################### __all__ ################################# -__all__ = ('FutureTimeoutError', 'FutureCancelledError', 'Future', - 'ChannelConnectivity', 'StatusCode', 'RpcError', 'RpcContext', - 'Call', 'ChannelCredentials', 'CallCredentials', - 'AuthMetadataContext', 'AuthMetadataPluginCallback', - 'AuthMetadataPlugin', 'ServerCertificateConfiguration', - 'ServerCredentials', 'UnaryUnaryMultiCallable', - 'UnaryStreamMultiCallable', 'StreamUnaryMultiCallable', - 'StreamStreamMultiCallable', 'Channel', 'ServicerContext', - 'RpcMethodHandler', 'HandlerCallDetails', 'GenericRpcHandler', - 'ServiceRpcHandler', 'Server', 'unary_unary_rpc_method_handler', - 'unary_stream_rpc_method_handler', 'stream_unary_rpc_method_handler', - 'stream_stream_rpc_method_handler', - 'method_handlers_generic_handler', 'ssl_channel_credentials', - 'metadata_call_credentials', 'access_token_call_credentials', - 'composite_call_credentials', 'composite_channel_credentials', - 'ssl_server_credentials', 'ssl_server_certificate_configuration', - 'dynamic_ssl_server_credentials', 'channel_ready_future', - 'insecure_channel', 'secure_channel', 'server',) +__all__ = ( + 'FutureTimeoutError', 'FutureCancelledError', 'Future', + 'ChannelConnectivity', 'StatusCode', 'RpcError', 'RpcContext', 'Call', + 'ChannelCredentials', 'CallCredentials', 'AuthMetadataContext', + 'AuthMetadataPluginCallback', 'AuthMetadataPlugin', 'ClientCallDetails', + 'ServerCertificateConfiguration', 'ServerCredentials', + 'UnaryUnaryMultiCallable', 'UnaryStreamMultiCallable', + 'StreamUnaryMultiCallable', 'StreamStreamMultiCallable', + 'UnaryUnaryClientInterceptor', 'UnaryStreamClientInterceptor', + 'StreamUnaryClientInterceptor', 'StreamStreamClientInterceptor', 'Channel', + 'ServicerContext', 'RpcMethodHandler', 'HandlerCallDetails', + 'GenericRpcHandler', 'ServiceRpcHandler', 'Server', 'ServerInterceptor', + 'unary_unary_rpc_method_handler', 'unary_stream_rpc_method_handler', + 'stream_unary_rpc_method_handler', 'stream_stream_rpc_method_handler', + 'method_handlers_generic_handler', 'ssl_channel_credentials', + 'metadata_call_credentials', 'access_token_call_credentials', + 'composite_call_credentials', 'composite_channel_credentials', + 'ssl_server_credentials', 'ssl_server_certificate_configuration', + 'dynamic_ssl_server_credentials', 'channel_ready_future', + 'insecure_channel', 'secure_channel', 'intercept_channel', 'server',) ############################### Extension Shims ################################ diff --git a/src/python/grpcio/grpc/_channel.py b/src/python/grpcio/grpc/_channel.py index cf4ce0941b..d7456a3dd1 100644 --- a/src/python/grpcio/grpc/_channel.py +++ b/src/python/grpcio/grpc/_channel.py @@ -122,8 +122,8 @@ def _abort(state, code, details): state.code = code state.details = details if state.initial_metadata is None: - state.initial_metadata = _common.EMPTY_METADATA - state.trailing_metadata = _common.EMPTY_METADATA + state.initial_metadata = () + state.trailing_metadata = () def _handle_event(event, state, response_deserializer): @@ -202,8 +202,7 @@ def _consume_request_iterator(request_iterator, state, call, else: operations = (cygrpc.operation_send_message( serialized_request, _EMPTY_FLAGS),) - call.start_client_batch( - cygrpc.Operations(operations), event_handler) + call.start_client_batch(operations, event_handler) state.due.add(cygrpc.OperationType.send_message) while True: state.condition.wait() @@ -218,8 +217,7 @@ def _consume_request_iterator(request_iterator, state, call, if state.code is None: operations = ( cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),) - call.start_client_batch( - cygrpc.Operations(operations), event_handler) + call.start_client_batch(operations, event_handler) state.due.add(cygrpc.OperationType.send_close_from_client) def stop_consumption_thread(timeout): # pylint: disable=unused-argument @@ -321,8 +319,7 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call): event_handler = _event_handler(self._state, self._call, self._response_deserializer) self._call.start_client_batch( - cygrpc.Operations( - (cygrpc.operation_receive_message(_EMPTY_FLAGS),)), + (cygrpc.operation_receive_message(_EMPTY_FLAGS),), event_handler) self._state.due.add(cygrpc.OperationType.receive_message) elif self._state.code is grpc.StatusCode.OK: @@ -372,14 +369,13 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call): with self._state.condition: while self._state.initial_metadata is None: self._state.condition.wait() - return _common.to_application_metadata(self._state.initial_metadata) + return self._state.initial_metadata def trailing_metadata(self): with self._state.condition: while self._state.trailing_metadata is None: self._state.condition.wait() - return _common.to_application_metadata( - self._state.trailing_metadata) + return self._state.trailing_metadata def code(self): with self._state.condition: @@ -420,8 +416,7 @@ def _start_unary_request(request, timeout, request_serializer): deadline, deadline_timespec = _deadline(timeout) serialized_request = _common.serialize(request, request_serializer) if serialized_request is None: - state = _RPCState((), _common.EMPTY_METADATA, _common.EMPTY_METADATA, - grpc.StatusCode.INTERNAL, + state = _RPCState((), (), (), grpc.StatusCode.INTERNAL, 'Exception serializing request!') rendezvous = _Rendezvous(state, None, None, deadline) return deadline, deadline_timespec, None, rendezvous @@ -458,8 +453,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): else: state = _RPCState(_UNARY_UNARY_INITIAL_DUE, None, None, None, None) operations = ( - cygrpc.operation_send_initial_metadata( - _common.to_cygrpc_metadata(metadata), _EMPTY_FLAGS), + cygrpc.operation_send_initial_metadata(metadata, _EMPTY_FLAGS), cygrpc.operation_send_message(serialized_request, _EMPTY_FLAGS), cygrpc.operation_send_close_from_client(_EMPTY_FLAGS), cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS), @@ -479,8 +473,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): deadline_timespec) if credentials is not None: call.set_credentials(credentials._credentials) - call_error = call.start_client_batch( - cygrpc.Operations(operations), None) + call_error = call.start_client_batch(operations, None) _check_call_error(call_error, metadata) _handle_event(completion_queue.poll(), state, self._response_deserializer) @@ -509,8 +502,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): event_handler = _event_handler(state, call, self._response_deserializer) with state.condition: - call_error = call.start_client_batch( - cygrpc.Operations(operations), event_handler) + call_error = call.start_client_batch(operations, event_handler) if call_error != cygrpc.CallError.ok: _call_error_set_RPCstate(state, call_error, metadata) return _Rendezvous(state, None, None, deadline) @@ -544,18 +536,15 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): self._response_deserializer) with state.condition: call.start_client_batch( - cygrpc.Operations(( - cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS), - )), event_handler) + (cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),), + event_handler) operations = ( cygrpc.operation_send_initial_metadata( - _common.to_cygrpc_metadata(metadata), - _EMPTY_FLAGS), cygrpc.operation_send_message( + metadata, _EMPTY_FLAGS), cygrpc.operation_send_message( serialized_request, _EMPTY_FLAGS), cygrpc.operation_send_close_from_client(_EMPTY_FLAGS), cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),) - call_error = call.start_client_batch( - cygrpc.Operations(operations), event_handler) + call_error = call.start_client_batch(operations, event_handler) if call_error != cygrpc.CallError.ok: _call_error_set_RPCstate(state, call_error, metadata) return _Rendezvous(state, None, None, deadline) @@ -584,16 +573,13 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): call.set_credentials(credentials._credentials) with state.condition: call.start_client_batch( - cygrpc.Operations( - (cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)), + (cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),), None) operations = ( - cygrpc.operation_send_initial_metadata( - _common.to_cygrpc_metadata(metadata), _EMPTY_FLAGS), + cygrpc.operation_send_initial_metadata(metadata, _EMPTY_FLAGS), cygrpc.operation_receive_message(_EMPTY_FLAGS), cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),) - call_error = call.start_client_batch( - cygrpc.Operations(operations), None) + call_error = call.start_client_batch(operations, None) _check_call_error(call_error, metadata) _consume_request_iterator(request_iterator, state, call, self._request_serializer) @@ -638,16 +624,13 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): event_handler = _event_handler(state, call, self._response_deserializer) with state.condition: call.start_client_batch( - cygrpc.Operations( - (cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)), + (cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),), event_handler) operations = ( - cygrpc.operation_send_initial_metadata( - _common.to_cygrpc_metadata(metadata), _EMPTY_FLAGS), + cygrpc.operation_send_initial_metadata(metadata, _EMPTY_FLAGS), cygrpc.operation_receive_message(_EMPTY_FLAGS), cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),) - call_error = call.start_client_batch( - cygrpc.Operations(operations), event_handler) + call_error = call.start_client_batch(operations, event_handler) if call_error != cygrpc.CallError.ok: _call_error_set_RPCstate(state, call_error, metadata) return _Rendezvous(state, None, None, deadline) @@ -681,15 +664,12 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): event_handler = _event_handler(state, call, self._response_deserializer) with state.condition: call.start_client_batch( - cygrpc.Operations( - (cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)), + (cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),), event_handler) operations = ( - cygrpc.operation_send_initial_metadata( - _common.to_cygrpc_metadata(metadata), _EMPTY_FLAGS), + cygrpc.operation_send_initial_metadata(metadata, _EMPTY_FLAGS), cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),) - call_error = call.start_client_batch( - cygrpc.Operations(operations), event_handler) + call_error = call.start_client_batch(operations, event_handler) if call_error != cygrpc.CallError.ok: _call_error_set_RPCstate(state, call_error, metadata) return _Rendezvous(state, None, None, deadline) diff --git a/src/python/grpcio/grpc/_common.py b/src/python/grpcio/grpc/_common.py index 740d4639db..130fc42630 100644 --- a/src/python/grpcio/grpc/_common.py +++ b/src/python/grpcio/grpc/_common.py @@ -22,8 +22,6 @@ import six import grpc from grpc._cython import cygrpc -EMPTY_METADATA = cygrpc.Metadata(()) - CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY = { cygrpc.ConnectivityState.idle: grpc.ChannelConnectivity.IDLE, @@ -91,21 +89,6 @@ def channel_args(options): return cygrpc.ChannelArgs(cygrpc_args) -def to_cygrpc_metadata(application_metadata): - return EMPTY_METADATA if application_metadata is None else cygrpc.Metadata( - cygrpc.Metadatum(encode(key), encode(value)) - for key, value in application_metadata) - - -def to_application_metadata(cygrpc_metadata): - if cygrpc_metadata is None: - return () - else: - return tuple((decode(key), value - if key[-4:] == b'-bin' else decode(value)) - for key, value in cygrpc_metadata) - - def _transform(message, transformer, exception_message): if transformer is None: return message diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi index 6b3a276097..6361669757 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi @@ -26,20 +26,16 @@ cdef class Call: def _start_batch(self, operations, tag, retain_self): if not self.is_valid: raise ValueError("invalid call object cannot be used from Python") - cdef grpc_call_error result - cdef Operations cy_operations = Operations(operations) - cdef OperationTag operation_tag = OperationTag(tag) + cdef OperationTag operation_tag = OperationTag(tag, operations) if retain_self: operation_tag.operation_call = self else: operation_tag.operation_call = None - operation_tag.batch_operations = cy_operations + operation_tag.store_ops() cpython.Py_INCREF(operation_tag) - with nogil: - result = grpc_call_start_batch( - self.c_call, cy_operations.c_ops, cy_operations.c_nops, + return grpc_call_start_batch( + self.c_call, operation_tag.c_ops, operation_tag.c_nops, <cpython.PyObject *>operation_tag, NULL) - return result def start_client_batch(self, operations, tag): # We don't reference this call in the operations tag because diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi index 4c397f8f64..644df674cc 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi @@ -76,7 +76,7 @@ cdef class Channel: def watch_connectivity_state( self, grpc_connectivity_state last_observed_state, Timespec deadline not None, CompletionQueue queue not None, tag): - cdef OperationTag operation_tag = OperationTag(tag) + cdef OperationTag operation_tag = OperationTag(tag, None) cpython.Py_INCREF(operation_tag) with nogil: grpc_channel_watch_connectivity_state( diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi index 237f430799..140fc357b9 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi @@ -42,7 +42,7 @@ cdef class CompletionQueue: cdef Call operation_call = None cdef CallDetails request_call_details = None cdef object request_metadata = None - cdef Operations batch_operations = None + cdef object batch_operations = None if event.type == GRPC_QUEUE_TIMEOUT: return Event( event.type, False, None, None, None, None, False, None) @@ -61,9 +61,10 @@ cdef class CompletionQueue: user_tag = tag.user_tag operation_call = tag.operation_call request_call_details = tag.request_call_details - if tag.request_metadata is not None: - request_metadata = tuple(tag.request_metadata) - batch_operations = tag.batch_operations + if tag.is_new_request: + request_metadata = _metadata(&tag._c_request_metadata) + grpc_metadata_array_destroy(&tag._c_request_metadata) + batch_operations = tag.release_ops() if tag.is_new_request: # Stuff in the tag not explicitly handled by us needs to live through # the life of the call diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi index 246a271893..500086f6cb 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi @@ -30,9 +30,13 @@ cdef int _get_metadata( grpc_metadata creds_md[GRPC_METADATA_CREDENTIALS_PLUGIN_SYNC_MAX], size_t *num_creds_md, grpc_status_code *status, const char **error_details) with gil: - def callback(Metadata metadata, grpc_status_code status, bytes error_details): + cdef size_t metadata_count + cdef grpc_metadata *c_metadata + def callback(metadata, grpc_status_code status, bytes error_details): if status is StatusCode.ok: - cb(user_data, metadata.c_metadata, metadata.c_count, status, NULL) + _store_c_metadata(metadata, &c_metadata, &metadata_count) + cb(user_data, c_metadata, metadata_count, status, NULL) + _release_c_metadata(c_metadata, metadata_count) else: cb(user_data, NULL, 0, status, error_details) args = context.service_url, context.method_name, callback, diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/grpc_string.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/grpc_string.pyx.pxi index c8f11f8e19..e3cad9acb3 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/grpc_string.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/grpc_string.pyx.pxi @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + # This function will ascii encode unicode string inputs if neccesary. # In Python3, unicode strings are the default str type. @@ -22,3 +24,25 @@ cdef bytes str_to_bytes(object s): return s.encode('ascii') else: raise TypeError('Expected bytes, str, or unicode, not {}'.format(type(s))) + + +cdef bytes _encode(str native_string_or_none): + if native_string_or_none is None: + return b'' + elif isinstance(native_string_or_none, (bytes,)): + return <bytes>native_string_or_none + elif isinstance(native_string_or_none, (unicode,)): + return native_string_or_none.encode('ascii') + else: + raise TypeError('Expected str, not {}'.format(type(native_string_or_none))) + + +cdef str _decode(bytes bytestring): + if isinstance(bytestring, (str,)): + return <str>bytestring + else: + try: + return bytestring.decode('utf8') + except UnicodeDecodeError: + logging.exception('Invalid encoding on %s', bytestring) + return bytestring.decode('latin1') diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/metadata.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/metadata.pxd.pxi new file mode 100644 index 0000000000..a18c365807 --- /dev/null +++ b/src/python/grpcio/grpc/_cython/_cygrpc/metadata.pxd.pxi @@ -0,0 +1,26 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +cdef void _store_c_metadata( + metadata, grpc_metadata **c_metadata, size_t *c_count) + + +cdef void _release_c_metadata(grpc_metadata *c_metadata, int count) + + +cdef tuple _metadatum(grpc_slice key_slice, grpc_slice value_slice) + + +cdef tuple _metadata(grpc_metadata_array *c_metadata_array) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/metadata.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/metadata.pyx.pxi new file mode 100644 index 0000000000..c39fef08fa --- /dev/null +++ b/src/python/grpcio/grpc/_cython/_cygrpc/metadata.pyx.pxi @@ -0,0 +1,62 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections + + +_Metadatum = collections.namedtuple('_Metadatum', ('key', 'value',)) + + +cdef void _store_c_metadata( + metadata, grpc_metadata **c_metadata, size_t *c_count): + if metadata is None: + c_count[0] = 0 + c_metadata[0] = NULL + else: + metadatum_count = len(metadata) + if metadatum_count == 0: + c_count[0] = 0 + c_metadata[0] = NULL + else: + c_count[0] = metadatum_count + c_metadata[0] = <grpc_metadata *>gpr_malloc( + metadatum_count * sizeof(grpc_metadata)) + for index, (key, value) in enumerate(metadata): + encoded_key = _encode(key) + encoded_value = value if encoded_key[-4:] == b'-bin' else _encode(value) + c_metadata[0][index].key = _slice_from_bytes(encoded_key) + c_metadata[0][index].value = _slice_from_bytes(encoded_value) + + +cdef void _release_c_metadata(grpc_metadata *c_metadata, int count): + if 0 < count: + for index in range(count): + grpc_slice_unref(c_metadata[index].key) + grpc_slice_unref(c_metadata[index].value) + gpr_free(c_metadata) + + +cdef tuple _metadatum(grpc_slice key_slice, grpc_slice value_slice): + cdef bytes key = _slice_bytes(key_slice) + cdef bytes value = _slice_bytes(value_slice) + return <tuple>_Metadatum( + _decode(key), value if key[-4:] == b'-bin' else _decode(value)) + + +cdef tuple _metadata(grpc_metadata_array *c_metadata_array): + return tuple( + _metadatum( + c_metadata_array.metadata[index].key, + c_metadata_array.metadata[index].value) + for index in range(c_metadata_array.count)) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/records.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/records.pxd.pxi index 9c40ebf0c2..594fdb1a8b 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/records.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/records.pxd.pxi @@ -37,10 +37,15 @@ cdef class OperationTag: cdef Server shutting_down_server cdef Call operation_call cdef CallDetails request_call_details - cdef MetadataArray request_metadata - cdef Operations batch_operations + cdef grpc_metadata_array _c_request_metadata + cdef grpc_op *c_ops + cdef size_t c_nops + cdef readonly object _operations cdef bint is_new_request + cdef void store_ops(self) + cdef object release_ops(self) + cdef class Event: @@ -57,7 +62,7 @@ cdef class Event: cdef readonly Call operation_call # For Call.start_batch - cdef readonly Operations batch_operations + cdef readonly object batch_operations cdef class ByteBuffer: @@ -84,28 +89,15 @@ cdef class ChannelArgs: cdef list args -cdef class Metadatum: - - cdef grpc_metadata c_metadata - cdef void _copy_metadatum(self, grpc_metadata *destination) nogil - - -cdef class Metadata: - - cdef grpc_metadata *c_metadata - cdef readonly size_t c_count - - -cdef class MetadataArray: - - cdef grpc_metadata_array c_metadata_array - - cdef class Operation: cdef grpc_op c_op + cdef bint _c_metadata_needs_release + cdef size_t _c_metadata_count + cdef grpc_metadata *_c_metadata cdef ByteBuffer _received_message - cdef MetadataArray _received_metadata + cdef bint _c_metadata_array_needs_destruction + cdef grpc_metadata_array _c_metadata_array cdef grpc_status_code _received_status_code cdef grpc_slice _status_details cdef int _received_cancelled @@ -113,13 +105,6 @@ cdef class Operation: cdef object references -cdef class Operations: - - cdef grpc_op *c_ops - cdef size_t c_nops - cdef list operations - - cdef class CompressionOptions: cdef grpc_compression_options c_options diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi index 03fb226190..26eaf50eb4 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi @@ -220,9 +220,26 @@ cdef class CallDetails: cdef class OperationTag: - def __cinit__(self, user_tag): + def __cinit__(self, user_tag, operations): self.user_tag = user_tag self.references = [] + self._operations = operations + + cdef void store_ops(self): + self.c_nops = 0 if self._operations is None else len(self._operations) + if 0 < self.c_nops: + self.c_ops = <grpc_op *>gpr_malloc(sizeof(grpc_op) * self.c_nops) + for index in range(self.c_nops): + self.c_ops[index] = (<Operation>(self._operations[index])).c_op + + cdef object release_ops(self): + if 0 < self.c_nops: + for index, operation in enumerate(self._operations): + (<Operation>operation).c_op = self.c_ops[index] + gpr_free(self.c_ops) + return self._operations + else: + return () cdef class Event: @@ -232,7 +249,7 @@ cdef class Event: CallDetails request_call_details, object request_metadata, bint is_new_request, - Operations batch_operations): + object batch_operations): self.type = type self.success = success self.tag = tag @@ -390,140 +407,13 @@ cdef class ChannelArgs: return self.args[i] -cdef class Metadatum: - - def __cinit__(self, bytes key, bytes value): - self.c_metadata.key = _slice_from_bytes(key) - self.c_metadata.value = _slice_from_bytes(value) - - cdef void _copy_metadatum(self, grpc_metadata *destination) nogil: - destination[0].key = _copy_slice(self.c_metadata.key) - destination[0].value = _copy_slice(self.c_metadata.value) - - @property - def key(self): - return _slice_bytes(self.c_metadata.key) - - @property - def value(self): - return _slice_bytes(self.c_metadata.value) - - def __len__(self): - return 2 - - def __getitem__(self, size_t i): - if i == 0: - return self.key - elif i == 1: - return self.value - else: - raise IndexError("index must be 0 (key) or 1 (value)") - - def __iter__(self): - return iter((self.key, self.value)) - - def __dealloc__(self): - grpc_slice_unref(self.c_metadata.key) - grpc_slice_unref(self.c_metadata.value) - -cdef class _MetadataIterator: - - cdef size_t i - cdef size_t _length - cdef object _metadatum_indexable - - def __cinit__(self, length, metadatum_indexable): - self._length = length - self._metadatum_indexable = metadatum_indexable - self.i = 0 - - def __iter__(self): - return self - - def __next__(self): - if self.i < self._length: - result = self._metadatum_indexable[self.i] - self.i = self.i + 1 - return result - else: - raise StopIteration() - - -# TODO(https://github.com/grpc/grpc/issues/7950): Eliminate this; just use an -# ordinary sequence of pairs of bytestrings all the way down to the -# grpc_call_start_batch call. -cdef class Metadata: - """Metadata being passed from application to core.""" - - def __cinit__(self, metadata_iterable): - metadata_sequence = tuple(metadata_iterable) - cdef size_t count = len(metadata_sequence) - with nogil: - grpc_init() - self.c_metadata = <grpc_metadata *>gpr_malloc( - count * sizeof(grpc_metadata)) - self.c_count = count - for index, metadatum in enumerate(metadata_sequence): - self.c_metadata[index].key = grpc_slice_copy( - (<Metadatum>metadatum).c_metadata.key) - self.c_metadata[index].value = grpc_slice_copy( - (<Metadatum>metadatum).c_metadata.value) - - def __dealloc__(self): - with nogil: - for index in range(self.c_count): - grpc_slice_unref(self.c_metadata[index].key) - grpc_slice_unref(self.c_metadata[index].value) - gpr_free(self.c_metadata) - grpc_shutdown() - - def __len__(self): - return self.c_count - - def __getitem__(self, size_t index): - if index < self.c_count: - key = _slice_bytes(self.c_metadata[index].key) - value = _slice_bytes(self.c_metadata[index].value) - return Metadatum(key, value) - else: - raise IndexError() - - def __iter__(self): - return _MetadataIterator(self.c_count, self) - - -cdef class MetadataArray: - """Metadata being passed from core to application.""" - - def __cinit__(self): - with nogil: - grpc_init() - grpc_metadata_array_init(&self.c_metadata_array) - - def __dealloc__(self): - with nogil: - grpc_metadata_array_destroy(&self.c_metadata_array) - grpc_shutdown() - - def __len__(self): - return self.c_metadata_array.count - - def __getitem__(self, size_t i): - if i >= self.c_metadata_array.count: - raise IndexError() - key = _slice_bytes(self.c_metadata_array.metadata[i].key) - value = _slice_bytes(self.c_metadata_array.metadata[i].value) - return Metadatum(key=key, value=value) - - def __iter__(self): - return _MetadataIterator(self.c_metadata_array.count, self) - - cdef class Operation: def __cinit__(self): grpc_init() self.references = [] + self._c_metadata_needs_release = False + self._c_metadata_array_needs_destruction = False self._status_details = grpc_empty_slice() self.is_valid = False @@ -556,13 +446,7 @@ cdef class Operation: if (self.c_op.type != GRPC_OP_RECV_INITIAL_METADATA and self.c_op.type != GRPC_OP_RECV_STATUS_ON_CLIENT): raise TypeError("self must be an operation receiving metadata") - # TODO(https://github.com/grpc/grpc/issues/7950): Drop the "all Cython - # objects must be legitimate for use from Python at any time" policy in - # place today, shift the policy toward "Operation objects are only usable - # while their calls are active", and move this making-a-copy-because-this- - # data-needs-to-live-much-longer-than-the-call-from-which-it-arose to the - # lowest Python layer. - return tuple(self._received_metadata) + return _metadata(&self._c_metadata_array) @property def received_status_code(self): @@ -602,16 +486,21 @@ cdef class Operation: return False if self._received_cancelled == 0 else True def __dealloc__(self): + if self._c_metadata_needs_release: + _release_c_metadata(self._c_metadata, self._c_metadata_count) + if self._c_metadata_array_needs_destruction: + grpc_metadata_array_destroy(&self._c_metadata_array) grpc_slice_unref(self._status_details) grpc_shutdown() -def operation_send_initial_metadata(Metadata metadata, int flags): +def operation_send_initial_metadata(metadata, int flags): cdef Operation op = Operation() op.c_op.type = GRPC_OP_SEND_INITIAL_METADATA op.c_op.flags = flags - op.c_op.data.send_initial_metadata.count = metadata.c_count - op.c_op.data.send_initial_metadata.metadata = metadata.c_metadata - op.references.append(metadata) + _store_c_metadata(metadata, &op._c_metadata, &op._c_metadata_count) + op._c_metadata_needs_release = True + op.c_op.data.send_initial_metadata.count = op._c_metadata_count + op.c_op.data.send_initial_metadata.metadata = op._c_metadata op.is_valid = True return op @@ -633,18 +522,19 @@ def operation_send_close_from_client(int flags): return op def operation_send_status_from_server( - Metadata metadata, grpc_status_code code, bytes details, int flags): + metadata, grpc_status_code code, bytes details, int flags): cdef Operation op = Operation() op.c_op.type = GRPC_OP_SEND_STATUS_FROM_SERVER op.c_op.flags = flags + _store_c_metadata(metadata, &op._c_metadata, &op._c_metadata_count) + op._c_metadata_needs_release = True op.c_op.data.send_status_from_server.trailing_metadata_count = ( - metadata.c_count) - op.c_op.data.send_status_from_server.trailing_metadata = metadata.c_metadata + op._c_metadata_count) + op.c_op.data.send_status_from_server.trailing_metadata = op._c_metadata op.c_op.data.send_status_from_server.status = code grpc_slice_unref(op._status_details) op._status_details = _slice_from_bytes(details) op.c_op.data.send_status_from_server.status_details = &op._status_details - op.references.append(metadata) op.is_valid = True return op @@ -652,9 +542,10 @@ def operation_receive_initial_metadata(int flags): cdef Operation op = Operation() op.c_op.type = GRPC_OP_RECV_INITIAL_METADATA op.c_op.flags = flags - op._received_metadata = MetadataArray() + grpc_metadata_array_init(&op._c_metadata_array) op.c_op.data.receive_initial_metadata.receive_initial_metadata = ( - &op._received_metadata.c_metadata_array) + &op._c_metadata_array) + op._c_metadata_array_needs_destruction = True op.is_valid = True return op @@ -675,9 +566,10 @@ def operation_receive_status_on_client(int flags): cdef Operation op = Operation() op.c_op.type = GRPC_OP_RECV_STATUS_ON_CLIENT op.c_op.flags = flags - op._received_metadata = MetadataArray() + grpc_metadata_array_init(&op._c_metadata_array) op.c_op.data.receive_status_on_client.trailing_metadata = ( - &op._received_metadata.c_metadata_array) + &op._c_metadata_array) + op._c_metadata_array_needs_destruction = True op.c_op.data.receive_status_on_client.status = ( &op._received_status_code) op.c_op.data.receive_status_on_client.status_details = ( @@ -694,59 +586,6 @@ def operation_receive_close_on_server(int flags): return op -cdef class _OperationsIterator: - - cdef size_t i - cdef Operations operations - - def __cinit__(self, Operations operations not None): - self.i = 0 - self.operations = operations - - def __iter__(self): - return self - - def __next__(self): - if self.i < len(self.operations): - result = self.operations[self.i] - self.i = self.i + 1 - return result - else: - raise StopIteration() - - -cdef class Operations: - - def __cinit__(self, operations): - grpc_init() - self.operations = list(operations) # normalize iterable - self.c_ops = NULL - self.c_nops = 0 - for operation in self.operations: - if not isinstance(operation, Operation): - raise TypeError("expected operations to be iterable of Operation") - self.c_nops = len(self.operations) - with nogil: - self.c_ops = <grpc_op *>gpr_malloc(sizeof(grpc_op)*self.c_nops) - for i in range(self.c_nops): - self.c_ops[i] = (<Operation>(self.operations[i])).c_op - - def __len__(self): - return self.c_nops - - def __getitem__(self, size_t i): - # self.operations is never stale; it's only updated from this file - return self.operations[i] - - def __dealloc__(self): - with nogil: - gpr_free(self.c_ops) - grpc_shutdown() - - def __iter__(self): - return _OperationsIterator(self) - - cdef class CompressionOptions: def __cinit__(self): diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi index 5f3405936c..f8d7892858 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi @@ -78,23 +78,19 @@ cdef class Server: raise ValueError("server must be started and not shutting down") if server_queue not in self.registered_completion_queues: raise ValueError("server_queue must be a registered completion queue") - cdef grpc_call_error result - cdef OperationTag operation_tag = OperationTag(tag) + cdef OperationTag operation_tag = OperationTag(tag, None) operation_tag.operation_call = Call() operation_tag.request_call_details = CallDetails() - operation_tag.request_metadata = MetadataArray() + grpc_metadata_array_init(&operation_tag._c_request_metadata) operation_tag.references.extend([self, call_queue, server_queue]) operation_tag.is_new_request = True - operation_tag.batch_operations = Operations([]) cpython.Py_INCREF(operation_tag) - with nogil: - result = grpc_server_request_call( - self.c_server, &operation_tag.operation_call.c_call, - &operation_tag.request_call_details.c_details, - &operation_tag.request_metadata.c_metadata_array, - call_queue.c_completion_queue, server_queue.c_completion_queue, - <cpython.PyObject *>operation_tag) - return result + return grpc_server_request_call( + self.c_server, &operation_tag.operation_call.c_call, + &operation_tag.request_call_details.c_details, + &operation_tag._c_request_metadata, + call_queue.c_completion_queue, server_queue.c_completion_queue, + <cpython.PyObject *>operation_tag) def register_completion_queue( self, CompletionQueue queue not None): @@ -135,7 +131,7 @@ cdef class Server: cdef _c_shutdown(self, CompletionQueue queue, tag): self.is_shutting_down = True - operation_tag = OperationTag(tag) + operation_tag = OperationTag(tag, None) operation_tag.shutting_down_server = self cpython.Py_INCREF(operation_tag) with nogil: diff --git a/src/python/grpcio/grpc/_cython/cygrpc.pxd b/src/python/grpcio/grpc/_cython/cygrpc.pxd index fc6cc5fb9f..6fc5638d5d 100644 --- a/src/python/grpcio/grpc/_cython/cygrpc.pxd +++ b/src/python/grpcio/grpc/_cython/cygrpc.pxd @@ -18,6 +18,7 @@ include "_cygrpc/call.pxd.pxi" include "_cygrpc/channel.pxd.pxi" include "_cygrpc/credentials.pxd.pxi" include "_cygrpc/completion_queue.pxd.pxi" +include "_cygrpc/metadata.pxd.pxi" include "_cygrpc/records.pxd.pxi" include "_cygrpc/security.pxd.pxi" include "_cygrpc/server.pxd.pxi" diff --git a/src/python/grpcio/grpc/_cython/cygrpc.pyx b/src/python/grpcio/grpc/_cython/cygrpc.pyx index 57165d5f5a..d605229822 100644 --- a/src/python/grpcio/grpc/_cython/cygrpc.pyx +++ b/src/python/grpcio/grpc/_cython/cygrpc.pyx @@ -25,6 +25,7 @@ include "_cygrpc/call.pyx.pxi" include "_cygrpc/channel.pyx.pxi" include "_cygrpc/credentials.pyx.pxi" include "_cygrpc/completion_queue.pyx.pxi" +include "_cygrpc/metadata.pyx.pxi" include "_cygrpc/records.pyx.pxi" include "_cygrpc/security.pyx.pxi" include "_cygrpc/server.pyx.pxi" diff --git a/src/python/grpcio/grpc/_interceptor.py b/src/python/grpcio/grpc/_interceptor.py new file mode 100644 index 0000000000..fffb269845 --- /dev/null +++ b/src/python/grpcio/grpc/_interceptor.py @@ -0,0 +1,318 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implementation of gRPC Python interceptors.""" + +import collections +import sys + +import grpc + + +class _ServicePipeline(object): + + def __init__(self, interceptors): + self.interceptors = tuple(interceptors) + + def _continuation(self, thunk, index): + return lambda context: self._intercept_at(thunk, index, context) + + def _intercept_at(self, thunk, index, context): + if index < len(self.interceptors): + interceptor = self.interceptors[index] + thunk = self._continuation(thunk, index + 1) + return interceptor.intercept_service(thunk, context) + else: + return thunk(context) + + def execute(self, thunk, context): + return self._intercept_at(thunk, 0, context) + + +def service_pipeline(interceptors): + return _ServicePipeline(interceptors) if interceptors else None + + +class _ClientCallDetails( + collections.namedtuple('_ClientCallDetails', + ('method', 'timeout', 'metadata', + 'credentials')), grpc.ClientCallDetails): + pass + + +class _LocalFailure(grpc.RpcError, grpc.Future, grpc.Call): + + def __init__(self, exception, traceback): + super(_LocalFailure, self).__init__() + self._exception = exception + self._traceback = traceback + + def initial_metadata(self): + return None + + def trailing_metadata(self): + return None + + def code(self): + return grpc.StatusCode.INTERNAL + + def details(self): + return 'Exception raised while intercepting the RPC' + + def cancel(self): + return False + + def cancelled(self): + return False + + def running(self): + return False + + def done(self): + return True + + def result(self, ignored_timeout=None): + raise self._exception + + def exception(self, ignored_timeout=None): + return self._exception + + def traceback(self, ignored_timeout=None): + return self._traceback + + def add_done_callback(self, fn): + fn(self) + + def __iter__(self): + return self + + def next(self): + raise self._exception + + +class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): + + def __init__(self, thunk, method, interceptor): + self._thunk = thunk + self._method = method + self._interceptor = interceptor + + def __call__(self, request, timeout=None, metadata=None, credentials=None): + call_future = self.future( + request, + timeout=timeout, + metadata=metadata, + credentials=credentials) + return call_future.result() + + def with_call(self, request, timeout=None, metadata=None, credentials=None): + call_future = self.future( + request, + timeout=timeout, + metadata=metadata, + credentials=credentials) + return call_future.result(), call_future + + def future(self, request, timeout=None, metadata=None, credentials=None): + + def continuation(client_call_details, request): + return self._thunk(client_call_details.method).future( + request, + timeout=client_call_details.timeout, + metadata=client_call_details.metadata, + credentials=client_call_details.credentials) + + client_call_details = _ClientCallDetails(self._method, timeout, + metadata, credentials) + try: + return self._interceptor.intercept_unary_unary( + continuation, client_call_details, request) + except Exception as exception: # pylint:disable=broad-except + return _LocalFailure(exception, sys.exc_info()[2]) + + +class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): + + def __init__(self, thunk, method, interceptor): + self._thunk = thunk + self._method = method + self._interceptor = interceptor + + def __call__(self, request, timeout=None, metadata=None, credentials=None): + + def continuation(client_call_details, request): + return self._thunk(client_call_details.method)( + request, + timeout=client_call_details.timeout, + metadata=client_call_details.metadata, + credentials=client_call_details.credentials) + + client_call_details = _ClientCallDetails(self._method, timeout, + metadata, credentials) + try: + return self._interceptor.intercept_unary_stream( + continuation, client_call_details, request) + except Exception as exception: # pylint:disable=broad-except + return _LocalFailure(exception, sys.exc_info()[2]) + + +class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): + + def __init__(self, thunk, method, interceptor): + self._thunk = thunk + self._method = method + self._interceptor = interceptor + + def __call__(self, + request_iterator, + timeout=None, + metadata=None, + credentials=None): + call_future = self.future( + request_iterator, + timeout=timeout, + metadata=metadata, + credentials=credentials) + return call_future.result() + + def with_call(self, + request_iterator, + timeout=None, + metadata=None, + credentials=None): + call_future = self.future( + request_iterator, + timeout=timeout, + metadata=metadata, + credentials=credentials) + return call_future.result(), call_future + + def future(self, + request_iterator, + timeout=None, + metadata=None, + credentials=None): + + def continuation(client_call_details, request_iterator): + return self._thunk(client_call_details.method).future( + request_iterator, + timeout=client_call_details.timeout, + metadata=client_call_details.metadata, + credentials=client_call_details.credentials) + + client_call_details = _ClientCallDetails(self._method, timeout, + metadata, credentials) + + try: + return self._interceptor.intercept_stream_unary( + continuation, client_call_details, request_iterator) + except Exception as exception: # pylint:disable=broad-except + return _LocalFailure(exception, sys.exc_info()[2]) + + +class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): + + def __init__(self, thunk, method, interceptor): + self._thunk = thunk + self._method = method + self._interceptor = interceptor + + def __call__(self, + request_iterator, + timeout=None, + metadata=None, + credentials=None): + + def continuation(client_call_details, request_iterator): + return self._thunk(client_call_details.method)( + request_iterator, + timeout=client_call_details.timeout, + metadata=client_call_details.metadata, + credentials=client_call_details.credentials) + + client_call_details = _ClientCallDetails(self._method, timeout, + metadata, credentials) + + try: + return self._interceptor.intercept_stream_stream( + continuation, client_call_details, request_iterator) + except Exception as exception: # pylint:disable=broad-except + return _LocalFailure(exception, sys.exc_info()[2]) + + +class _Channel(grpc.Channel): + + def __init__(self, channel, interceptor): + self._channel = channel + self._interceptor = interceptor + + def subscribe(self, *args, **kwargs): + self._channel.subscribe(*args, **kwargs) + + def unsubscribe(self, *args, **kwargs): + self._channel.unsubscribe(*args, **kwargs) + + def unary_unary(self, + method, + request_serializer=None, + response_deserializer=None): + thunk = lambda m: self._channel.unary_unary(m, request_serializer, response_deserializer) + if isinstance(self._interceptor, grpc.UnaryUnaryClientInterceptor): + return _UnaryUnaryMultiCallable(thunk, method, self._interceptor) + else: + return thunk(method) + + def unary_stream(self, + method, + request_serializer=None, + response_deserializer=None): + thunk = lambda m: self._channel.unary_stream(m, request_serializer, response_deserializer) + if isinstance(self._interceptor, grpc.UnaryStreamClientInterceptor): + return _UnaryStreamMultiCallable(thunk, method, self._interceptor) + else: + return thunk(method) + + def stream_unary(self, + method, + request_serializer=None, + response_deserializer=None): + thunk = lambda m: self._channel.stream_unary(m, request_serializer, response_deserializer) + if isinstance(self._interceptor, grpc.StreamUnaryClientInterceptor): + return _StreamUnaryMultiCallable(thunk, method, self._interceptor) + else: + return thunk(method) + + def stream_stream(self, + method, + request_serializer=None, + response_deserializer=None): + thunk = lambda m: self._channel.stream_stream(m, request_serializer, response_deserializer) + if isinstance(self._interceptor, grpc.StreamStreamClientInterceptor): + return _StreamStreamMultiCallable(thunk, method, self._interceptor) + else: + return thunk(method) + + +def intercept_channel(channel, *interceptors): + for interceptor in reversed(list(interceptors)): + if not isinstance(interceptor, grpc.UnaryUnaryClientInterceptor) and \ + not isinstance(interceptor, grpc.UnaryStreamClientInterceptor) and \ + not isinstance(interceptor, grpc.StreamUnaryClientInterceptor) and \ + not isinstance(interceptor, grpc.StreamStreamClientInterceptor): + raise TypeError('interceptor must be ' + 'grpc.UnaryUnaryClientInterceptor or ' + 'grpc.UnaryStreamClientInterceptor or ' + 'grpc.StreamUnaryClientInterceptor or ' + 'grpc.StreamStreamClientInterceptor or ') + channel = _Channel(channel, interceptor) + return channel diff --git a/src/python/grpcio/grpc/_plugin_wrapping.py b/src/python/grpcio/grpc/_plugin_wrapping.py index cd17f4a049..f7287956dc 100644 --- a/src/python/grpcio/grpc/_plugin_wrapping.py +++ b/src/python/grpcio/grpc/_plugin_wrapping.py @@ -54,9 +54,7 @@ class _AuthMetadataPluginCallback(grpc.AuthMetadataPluginCallback): 'AuthMetadataPluginCallback raised exception "{}"!'.format( self._state.exception)) if error is None: - self._callback( - _common.to_cygrpc_metadata(metadata), cygrpc.StatusCode.ok, - None) + self._callback(metadata, cygrpc.StatusCode.ok, None) else: self._callback(None, cygrpc.StatusCode.internal, _common.encode(str(error))) diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py index 5b4812bffe..8857bd3eb5 100644 --- a/src/python/grpcio/grpc/_server.py +++ b/src/python/grpcio/grpc/_server.py @@ -23,6 +23,7 @@ import six import grpc from grpc import _common +from grpc import _interceptor from grpc._cython import cygrpc from grpc.framework.foundation import callable_util @@ -96,6 +97,7 @@ class _RPCState(object): self.statused = False self.rpc_errors = [] self.callbacks = [] + self.abortion = None def _raise_rpc_error(state): @@ -129,19 +131,17 @@ def _abort(state, call, code, details): effective_details = details if state.details is None else state.details if state.initial_metadata_allowed: operations = (cygrpc.operation_send_initial_metadata( - _common.EMPTY_METADATA, - _EMPTY_FLAGS), cygrpc.operation_send_status_from_server( - _common.to_cygrpc_metadata(state.trailing_metadata), - effective_code, effective_details, _EMPTY_FLAGS),) + (), _EMPTY_FLAGS), cygrpc.operation_send_status_from_server( + state.trailing_metadata, effective_code, effective_details, + _EMPTY_FLAGS),) token = _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN else: operations = (cygrpc.operation_send_status_from_server( - _common.to_cygrpc_metadata(state.trailing_metadata), - effective_code, effective_details, _EMPTY_FLAGS),) + state.trailing_metadata, effective_code, effective_details, + _EMPTY_FLAGS),) token = _SEND_STATUS_FROM_SERVER_TOKEN - call.start_server_batch( - cygrpc.Operations(operations), - _send_status_from_server(state, token)) + call.start_server_batch(operations, + _send_status_from_server(state, token)) state.statused = True state.due.add(token) @@ -237,7 +237,7 @@ class _Context(grpc.ServicerContext): self._state.disable_next_compression = True def invocation_metadata(self): - return _common.to_application_metadata(self._rpc_event.request_metadata) + return self._rpc_event.request_metadata def peer(self): return _common.decode(self._rpc_event.operation_call.peer()) @@ -263,11 +263,9 @@ class _Context(grpc.ServicerContext): else: if self._state.initial_metadata_allowed: operation = cygrpc.operation_send_initial_metadata( - _common.to_cygrpc_metadata(initial_metadata), - _EMPTY_FLAGS) + initial_metadata, _EMPTY_FLAGS) self._rpc_event.operation_call.start_server_batch( - cygrpc.Operations((operation,)), - _send_initial_metadata(self._state)) + (operation,), _send_initial_metadata(self._state)) self._state.initial_metadata_allowed = False self._state.due.add(_SEND_INITIAL_METADATA_TOKEN) else: @@ -275,8 +273,14 @@ class _Context(grpc.ServicerContext): def set_trailing_metadata(self, trailing_metadata): with self._state.condition: - self._state.trailing_metadata = _common.to_cygrpc_metadata( - trailing_metadata) + self._state.trailing_metadata = trailing_metadata + + def abort(self, code, details): + with self._state.condition: + self._state.code = code + self._state.details = _common.encode(details) + self._state.abortion = Exception() + raise self._state.abortion def set_code(self, code): with self._state.condition: @@ -301,8 +305,7 @@ class _RequestIterator(object): raise StopIteration() else: self._call.start_server_batch( - cygrpc.Operations( - (cygrpc.operation_receive_message(_EMPTY_FLAGS),)), + (cygrpc.operation_receive_message(_EMPTY_FLAGS),), _receive_message(self._state, self._call, self._request_deserializer)) self._state.due.add(_RECEIVE_MESSAGE_TOKEN) @@ -345,8 +348,7 @@ def _unary_request(rpc_event, state, request_deserializer): return None else: rpc_event.operation_call.start_server_batch( - cygrpc.Operations( - (cygrpc.operation_receive_message(_EMPTY_FLAGS),)), + (cygrpc.operation_receive_message(_EMPTY_FLAGS),), _receive_message(state, rpc_event.operation_call, request_deserializer)) state.due.add(_RECEIVE_MESSAGE_TOKEN) @@ -376,7 +378,10 @@ def _call_behavior(rpc_event, state, behavior, argument, request_deserializer): return behavior(argument, context), True except Exception as exception: # pylint: disable=broad-except with state.condition: - if exception not in state.rpc_errors: + if exception is state.abortion: + _abort(state, rpc_event.operation_call, + cygrpc.StatusCode.unknown, b'RPC Aborted') + elif exception not in state.rpc_errors: details = 'Exception calling application: {}'.format(exception) logging.exception(details) _abort(state, rpc_event.operation_call, @@ -391,7 +396,10 @@ def _take_response_from_response_iterator(rpc_event, state, response_iterator): return None, True except Exception as exception: # pylint: disable=broad-except with state.condition: - if exception not in state.rpc_errors: + if exception is state.abortion: + _abort(state, rpc_event.operation_call, + cygrpc.StatusCode.unknown, b'RPC Aborted') + elif exception not in state.rpc_errors: details = 'Exception iterating responses: {}'.format(exception) logging.exception(details) _abort(state, rpc_event.operation_call, @@ -417,9 +425,8 @@ def _send_response(rpc_event, state, serialized_response): else: if state.initial_metadata_allowed: operations = (cygrpc.operation_send_initial_metadata( - _common.EMPTY_METADATA, _EMPTY_FLAGS), - cygrpc.operation_send_message(serialized_response, - _EMPTY_FLAGS),) + (), _EMPTY_FLAGS), cygrpc.operation_send_message( + serialized_response, _EMPTY_FLAGS),) state.initial_metadata_allowed = False token = _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN else: @@ -427,7 +434,7 @@ def _send_response(rpc_event, state, serialized_response): _EMPTY_FLAGS),) token = _SEND_MESSAGE_TOKEN rpc_event.operation_call.start_server_batch( - cygrpc.Operations(operations), _send_message(state, token)) + operations, _send_message(state, token)) state.due.add(token) while True: state.condition.wait() @@ -438,24 +445,21 @@ def _send_response(rpc_event, state, serialized_response): def _status(rpc_event, state, serialized_response): with state.condition: if state.client is not _CANCELLED: - trailing_metadata = _common.to_cygrpc_metadata( - state.trailing_metadata) code = _completion_code(state) details = _details(state) operations = [ cygrpc.operation_send_status_from_server( - trailing_metadata, code, details, _EMPTY_FLAGS), + state.trailing_metadata, code, details, _EMPTY_FLAGS), ] if state.initial_metadata_allowed: operations.append( - cygrpc.operation_send_initial_metadata( - _common.EMPTY_METADATA, _EMPTY_FLAGS)) + cygrpc.operation_send_initial_metadata((), _EMPTY_FLAGS)) if serialized_response is not None: operations.append( cygrpc.operation_send_message(serialized_response, _EMPTY_FLAGS)) rpc_event.operation_call.start_server_batch( - cygrpc.Operations(operations), + operations, _send_status_from_server(state, _SEND_STATUS_FROM_SERVER_TOKEN)) state.statused = True state.due.add(_SEND_STATUS_FROM_SERVER_TOKEN) @@ -538,24 +542,31 @@ def _handle_stream_stream(rpc_event, state, method_handler, thread_pool): method_handler.request_deserializer, method_handler.response_serializer) -def _find_method_handler(rpc_event, generic_handlers): - for generic_handler in generic_handlers: - method_handler = generic_handler.service( - _HandlerCallDetails( - _common.decode(rpc_event.request_call_details.method), - rpc_event.request_metadata)) - if method_handler is not None: - return method_handler - else: +def _find_method_handler(rpc_event, generic_handlers, interceptor_pipeline): + + def query_handlers(handler_call_details): + for generic_handler in generic_handlers: + method_handler = generic_handler.service(handler_call_details) + if method_handler is not None: + return method_handler return None + handler_call_details = _HandlerCallDetails( + _common.decode(rpc_event.request_call_details.method), + rpc_event.request_metadata) + + if interceptor_pipeline is not None: + return interceptor_pipeline.execute(query_handlers, + handler_call_details) + else: + return query_handlers(handler_call_details) + def _reject_rpc(rpc_event, status, details): - operations = (cygrpc.operation_send_initial_metadata(_common.EMPTY_METADATA, - _EMPTY_FLAGS), + operations = (cygrpc.operation_send_initial_metadata((), _EMPTY_FLAGS), cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS), - cygrpc.operation_send_status_from_server( - _common.EMPTY_METADATA, status, details, _EMPTY_FLAGS),) + cygrpc.operation_send_status_from_server((), status, details, + _EMPTY_FLAGS),) rpc_state = _RPCState() rpc_event.operation_call.start_server_batch( operations, lambda ignored_event: (rpc_state, (),)) @@ -566,8 +577,7 @@ def _handle_with_method_handler(rpc_event, method_handler, thread_pool): state = _RPCState() with state.condition: rpc_event.operation_call.start_server_batch( - cygrpc.Operations( - (cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),)), + (cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),), _receive_close_on_server(state)) state.due.add(_RECEIVE_CLOSE_ON_SERVER_TOKEN) if method_handler.request_streaming: @@ -586,13 +596,14 @@ def _handle_with_method_handler(rpc_event, method_handler, thread_pool): method_handler, thread_pool) -def _handle_call(rpc_event, generic_handlers, thread_pool, +def _handle_call(rpc_event, generic_handlers, interceptor_pipeline, thread_pool, concurrency_exceeded): if not rpc_event.success: return None, None if rpc_event.request_call_details.method is not None: try: - method_handler = _find_method_handler(rpc_event, generic_handlers) + method_handler = _find_method_handler(rpc_event, generic_handlers, + interceptor_pipeline) except Exception as exception: # pylint: disable=broad-except details = 'Exception servicing handler: {}'.format(exception) logging.exception(details) @@ -620,12 +631,14 @@ class _ServerStage(enum.Enum): class _ServerState(object): - def __init__(self, completion_queue, server, generic_handlers, thread_pool, - maximum_concurrent_rpcs): + # pylint: disable=too-many-arguments + def __init__(self, completion_queue, server, generic_handlers, + interceptor_pipeline, thread_pool, maximum_concurrent_rpcs): self.lock = threading.Lock() self.completion_queue = completion_queue self.server = server self.generic_handlers = list(generic_handlers) + self.interceptor_pipeline = interceptor_pipeline self.thread_pool = thread_pool self.stage = _ServerStage.STOPPED self.shutdown_events = None @@ -690,8 +703,8 @@ def _serve(state): state.maximum_concurrent_rpcs is not None and state.active_rpc_count >= state.maximum_concurrent_rpcs) rpc_state, rpc_future = _handle_call( - event, state.generic_handlers, state.thread_pool, - concurrency_exceeded) + event, state.generic_handlers, state.interceptor_pipeline, + state.thread_pool, concurrency_exceeded) if rpc_state is not None: state.rpc_states.add(rpc_state) if rpc_future is not None: @@ -779,12 +792,14 @@ def _start(state): class Server(grpc.Server): - def __init__(self, thread_pool, generic_handlers, options, + # pylint: disable=too-many-arguments + def __init__(self, thread_pool, generic_handlers, interceptors, options, maximum_concurrent_rpcs): completion_queue = cygrpc.CompletionQueue() server = cygrpc.Server(_common.channel_args(options)) server.register_completion_queue(completion_queue) self._state = _ServerState(completion_queue, server, generic_handlers, + _interceptor.service_pipeline(interceptors), thread_pool, maximum_concurrent_rpcs) def add_generic_rpc_handlers(self, generic_rpc_handlers): diff --git a/src/python/grpcio/grpc/beta/_client_adaptations.py b/src/python/grpcio/grpc/beta/_client_adaptations.py index 73ce22fa98..dcaa0eeaa2 100644 --- a/src/python/grpcio/grpc/beta/_client_adaptations.py +++ b/src/python/grpcio/grpc/beta/_client_adaptations.py @@ -15,6 +15,7 @@ import grpc from grpc import _common +from grpc.beta import _metadata from grpc.beta import interfaces from grpc.framework.common import cardinality from grpc.framework.foundation import future @@ -157,10 +158,10 @@ class _Rendezvous(future.Future, face.Call): return _InvocationProtocolContext() def initial_metadata(self): - return self._call.initial_metadata() + return _metadata.beta(self._call.initial_metadata()) def terminal_metadata(self): - return self._call.terminal_metadata() + return _metadata.beta(self._call.terminal_metadata()) def code(self): return self._call.code() @@ -182,14 +183,14 @@ def _blocking_unary_unary(channel, group, method, timeout, with_call, response, call = multi_callable.with_call( request, timeout=timeout, - metadata=effective_metadata, + metadata=_metadata.unbeta(effective_metadata), credentials=_credentials(protocol_options)) return response, _Rendezvous(None, None, call) else: return multi_callable( request, timeout=timeout, - metadata=effective_metadata, + metadata=_metadata.unbeta(effective_metadata), credentials=_credentials(protocol_options)) except grpc.RpcError as rpc_error_call: raise _abortion_error(rpc_error_call) @@ -206,7 +207,7 @@ def _future_unary_unary(channel, group, method, timeout, protocol_options, response_future = multi_callable.future( request, timeout=timeout, - metadata=effective_metadata, + metadata=_metadata.unbeta(effective_metadata), credentials=_credentials(protocol_options)) return _Rendezvous(response_future, None, response_future) @@ -222,7 +223,7 @@ def _unary_stream(channel, group, method, timeout, protocol_options, metadata, response_iterator = multi_callable( request, timeout=timeout, - metadata=effective_metadata, + metadata=_metadata.unbeta(effective_metadata), credentials=_credentials(protocol_options)) return _Rendezvous(None, response_iterator, response_iterator) @@ -241,14 +242,14 @@ def _blocking_stream_unary(channel, group, method, timeout, with_call, response, call = multi_callable.with_call( request_iterator, timeout=timeout, - metadata=effective_metadata, + metadata=_metadata.unbeta(effective_metadata), credentials=_credentials(protocol_options)) return response, _Rendezvous(None, None, call) else: return multi_callable( request_iterator, timeout=timeout, - metadata=effective_metadata, + metadata=_metadata.unbeta(effective_metadata), credentials=_credentials(protocol_options)) except grpc.RpcError as rpc_error_call: raise _abortion_error(rpc_error_call) @@ -265,7 +266,7 @@ def _future_stream_unary(channel, group, method, timeout, protocol_options, response_future = multi_callable.future( request_iterator, timeout=timeout, - metadata=effective_metadata, + metadata=_metadata.unbeta(effective_metadata), credentials=_credentials(protocol_options)) return _Rendezvous(response_future, None, response_future) @@ -281,7 +282,7 @@ def _stream_stream(channel, group, method, timeout, protocol_options, metadata, response_iterator = multi_callable( request_iterator, timeout=timeout, - metadata=effective_metadata, + metadata=_metadata.unbeta(effective_metadata), credentials=_credentials(protocol_options)) return _Rendezvous(None, response_iterator, response_iterator) diff --git a/src/python/grpcio/grpc/beta/_metadata.py b/src/python/grpcio/grpc/beta/_metadata.py new file mode 100644 index 0000000000..e135f4dff4 --- /dev/null +++ b/src/python/grpcio/grpc/beta/_metadata.py @@ -0,0 +1,49 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""API metadata conversion utilities.""" + +import collections + +_Metadatum = collections.namedtuple('_Metadatum', ('key', 'value',)) + + +def _beta_metadatum(key, value): + beta_key = key if isinstance(key, (bytes,)) else key.encode('ascii') + beta_value = value if isinstance(value, (bytes,)) else value.encode('ascii') + return _Metadatum(beta_key, beta_value) + + +def _metadatum(beta_key, beta_value): + key = beta_key if isinstance(beta_key, (str,)) else beta_key.decode('utf8') + if isinstance(beta_value, (str,)) or key[-4:] == '-bin': + value = beta_value + else: + value = beta_value.decode('utf8') + return _Metadatum(key, value) + + +def beta(metadata): + if metadata is None: + return () + else: + return tuple(_beta_metadatum(key, value) for key, value in metadata) + + +def unbeta(beta_metadata): + if beta_metadata is None: + return () + else: + return tuple( + _metadatum(beta_key, beta_value) + for beta_key, beta_value in beta_metadata) diff --git a/src/python/grpcio/grpc/beta/_server_adaptations.py b/src/python/grpcio/grpc/beta/_server_adaptations.py index ec363e9bc9..1c22dbe3bb 100644 --- a/src/python/grpcio/grpc/beta/_server_adaptations.py +++ b/src/python/grpcio/grpc/beta/_server_adaptations.py @@ -18,6 +18,7 @@ import threading import grpc from grpc import _common +from grpc.beta import _metadata from grpc.beta import interfaces from grpc.framework.common import cardinality from grpc.framework.common import style @@ -65,14 +66,15 @@ class _FaceServicerContext(face.ServicerContext): return _ServerProtocolContext(self._servicer_context) def invocation_metadata(self): - return _common.to_cygrpc_metadata( - self._servicer_context.invocation_metadata()) + return _metadata.beta(self._servicer_context.invocation_metadata()) def initial_metadata(self, initial_metadata): - self._servicer_context.send_initial_metadata(initial_metadata) + self._servicer_context.send_initial_metadata( + _metadata.unbeta(initial_metadata)) def terminal_metadata(self, terminal_metadata): - self._servicer_context.set_terminal_metadata(terminal_metadata) + self._servicer_context.set_terminal_metadata( + _metadata.unbeta(terminal_metadata)) def code(self, code): self._servicer_context.set_code(code) diff --git a/src/python/grpcio/grpc/beta/implementations.py b/src/python/grpcio/grpc/beta/implementations.py index e52ce764b5..312daf033e 100644 --- a/src/python/grpcio/grpc/beta/implementations.py +++ b/src/python/grpcio/grpc/beta/implementations.py @@ -21,6 +21,7 @@ import threading # pylint: disable=unused-import import grpc from grpc import _auth from grpc.beta import _client_adaptations +from grpc.beta import _metadata from grpc.beta import _server_adaptations from grpc.beta import interfaces # pylint: disable=unused-import from grpc.framework.common import cardinality # pylint: disable=unused-import @@ -31,7 +32,18 @@ from grpc.framework.interfaces.face import face # pylint: disable=unused-import ChannelCredentials = grpc.ChannelCredentials ssl_channel_credentials = grpc.ssl_channel_credentials CallCredentials = grpc.CallCredentials -metadata_call_credentials = grpc.metadata_call_credentials + + +def metadata_call_credentials(metadata_plugin, name=None): + + def plugin(context, callback): + + def wrapped_callback(beta_metadata, error): + callback(_metadata.unbeta(beta_metadata), error) + + metadata_plugin(context, wrapped_callback) + + return grpc.metadata_call_credentials(plugin, name=name) def google_call_credentials(credentials): diff --git a/src/python/grpcio_testing/grpc_testing/_server/_servicer_context.py b/src/python/grpcio_testing/grpc_testing/_server/_servicer_context.py index 496689ded0..90eeb130d3 100644 --- a/src/python/grpcio_testing/grpc_testing/_server/_servicer_context.py +++ b/src/python/grpcio_testing/grpc_testing/_server/_servicer_context.py @@ -67,6 +67,9 @@ class ServicerContext(grpc.ServicerContext): self._rpc.set_trailing_metadata( _common.fuss_with_metadata(trailing_metadata)) + def abort(self, code, details): + raise NotImplementedError() + def set_code(self, code): self._rpc.set_code(code) diff --git a/src/python/grpcio_tests/tests/tests.json b/src/python/grpcio_tests/tests/tests.json index 34cbade92c..3bf5308749 100644 --- a/src/python/grpcio_tests/tests/tests.json +++ b/src/python/grpcio_tests/tests/tests.json @@ -39,6 +39,7 @@ "unit._cython.cygrpc_test.TypeSmokeTest", "unit._empty_message_test.EmptyMessageTest", "unit._exit_test.ExitTest", + "unit._interceptor_test.InterceptorTest", "unit._invalid_metadata_test.InvalidMetadataTest", "unit._invocation_defects_test.InvocationDefectsTest", "unit._metadata_code_details_test.MetadataCodeDetailsTest", diff --git a/src/python/grpcio_tests/tests/unit/_api_test.py b/src/python/grpcio_tests/tests/unit/_api_test.py index b14e8d5c75..d6f4447532 100644 --- a/src/python/grpcio_tests/tests/unit/_api_test.py +++ b/src/python/grpcio_tests/tests/unit/_api_test.py @@ -33,18 +33,21 @@ class AllTest(unittest.TestCase): 'AuthMetadataPlugin', 'ServerCertificateConfiguration', 'ServerCredentials', 'UnaryUnaryMultiCallable', 'UnaryStreamMultiCallable', 'StreamUnaryMultiCallable', - 'StreamStreamMultiCallable', 'Channel', 'ServicerContext', + 'StreamStreamMultiCallable', 'UnaryUnaryClientInterceptor', + 'UnaryStreamClientInterceptor', 'StreamUnaryClientInterceptor', + 'StreamStreamClientInterceptor', 'Channel', 'ServicerContext', 'RpcMethodHandler', 'HandlerCallDetails', 'GenericRpcHandler', - 'ServiceRpcHandler', 'Server', 'unary_unary_rpc_method_handler', - 'unary_stream_rpc_method_handler', - 'stream_unary_rpc_method_handler', + 'ServiceRpcHandler', 'Server', 'ServerInterceptor', + 'unary_unary_rpc_method_handler', 'unary_stream_rpc_method_handler', + 'stream_unary_rpc_method_handler', 'ClientCallDetails', 'stream_stream_rpc_method_handler', 'method_handlers_generic_handler', 'ssl_channel_credentials', 'metadata_call_credentials', 'access_token_call_credentials', 'composite_call_credentials', 'composite_channel_credentials', 'ssl_server_credentials', 'ssl_server_certificate_configuration', 'dynamic_ssl_server_credentials', 'channel_ready_future', - 'insecure_channel', 'secure_channel', 'server',) + 'insecure_channel', 'secure_channel', 'intercept_channel', + 'server',) six.assertCountEqual(self, expected_grpc_code_elements, _from_grpc_import_star.GRPC_ELEMENTS) diff --git a/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py b/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py index 5b97b7b542..a8a7175cc7 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py @@ -22,7 +22,7 @@ from tests.unit.framework.common import test_constants _INFINITE_FUTURE = cygrpc.Timespec(float('+inf')) _EMPTY_FLAGS = 0 -_EMPTY_METADATA = cygrpc.Metadata(()) +_EMPTY_METADATA = () _SERVER_SHUTDOWN_TAG = 'server_shutdown' _REQUEST_CALL_TAG = 'request_call' @@ -65,12 +65,10 @@ class _Handler(object): with self._lock: self._call.start_server_batch( - cygrpc.Operations( - (cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),)), + (cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),), _RECEIVE_CLOSE_ON_SERVER_TAG) self._call.start_server_batch( - cygrpc.Operations( - (cygrpc.operation_receive_message(_EMPTY_FLAGS),)), + (cygrpc.operation_receive_message(_EMPTY_FLAGS),), _RECEIVE_MESSAGE_TAG) first_event = self._completion_queue.poll() if _is_cancellation_event(first_event): @@ -84,8 +82,8 @@ class _Handler(object): cygrpc.operation_send_status_from_server( _EMPTY_METADATA, cygrpc.StatusCode.ok, b'test details!', _EMPTY_FLAGS),) - self._call.start_server_batch( - cygrpc.Operations(operations), _SERVER_COMPLETE_CALL_TAG) + self._call.start_server_batch(operations, + _SERVER_COMPLETE_CALL_TAG) self._completion_queue.poll() self._completion_queue.poll() @@ -179,8 +177,7 @@ class CancelManyCallsTest(unittest.TestCase): cygrpc.operation_receive_message(_EMPTY_FLAGS), cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),) tag = 'client_complete_call_{0:04d}_tag'.format(index) - client_call.start_client_batch( - cygrpc.Operations(operations), tag) + client_call.start_client_batch(operations, tag) client_due.add(tag) client_calls.append(client_call) diff --git a/src/python/grpcio_tests/tests/unit/_cython/_common.py b/src/python/grpcio_tests/tests/unit/_cython/_common.py index ac66d1db3d..96f0f1589b 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_common.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_common.py @@ -23,17 +23,14 @@ RPC_COUNT = 4000 INFINITE_FUTURE = cygrpc.Timespec(float('+inf')) EMPTY_FLAGS = 0 -INVOCATION_METADATA = cygrpc.Metadata( - (cygrpc.Metadatum(b'client-md-key', b'client-md-key'), - cygrpc.Metadatum(b'client-md-key-bin', b'\x00\x01' * 3000),)) +INVOCATION_METADATA = (('client-md-key', 'client-md-key'), + ('client-md-key-bin', b'\x00\x01' * 3000),) -INITIAL_METADATA = cygrpc.Metadata( - (cygrpc.Metadatum(b'server-initial-md-key', b'server-initial-md-value'), - cygrpc.Metadatum(b'server-initial-md-key-bin', b'\x00\x02' * 3000),)) +INITIAL_METADATA = (('server-initial-md-key', 'server-initial-md-value'), + ('server-initial-md-key-bin', b'\x00\x02' * 3000),) -TRAILING_METADATA = cygrpc.Metadata( - (cygrpc.Metadatum(b'server-trailing-md-key', b'server-trailing-md-value'), - cygrpc.Metadatum(b'server-trailing-md-key-bin', b'\x00\x03' * 3000),)) +TRAILING_METADATA = (('server-trailing-md-key', 'server-trailing-md-value'), + ('server-trailing-md-key-bin', b'\x00\x03' * 3000),) class QueueDriver(object): diff --git a/src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py b/src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py index 14cc66675c..d08003af44 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py @@ -48,20 +48,19 @@ class Test(_common.RpcTest, unittest.TestCase): client_complete_rpc_tag = 'client_complete_rpc_tag' with self.client_condition: client_receive_initial_metadata_start_batch_result = ( - client_call.start_client_batch( - cygrpc.Operations([ - cygrpc.operation_receive_initial_metadata( - _common.EMPTY_FLAGS), - ]), client_receive_initial_metadata_tag)) + client_call.start_client_batch([ + cygrpc.operation_receive_initial_metadata( + _common.EMPTY_FLAGS), + ], client_receive_initial_metadata_tag)) client_complete_rpc_start_batch_result = client_call.start_client_batch( - cygrpc.Operations([ + [ cygrpc.operation_send_initial_metadata( _common.INVOCATION_METADATA, _common.EMPTY_FLAGS), cygrpc.operation_send_close_from_client( _common.EMPTY_FLAGS), cygrpc.operation_receive_status_on_client( _common.EMPTY_FLAGS), - ]), client_complete_rpc_tag) + ], client_complete_rpc_tag) self.client_driver.add_due({ client_receive_initial_metadata_tag, client_complete_rpc_tag, diff --git a/src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py b/src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py index 1e44bcc4dc..d0166a2b29 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py @@ -43,20 +43,19 @@ class Test(_common.RpcTest, unittest.TestCase): client_complete_rpc_tag = 'client_complete_rpc_tag' with self.client_condition: client_receive_initial_metadata_start_batch_result = ( - client_call.start_client_batch( - cygrpc.Operations([ - cygrpc.operation_receive_initial_metadata( - _common.EMPTY_FLAGS), - ]), client_receive_initial_metadata_tag)) + client_call.start_client_batch([ + cygrpc.operation_receive_initial_metadata( + _common.EMPTY_FLAGS), + ], client_receive_initial_metadata_tag)) client_complete_rpc_start_batch_result = client_call.start_client_batch( - cygrpc.Operations([ + [ cygrpc.operation_send_initial_metadata( _common.INVOCATION_METADATA, _common.EMPTY_FLAGS), cygrpc.operation_send_close_from_client( _common.EMPTY_FLAGS), cygrpc.operation_receive_status_on_client( _common.EMPTY_FLAGS), - ]), client_complete_rpc_tag) + ], client_complete_rpc_tag) self.client_driver.add_due({ client_receive_initial_metadata_tag, client_complete_rpc_tag, diff --git a/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py b/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py index 0105612b47..1deb15ba03 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py @@ -20,7 +20,7 @@ from grpc._cython import cygrpc _INFINITE_FUTURE = cygrpc.Timespec(float('+inf')) _EMPTY_FLAGS = 0 -_EMPTY_METADATA = cygrpc.Metadata(()) +_EMPTY_METADATA = () class _ServerDriver(object): @@ -157,19 +157,17 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase): client_complete_rpc_tag = 'client_complete_rpc_tag' with client_condition: client_receive_initial_metadata_start_batch_result = ( - client_call.start_client_batch( - cygrpc.Operations([ - cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS), - ]), client_receive_initial_metadata_tag)) + client_call.start_client_batch([ + cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS), + ], client_receive_initial_metadata_tag)) client_due.add(client_receive_initial_metadata_tag) client_complete_rpc_start_batch_result = ( - client_call.start_client_batch( - cygrpc.Operations([ - cygrpc.operation_send_initial_metadata(_EMPTY_METADATA, - _EMPTY_FLAGS), - cygrpc.operation_send_close_from_client(_EMPTY_FLAGS), - cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS), - ]), client_complete_rpc_tag)) + client_call.start_client_batch([ + cygrpc.operation_send_initial_metadata(_EMPTY_METADATA, + _EMPTY_FLAGS), + cygrpc.operation_send_close_from_client(_EMPTY_FLAGS), + cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS), + ], client_complete_rpc_tag)) client_due.add(client_complete_rpc_tag) server_rpc_event = server_driver.first_event() @@ -197,8 +195,8 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase): server_rpc_event.operation_call.start_server_batch([ cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS), cygrpc.operation_send_status_from_server( - cygrpc.Metadata(()), cygrpc.StatusCode.ok, - b'test details', _EMPTY_FLAGS), + (), cygrpc.StatusCode.ok, b'test details', + _EMPTY_FLAGS), ], server_complete_rpc_tag)) server_send_second_message_event = server_call_driver.event_with_tag( server_send_second_message_tag) @@ -209,10 +207,9 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase): with client_condition: client_receive_first_message_tag = 'client_receive_first_message_tag' client_receive_first_message_start_batch_result = ( - client_call.start_client_batch( - cygrpc.Operations([ - cygrpc.operation_receive_message(_EMPTY_FLAGS), - ]), client_receive_first_message_tag)) + client_call.start_client_batch([ + cygrpc.operation_receive_message(_EMPTY_FLAGS), + ], client_receive_first_message_tag)) client_due.add(client_receive_first_message_tag) client_receive_first_message_event = client_driver.event_with_tag( client_receive_first_message_tag) diff --git a/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py b/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py index da94cf8028..4eda685486 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py @@ -29,50 +29,12 @@ _EMPTY_FLAGS = 0 def _metadata_plugin(context, callback): - callback( - cygrpc.Metadata([ - cygrpc.Metadatum(_CALL_CREDENTIALS_METADATA_KEY, - _CALL_CREDENTIALS_METADATA_VALUE) - ]), cygrpc.StatusCode.ok, b'') + callback(((_CALL_CREDENTIALS_METADATA_KEY, + _CALL_CREDENTIALS_METADATA_VALUE,),), cygrpc.StatusCode.ok, b'') class TypeSmokeTest(unittest.TestCase): - def testStringsInUtilitiesUpDown(self): - self.assertEqual(0, cygrpc.StatusCode.ok) - metadatum = cygrpc.Metadatum(b'a', b'b') - self.assertEqual(b'a', metadatum.key) - self.assertEqual(b'b', metadatum.value) - metadata = cygrpc.Metadata([metadatum]) - self.assertEqual(1, len(metadata)) - self.assertEqual(metadatum.key, metadata[0].key) - - def testMetadataIteration(self): - metadata = cygrpc.Metadata( - [cygrpc.Metadatum(b'a', b'b'), cygrpc.Metadatum(b'c', b'd')]) - iterator = iter(metadata) - metadatum = next(iterator) - self.assertIsInstance(metadatum, cygrpc.Metadatum) - self.assertEqual(metadatum.key, b'a') - self.assertEqual(metadatum.value, b'b') - metadatum = next(iterator) - self.assertIsInstance(metadatum, cygrpc.Metadatum) - self.assertEqual(metadatum.key, b'c') - self.assertEqual(metadatum.value, b'd') - with self.assertRaises(StopIteration): - next(iterator) - - def testOperationsIteration(self): - operations = cygrpc.Operations( - [cygrpc.operation_send_message(b'asdf', _EMPTY_FLAGS)]) - iterator = iter(operations) - operation = next(iterator) - self.assertIsInstance(operation, cygrpc.Operation) - # `Operation`s are write-only structures; can't directly debug anything out - # of them. Just check that we stop iterating. - with self.assertRaises(StopIteration): - next(iterator) - def testOperationFlags(self): operation = cygrpc.operation_send_message(b'asdf', cygrpc.WriteFlag.no_compress) @@ -182,8 +144,7 @@ class ServerClientMixin(object): def performer(): tag = object() try: - call_result = call.start_client_batch( - cygrpc.Operations(operations), tag) + call_result = call.start_client_batch(operations, tag) self.assertEqual(cygrpc.CallError.ok, call_result) event = queue.poll(deadline) self.assertEqual(cygrpc.CompletionType.operation_complete, @@ -200,14 +161,14 @@ class ServerClientMixin(object): def test_echo(self): DEADLINE = time.time() + 5 DEADLINE_TOLERANCE = 0.25 - CLIENT_METADATA_ASCII_KEY = b'key' - CLIENT_METADATA_ASCII_VALUE = b'val' - CLIENT_METADATA_BIN_KEY = b'key-bin' + CLIENT_METADATA_ASCII_KEY = 'key' + CLIENT_METADATA_ASCII_VALUE = 'val' + CLIENT_METADATA_BIN_KEY = 'key-bin' CLIENT_METADATA_BIN_VALUE = b'\0' * 1000 - SERVER_INITIAL_METADATA_KEY = b'init_me_me_me' - SERVER_INITIAL_METADATA_VALUE = b'whodawha?' - SERVER_TRAILING_METADATA_KEY = b'california_is_in_a_drought' - SERVER_TRAILING_METADATA_VALUE = b'zomg it is' + SERVER_INITIAL_METADATA_KEY = 'init_me_me_me' + SERVER_INITIAL_METADATA_VALUE = 'whodawha?' + SERVER_TRAILING_METADATA_KEY = 'california_is_in_a_drought' + SERVER_TRAILING_METADATA_VALUE = 'zomg it is' SERVER_STATUS_CODE = cygrpc.StatusCode.ok SERVER_STATUS_DETAILS = b'our work is never over' REQUEST = b'in death a member of project mayhem has a name' @@ -227,11 +188,9 @@ class ServerClientMixin(object): client_call = self.client_channel.create_call( None, 0, self.client_completion_queue, METHOD, self.host_argument, cygrpc_deadline) - client_initial_metadata = cygrpc.Metadata([ - cygrpc.Metadatum(CLIENT_METADATA_ASCII_KEY, - CLIENT_METADATA_ASCII_VALUE), - cygrpc.Metadatum(CLIENT_METADATA_BIN_KEY, CLIENT_METADATA_BIN_VALUE) - ]) + client_initial_metadata = ( + (CLIENT_METADATA_ASCII_KEY, CLIENT_METADATA_ASCII_VALUE,), + (CLIENT_METADATA_BIN_KEY, CLIENT_METADATA_BIN_VALUE,),) client_start_batch_result = client_call.start_client_batch([ cygrpc.operation_send_initial_metadata(client_initial_metadata, _EMPTY_FLAGS), @@ -263,14 +222,10 @@ class ServerClientMixin(object): server_call_tag = object() server_call = request_event.operation_call - server_initial_metadata = cygrpc.Metadata([ - cygrpc.Metadatum(SERVER_INITIAL_METADATA_KEY, - SERVER_INITIAL_METADATA_VALUE) - ]) - server_trailing_metadata = cygrpc.Metadata([ - cygrpc.Metadatum(SERVER_TRAILING_METADATA_KEY, - SERVER_TRAILING_METADATA_VALUE) - ]) + server_initial_metadata = ( + (SERVER_INITIAL_METADATA_KEY, SERVER_INITIAL_METADATA_VALUE,),) + server_trailing_metadata = ( + (SERVER_TRAILING_METADATA_KEY, SERVER_TRAILING_METADATA_VALUE,),) server_start_batch_result = server_call.start_server_batch([ cygrpc.operation_send_initial_metadata( server_initial_metadata, @@ -347,7 +302,7 @@ class ServerClientMixin(object): METHOD = b'twinkies' cygrpc_deadline = cygrpc.Timespec(DEADLINE) - empty_metadata = cygrpc.Metadata([]) + empty_metadata = () server_request_tag = object() self.server.request_call(self.server_completion_queue, diff --git a/src/python/grpcio_tests/tests/unit/_interceptor_test.py b/src/python/grpcio_tests/tests/unit/_interceptor_test.py new file mode 100644 index 0000000000..cf875ed7da --- /dev/null +++ b/src/python/grpcio_tests/tests/unit/_interceptor_test.py @@ -0,0 +1,571 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test of gRPC Python interceptors.""" + +import collections +import itertools +import threading +import unittest +from concurrent import futures + +import grpc +from grpc.framework.foundation import logging_pool + +from tests.unit.framework.common import test_constants +from tests.unit.framework.common import test_control + +_SERIALIZE_REQUEST = lambda bytestring: bytestring * 2 +_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2:] +_SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3 +_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) // 3] + +_UNARY_UNARY = '/test/UnaryUnary' +_UNARY_STREAM = '/test/UnaryStream' +_STREAM_UNARY = '/test/StreamUnary' +_STREAM_STREAM = '/test/StreamStream' + + +class _Callback(object): + + def __init__(self): + self._condition = threading.Condition() + self._value = None + self._called = False + + def __call__(self, value): + with self._condition: + self._value = value + self._called = True + self._condition.notify_all() + + def value(self): + with self._condition: + while not self._called: + self._condition.wait() + return self._value + + +class _Handler(object): + + def __init__(self, control): + self._control = control + + def handle_unary_unary(self, request, servicer_context): + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata((('testkey', 'testvalue',),)) + return request + + def handle_unary_stream(self, request, servicer_context): + for _ in range(test_constants.STREAM_LENGTH): + self._control.control() + yield request + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata((('testkey', 'testvalue',),)) + + def handle_stream_unary(self, request_iterator, servicer_context): + if servicer_context is not None: + servicer_context.invocation_metadata() + self._control.control() + response_elements = [] + for request in request_iterator: + self._control.control() + response_elements.append(request) + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata((('testkey', 'testvalue',),)) + return b''.join(response_elements) + + def handle_stream_stream(self, request_iterator, servicer_context): + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata((('testkey', 'testvalue',),)) + for request in request_iterator: + self._control.control() + yield request + self._control.control() + + +class _MethodHandler(grpc.RpcMethodHandler): + + def __init__(self, request_streaming, response_streaming, + request_deserializer, response_serializer, unary_unary, + unary_stream, stream_unary, stream_stream): + self.request_streaming = request_streaming + self.response_streaming = response_streaming + self.request_deserializer = request_deserializer + self.response_serializer = response_serializer + self.unary_unary = unary_unary + self.unary_stream = unary_stream + self.stream_unary = stream_unary + self.stream_stream = stream_stream + + +class _GenericHandler(grpc.GenericRpcHandler): + + def __init__(self, handler): + self._handler = handler + + def service(self, handler_call_details): + if handler_call_details.method == _UNARY_UNARY: + return _MethodHandler(False, False, None, None, + self._handler.handle_unary_unary, None, None, + None) + elif handler_call_details.method == _UNARY_STREAM: + return _MethodHandler(False, True, _DESERIALIZE_REQUEST, + _SERIALIZE_RESPONSE, None, + self._handler.handle_unary_stream, None, None) + elif handler_call_details.method == _STREAM_UNARY: + return _MethodHandler(True, False, _DESERIALIZE_REQUEST, + _SERIALIZE_RESPONSE, None, None, + self._handler.handle_stream_unary, None) + elif handler_call_details.method == _STREAM_STREAM: + return _MethodHandler(True, True, None, None, None, None, None, + self._handler.handle_stream_stream) + else: + return None + + +def _unary_unary_multi_callable(channel): + return channel.unary_unary(_UNARY_UNARY) + + +def _unary_stream_multi_callable(channel): + return channel.unary_stream( + _UNARY_STREAM, + request_serializer=_SERIALIZE_REQUEST, + response_deserializer=_DESERIALIZE_RESPONSE) + + +def _stream_unary_multi_callable(channel): + return channel.stream_unary( + _STREAM_UNARY, + request_serializer=_SERIALIZE_REQUEST, + response_deserializer=_DESERIALIZE_RESPONSE) + + +def _stream_stream_multi_callable(channel): + return channel.stream_stream(_STREAM_STREAM) + + +class _ClientCallDetails( + collections.namedtuple('_ClientCallDetails', + ('method', 'timeout', 'metadata', + 'credentials')), grpc.ClientCallDetails): + pass + + +class _GenericClientInterceptor( + grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor, grpc.StreamStreamClientInterceptor): + + def __init__(self, interceptor_function): + self._fn = interceptor_function + + def intercept_unary_unary(self, continuation, client_call_details, request): + new_details, new_request_iterator, postprocess = self._fn( + client_call_details, iter((request,)), False, False) + response = continuation(new_details, next(new_request_iterator)) + return postprocess(response) if postprocess else response + + def intercept_unary_stream(self, continuation, client_call_details, + request): + new_details, new_request_iterator, postprocess = self._fn( + client_call_details, iter((request,)), False, True) + response_it = continuation(new_details, new_request_iterator) + return postprocess(response_it) if postprocess else response_it + + def intercept_stream_unary(self, continuation, client_call_details, + request_iterator): + new_details, new_request_iterator, postprocess = self._fn( + client_call_details, request_iterator, True, False) + response = continuation(new_details, next(new_request_iterator)) + return postprocess(response) if postprocess else response + + def intercept_stream_stream(self, continuation, client_call_details, + request_iterator): + new_details, new_request_iterator, postprocess = self._fn( + client_call_details, request_iterator, True, True) + response_it = continuation(new_details, new_request_iterator) + return postprocess(response_it) if postprocess else response_it + + +class _LoggingInterceptor( + grpc.ServerInterceptor, grpc.UnaryUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, grpc.StreamUnaryClientInterceptor, + grpc.StreamStreamClientInterceptor): + + def __init__(self, tag, record): + self.tag = tag + self.record = record + + def intercept_service(self, continuation, handler_call_details): + self.record.append(self.tag + ':intercept_service') + return continuation(handler_call_details) + + def intercept_unary_unary(self, continuation, client_call_details, request): + self.record.append(self.tag + ':intercept_unary_unary') + return continuation(client_call_details, request) + + def intercept_unary_stream(self, continuation, client_call_details, + request): + self.record.append(self.tag + ':intercept_unary_stream') + return continuation(client_call_details, request) + + def intercept_stream_unary(self, continuation, client_call_details, + request_iterator): + self.record.append(self.tag + ':intercept_stream_unary') + return continuation(client_call_details, request_iterator) + + def intercept_stream_stream(self, continuation, client_call_details, + request_iterator): + self.record.append(self.tag + ':intercept_stream_stream') + return continuation(client_call_details, request_iterator) + + +class _DefectiveClientInterceptor(grpc.UnaryUnaryClientInterceptor): + + def intercept_unary_unary(self, ignored_continuation, + ignored_client_call_details, ignored_request): + raise test_control.Defect() + + +def _wrap_request_iterator_stream_interceptor(wrapper): + + def intercept_call(client_call_details, request_iterator, request_streaming, + ignored_response_streaming): + if request_streaming: + return client_call_details, wrapper(request_iterator), None + else: + return client_call_details, request_iterator, None + + return _GenericClientInterceptor(intercept_call) + + +def _append_request_header_interceptor(header, value): + + def intercept_call(client_call_details, request_iterator, + ignored_request_streaming, ignored_response_streaming): + metadata = [] + if client_call_details.metadata: + metadata = list(client_call_details.metadata) + metadata.append((header, value,)) + client_call_details = _ClientCallDetails( + client_call_details.method, client_call_details.timeout, metadata, + client_call_details.credentials) + return client_call_details, request_iterator, None + + return _GenericClientInterceptor(intercept_call) + + +class _GenericServerInterceptor(grpc.ServerInterceptor): + + def __init__(self, fn): + self._fn = fn + + def intercept_service(self, continuation, handler_call_details): + return self._fn(continuation, handler_call_details) + + +def _filter_server_interceptor(condition, interceptor): + + def intercept_service(continuation, handler_call_details): + if condition(handler_call_details): + return interceptor.intercept_service(continuation, + handler_call_details) + return continuation(handler_call_details) + + return _GenericServerInterceptor(intercept_service) + + +class InterceptorTest(unittest.TestCase): + + def setUp(self): + self._control = test_control.PauseFailControl() + self._handler = _Handler(self._control) + self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) + + self._record = [] + conditional_interceptor = _filter_server_interceptor( + lambda x: ('secret', '42') in x.invocation_metadata, + _LoggingInterceptor('s3', self._record)) + + self._server = grpc.server( + self._server_pool, + interceptors=(_LoggingInterceptor('s1', self._record), + conditional_interceptor, + _LoggingInterceptor('s2', self._record),)) + port = self._server.add_insecure_port('[::]:0') + self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),)) + self._server.start() + + self._channel = grpc.insecure_channel('localhost:%d' % port) + + def tearDown(self): + self._server.stop(None) + self._server_pool.shutdown(wait=True) + + def testTripleRequestMessagesClientInterceptor(self): + + def triple(request_iterator): + while True: + try: + item = next(request_iterator) + yield item + yield item + yield item + except StopIteration: + break + + interceptor = _wrap_request_iterator_stream_interceptor(triple) + channel = grpc.intercept_channel(self._channel, interceptor) + requests = tuple(b'\x07\x08' + for _ in range(test_constants.STREAM_LENGTH)) + + multi_callable = _stream_stream_multi_callable(channel) + response_iterator = multi_callable( + iter(requests), + metadata=( + ('test', + 'InterceptedStreamRequestBlockingUnaryResponseWithCall'),)) + + responses = tuple(response_iterator) + self.assertEqual(len(responses), 3 * test_constants.STREAM_LENGTH) + + multi_callable = _stream_stream_multi_callable(self._channel) + response_iterator = multi_callable( + iter(requests), + metadata=( + ('test', + 'InterceptedStreamRequestBlockingUnaryResponseWithCall'),)) + + responses = tuple(response_iterator) + self.assertEqual(len(responses), test_constants.STREAM_LENGTH) + + def testDefectiveClientInterceptor(self): + interceptor = _DefectiveClientInterceptor() + defective_channel = grpc.intercept_channel(self._channel, interceptor) + + request = b'\x07\x08' + + multi_callable = _unary_unary_multi_callable(defective_channel) + call_future = multi_callable.future( + request, + metadata=( + ('test', 'InterceptedUnaryRequestBlockingUnaryResponse'),)) + + self.assertIsNotNone(call_future.exception()) + self.assertEqual(call_future.code(), grpc.StatusCode.INTERNAL) + + def testInterceptedHeaderManipulationWithServerSideVerification(self): + request = b'\x07\x08' + + channel = grpc.intercept_channel( + self._channel, _append_request_header_interceptor('secret', '42')) + channel = grpc.intercept_channel( + channel, + _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + self._record[:] = [] + + multi_callable = _unary_unary_multi_callable(channel) + multi_callable.with_call( + request, + metadata=( + ('test', + 'InterceptedUnaryRequestBlockingUnaryResponseWithCall'),)) + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_unary_unary', 'c2:intercept_unary_unary', + 's1:intercept_service', 's3:intercept_service', + 's2:intercept_service' + ]) + + def testInterceptedUnaryRequestBlockingUnaryResponse(self): + request = b'\x07\x08' + + self._record[:] = [] + + channel = grpc.intercept_channel( + self._channel, + _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + multi_callable = _unary_unary_multi_callable(channel) + multi_callable( + request, + metadata=( + ('test', 'InterceptedUnaryRequestBlockingUnaryResponse'),)) + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_unary_unary', 'c2:intercept_unary_unary', + 's1:intercept_service', 's2:intercept_service' + ]) + + def testInterceptedUnaryRequestBlockingUnaryResponseWithCall(self): + request = b'\x07\x08' + + channel = grpc.intercept_channel( + self._channel, + _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + self._record[:] = [] + + multi_callable = _unary_unary_multi_callable(channel) + multi_callable.with_call( + request, + metadata=( + ('test', + 'InterceptedUnaryRequestBlockingUnaryResponseWithCall'),)) + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_unary_unary', 'c2:intercept_unary_unary', + 's1:intercept_service', 's2:intercept_service' + ]) + + def testInterceptedUnaryRequestFutureUnaryResponse(self): + request = b'\x07\x08' + + self._record[:] = [] + channel = grpc.intercept_channel( + self._channel, + _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + multi_callable = _unary_unary_multi_callable(channel) + response_future = multi_callable.future( + request, + metadata=(('test', 'InterceptedUnaryRequestFutureUnaryResponse'),)) + response_future.result() + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_unary_unary', 'c2:intercept_unary_unary', + 's1:intercept_service', 's2:intercept_service' + ]) + + def testInterceptedUnaryRequestStreamResponse(self): + request = b'\x37\x58' + + self._record[:] = [] + channel = grpc.intercept_channel( + self._channel, + _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + multi_callable = _unary_stream_multi_callable(channel) + response_iterator = multi_callable( + request, + metadata=(('test', 'InterceptedUnaryRequestStreamResponse'),)) + tuple(response_iterator) + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_unary_stream', 'c2:intercept_unary_stream', + 's1:intercept_service', 's2:intercept_service' + ]) + + def testInterceptedStreamRequestBlockingUnaryResponse(self): + requests = tuple(b'\x07\x08' + for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + self._record[:] = [] + channel = grpc.intercept_channel( + self._channel, + _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + multi_callable = _stream_unary_multi_callable(channel) + multi_callable( + request_iterator, + metadata=( + ('test', 'InterceptedStreamRequestBlockingUnaryResponse'),)) + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_stream_unary', 'c2:intercept_stream_unary', + 's1:intercept_service', 's2:intercept_service' + ]) + + def testInterceptedStreamRequestBlockingUnaryResponseWithCall(self): + requests = tuple(b'\x07\x08' + for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + self._record[:] = [] + channel = grpc.intercept_channel( + self._channel, + _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + multi_callable = _stream_unary_multi_callable(channel) + multi_callable.with_call( + request_iterator, + metadata=( + ('test', + 'InterceptedStreamRequestBlockingUnaryResponseWithCall'),)) + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_stream_unary', 'c2:intercept_stream_unary', + 's1:intercept_service', 's2:intercept_service' + ]) + + def testInterceptedStreamRequestFutureUnaryResponse(self): + requests = tuple(b'\x07\x08' + for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + self._record[:] = [] + channel = grpc.intercept_channel( + self._channel, + _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + multi_callable = _stream_unary_multi_callable(channel) + response_future = multi_callable.future( + request_iterator, + metadata=(('test', 'InterceptedStreamRequestFutureUnaryResponse'),)) + response_future.result() + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_stream_unary', 'c2:intercept_stream_unary', + 's1:intercept_service', 's2:intercept_service' + ]) + + def testInterceptedStreamRequestStreamResponse(self): + requests = tuple(b'\x77\x58' + for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + self._record[:] = [] + channel = grpc.intercept_channel( + self._channel, + _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + multi_callable = _stream_stream_multi_callable(channel) + response_iterator = multi_callable( + request_iterator, + metadata=(('test', 'InterceptedStreamRequestStreamResponse'),)) + tuple(response_iterator) + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_stream_stream', 'c2:intercept_stream_stream', + 's1:intercept_service', 's2:intercept_service' + ]) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py index 6faab94be6..cb59cd3769 100644 --- a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py +++ b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py @@ -56,6 +56,7 @@ class _Servicer(object): def __init__(self): self._lock = threading.Lock() + self._abort_call = False self._code = None self._details = None self._exception = False @@ -67,10 +68,13 @@ class _Servicer(object): self._received_client_metadata = context.invocation_metadata() context.send_initial_metadata(_SERVER_INITIAL_METADATA) context.set_trailing_metadata(_SERVER_TRAILING_METADATA) - if self._code is not None: - context.set_code(self._code) - if self._details is not None: - context.set_details(self._details) + if self._abort_call: + context.abort(self._code, self._details) + else: + if self._code is not None: + context.set_code(self._code) + if self._details is not None: + context.set_details(self._details) if self._exception: raise test_control.Defect() else: @@ -81,10 +85,13 @@ class _Servicer(object): self._received_client_metadata = context.invocation_metadata() context.send_initial_metadata(_SERVER_INITIAL_METADATA) context.set_trailing_metadata(_SERVER_TRAILING_METADATA) - if self._code is not None: - context.set_code(self._code) - if self._details is not None: - context.set_details(self._details) + if self._abort_call: + context.abort(self._code, self._details) + else: + if self._code is not None: + context.set_code(self._code) + if self._details is not None: + context.set_details(self._details) for _ in range(test_constants.STREAM_LENGTH // 2): yield _SERIALIZED_RESPONSE if self._exception: @@ -95,14 +102,16 @@ class _Servicer(object): self._received_client_metadata = context.invocation_metadata() context.send_initial_metadata(_SERVER_INITIAL_METADATA) context.set_trailing_metadata(_SERVER_TRAILING_METADATA) - if self._code is not None: - context.set_code(self._code) - if self._details is not None: - context.set_details(self._details) # TODO(https://github.com/grpc/grpc/issues/6891): just ignore the # request iterator. - for ignored_request in request_iterator: - pass + list(request_iterator) + if self._abort_call: + context.abort(self._code, self._details) + else: + if self._code is not None: + context.set_code(self._code) + if self._details is not None: + context.set_details(self._details) if self._exception: raise test_control.Defect() else: @@ -113,19 +122,25 @@ class _Servicer(object): self._received_client_metadata = context.invocation_metadata() context.send_initial_metadata(_SERVER_INITIAL_METADATA) context.set_trailing_metadata(_SERVER_TRAILING_METADATA) - if self._code is not None: - context.set_code(self._code) - if self._details is not None: - context.set_details(self._details) # TODO(https://github.com/grpc/grpc/issues/6891): just ignore the # request iterator. - for ignored_request in request_iterator: - pass + list(request_iterator) + if self._abort_call: + context.abort(self._code, self._details) + else: + if self._code is not None: + context.set_code(self._code) + if self._details is not None: + context.set_details(self._details) for _ in range(test_constants.STREAM_LENGTH // 3): yield object() if self._exception: raise test_control.Defect() + def set_abort_call(self): + with self._lock: + self._abort_call = True + def set_code(self, code): with self._lock: self._code = code @@ -212,11 +227,10 @@ class MetadataCodeDetailsTest(unittest.TestCase): def testSuccessfulUnaryStream(self): self._servicer.set_details(_DETAILS) - call = self._unary_stream( + response_iterator_call = self._unary_stream( _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA) - received_initial_metadata = call.initial_metadata() - for _ in call: - pass + received_initial_metadata = response_iterator_call.initial_metadata() + list(response_iterator_call) self.assertTrue( test_common.metadata_transmitted( @@ -225,10 +239,11 @@ class MetadataCodeDetailsTest(unittest.TestCase): test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, received_initial_metadata)) self.assertTrue( - test_common.metadata_transmitted(_SERVER_TRAILING_METADATA, - call.trailing_metadata())) - self.assertIs(grpc.StatusCode.OK, call.code()) - self.assertEqual(_DETAILS, call.details()) + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + response_iterator_call.trailing_metadata())) + self.assertIs(grpc.StatusCode.OK, response_iterator_call.code()) + self.assertEqual(_DETAILS, response_iterator_call.details()) def testSuccessfulStreamUnary(self): self._servicer.set_details(_DETAILS) @@ -252,12 +267,11 @@ class MetadataCodeDetailsTest(unittest.TestCase): def testSuccessfulStreamStream(self): self._servicer.set_details(_DETAILS) - call = self._stream_stream( + response_iterator_call = self._stream_stream( iter([object()] * test_constants.STREAM_LENGTH), metadata=_CLIENT_METADATA) - received_initial_metadata = call.initial_metadata() - for _ in call: - pass + received_initial_metadata = response_iterator_call.initial_metadata() + list(response_iterator_call) self.assertTrue( test_common.metadata_transmitted( @@ -266,10 +280,106 @@ class MetadataCodeDetailsTest(unittest.TestCase): test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, received_initial_metadata)) self.assertTrue( - test_common.metadata_transmitted(_SERVER_TRAILING_METADATA, - call.trailing_metadata())) - self.assertIs(grpc.StatusCode.OK, call.code()) - self.assertEqual(_DETAILS, call.details()) + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + response_iterator_call.trailing_metadata())) + self.assertIs(grpc.StatusCode.OK, response_iterator_call.code()) + self.assertEqual(_DETAILS, response_iterator_call.details()) + + def testAbortedUnaryUnary(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_abort_call() + + with self.assertRaises(grpc.RpcError) as exception_context: + self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, + exception_context.exception.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(_NON_OK_CODE, exception_context.exception.code()) + self.assertEqual(_DETAILS, exception_context.exception.details()) + + def testAbortedUnaryStream(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_abort_call() + + response_iterator_call = self._unary_stream( + _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA) + received_initial_metadata = response_iterator_call.initial_metadata() + with self.assertRaises(grpc.RpcError): + self.assertEqual(len(list(response_iterator_call)), 0) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, + received_initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + response_iterator_call.trailing_metadata())) + self.assertIs(_NON_OK_CODE, response_iterator_call.code()) + self.assertEqual(_DETAILS, response_iterator_call.details()) + + def testAbortedStreamUnary(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_abort_call() + + with self.assertRaises(grpc.RpcError) as exception_context: + self._stream_unary.with_call( + iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, + exception_context.exception.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(_NON_OK_CODE, exception_context.exception.code()) + self.assertEqual(_DETAILS, exception_context.exception.details()) + + def testAbortedStreamStream(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_abort_call() + + response_iterator_call = self._stream_stream( + iter([object()] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + received_initial_metadata = response_iterator_call.initial_metadata() + with self.assertRaises(grpc.RpcError): + self.assertEqual(len(list(response_iterator_call)), 0) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, + received_initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + response_iterator_call.trailing_metadata())) + self.assertIs(_NON_OK_CODE, response_iterator_call.code()) + self.assertEqual(_DETAILS, response_iterator_call.details()) def testCustomCodeUnaryUnary(self): self._servicer.set_code(_NON_OK_CODE) @@ -296,12 +406,11 @@ class MetadataCodeDetailsTest(unittest.TestCase): self._servicer.set_code(_NON_OK_CODE) self._servicer.set_details(_DETAILS) - call = self._unary_stream( + response_iterator_call = self._unary_stream( _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA) - received_initial_metadata = call.initial_metadata() + received_initial_metadata = response_iterator_call.initial_metadata() with self.assertRaises(grpc.RpcError): - for _ in call: - pass + list(response_iterator_call) self.assertTrue( test_common.metadata_transmitted( @@ -310,10 +419,11 @@ class MetadataCodeDetailsTest(unittest.TestCase): test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, received_initial_metadata)) self.assertTrue( - test_common.metadata_transmitted(_SERVER_TRAILING_METADATA, - call.trailing_metadata())) - self.assertIs(_NON_OK_CODE, call.code()) - self.assertEqual(_DETAILS, call.details()) + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + response_iterator_call.trailing_metadata())) + self.assertIs(_NON_OK_CODE, response_iterator_call.code()) + self.assertEqual(_DETAILS, response_iterator_call.details()) def testCustomCodeStreamUnary(self): self._servicer.set_code(_NON_OK_CODE) @@ -342,13 +452,12 @@ class MetadataCodeDetailsTest(unittest.TestCase): self._servicer.set_code(_NON_OK_CODE) self._servicer.set_details(_DETAILS) - call = self._stream_stream( + response_iterator_call = self._stream_stream( iter([object()] * test_constants.STREAM_LENGTH), metadata=_CLIENT_METADATA) - received_initial_metadata = call.initial_metadata() + received_initial_metadata = response_iterator_call.initial_metadata() with self.assertRaises(grpc.RpcError) as exception_context: - for _ in call: - pass + list(response_iterator_call) self.assertTrue( test_common.metadata_transmitted( @@ -390,12 +499,11 @@ class MetadataCodeDetailsTest(unittest.TestCase): self._servicer.set_details(_DETAILS) self._servicer.set_exception() - call = self._unary_stream( + response_iterator_call = self._unary_stream( _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA) - received_initial_metadata = call.initial_metadata() + received_initial_metadata = response_iterator_call.initial_metadata() with self.assertRaises(grpc.RpcError): - for _ in call: - pass + list(response_iterator_call) self.assertTrue( test_common.metadata_transmitted( @@ -404,10 +512,11 @@ class MetadataCodeDetailsTest(unittest.TestCase): test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, received_initial_metadata)) self.assertTrue( - test_common.metadata_transmitted(_SERVER_TRAILING_METADATA, - call.trailing_metadata())) - self.assertIs(_NON_OK_CODE, call.code()) - self.assertEqual(_DETAILS, call.details()) + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + response_iterator_call.trailing_metadata())) + self.assertIs(_NON_OK_CODE, response_iterator_call.code()) + self.assertEqual(_DETAILS, response_iterator_call.details()) def testCustomCodeExceptionStreamUnary(self): self._servicer.set_code(_NON_OK_CODE) @@ -438,13 +547,12 @@ class MetadataCodeDetailsTest(unittest.TestCase): self._servicer.set_details(_DETAILS) self._servicer.set_exception() - call = self._stream_stream( + response_iterator_call = self._stream_stream( iter([object()] * test_constants.STREAM_LENGTH), metadata=_CLIENT_METADATA) - received_initial_metadata = call.initial_metadata() + received_initial_metadata = response_iterator_call.initial_metadata() with self.assertRaises(grpc.RpcError): - for _ in call: - pass + list(response_iterator_call) self.assertTrue( test_common.metadata_transmitted( @@ -453,10 +561,11 @@ class MetadataCodeDetailsTest(unittest.TestCase): test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, received_initial_metadata)) self.assertTrue( - test_common.metadata_transmitted(_SERVER_TRAILING_METADATA, - call.trailing_metadata())) - self.assertIs(_NON_OK_CODE, call.code()) - self.assertEqual(_DETAILS, call.details()) + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + response_iterator_call.trailing_metadata())) + self.assertIs(_NON_OK_CODE, response_iterator_call.code()) + self.assertEqual(_DETAILS, response_iterator_call.details()) def testCustomCodeReturnNoneUnaryUnary(self): self._servicer.set_code(_NON_OK_CODE) diff --git a/src/ruby/ext/grpc/extconf.rb b/src/ruby/ext/grpc/extconf.rb index 9d2cf2a08a..c1a0c56841 100644 --- a/src/ruby/ext/grpc/extconf.rb +++ b/src/ruby/ext/grpc/extconf.rb @@ -61,7 +61,7 @@ ENV['EMBED_ZLIB'] = 'true' ENV['EMBED_CARES'] = 'true' ENV['ARCH_FLAGS'] = RbConfig::CONFIG['ARCH_FLAG'] ENV['ARCH_FLAGS'] = '-arch i386 -arch x86_64' if RUBY_PLATFORM =~ /darwin/ -ENV['CFLAGS'] = '-DGPR_BACKWARDS_COMPATIBILITY_MODE' +ENV['CPPFLAGS'] = '-DGPR_BACKWARDS_COMPATIBILITY_MODE' output_dir = File.expand_path(RbConfig::CONFIG['topdir']) grpc_lib_dir = File.join(output_dir, 'libs', grpc_config) |