diff options
26 files changed, 178 insertions, 98 deletions
diff --git a/src/core/ext/transport/chttp2/transport/chttp2_transport.cc b/src/core/ext/transport/chttp2/transport/chttp2_transport.cc index 9b6574b612..7f4627fa77 100644 --- a/src/core/ext/transport/chttp2/transport/chttp2_transport.cc +++ b/src/core/ext/transport/chttp2/transport/chttp2_transport.cc @@ -170,7 +170,12 @@ grpc_chttp2_transport::~grpc_chttp2_transport() { grpc_slice_buffer_destroy_internal(&outbuf); grpc_chttp2_hpack_compressor_destroy(&hpack_compressor); - grpc_core::ContextList::Execute(cl, nullptr, GRPC_ERROR_NONE); + grpc_error* error = + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Transport destroyed"); + // ContextList::Execute follows semantics of a callback function and does not + // take a ref on error + grpc_core::ContextList::Execute(cl, nullptr, error); + GRPC_ERROR_UNREF(error); cl = nullptr; grpc_slice_buffer_destroy_internal(&read_buffer); diff --git a/src/core/ext/transport/chttp2/transport/context_list.cc b/src/core/ext/transport/chttp2/transport/context_list.cc index f30d41c332..df09809067 100644 --- a/src/core/ext/transport/chttp2/transport/context_list.cc +++ b/src/core/ext/transport/chttp2/transport/context_list.cc @@ -21,31 +21,47 @@ #include "src/core/ext/transport/chttp2/transport/context_list.h" namespace { -void (*write_timestamps_callback_g)(void*, grpc_core::Timestamps*) = nullptr; -} +void (*write_timestamps_callback_g)(void*, grpc_core::Timestamps*, + grpc_error* error) = nullptr; +void* (*get_copied_context_fn_g)(void*) = nullptr; +} // namespace namespace grpc_core { +void ContextList::Append(ContextList** head, grpc_chttp2_stream* s) { + if (get_copied_context_fn_g == nullptr || + write_timestamps_callback_g == nullptr) { + return; + } + /* Create a new element in the list and add it at the front */ + ContextList* elem = grpc_core::New<ContextList>(); + elem->trace_context_ = get_copied_context_fn_g(s->context); + elem->byte_offset_ = s->byte_counter; + elem->next_ = *head; + *head = elem; +} + void ContextList::Execute(void* arg, grpc_core::Timestamps* ts, grpc_error* error) { ContextList* head = static_cast<ContextList*>(arg); ContextList* to_be_freed; while (head != nullptr) { - if (error == GRPC_ERROR_NONE && ts != nullptr) { - if (write_timestamps_callback_g) { - ts->byte_offset = static_cast<uint32_t>(head->byte_offset_); - write_timestamps_callback_g(head->s_->context, ts); - } + if (write_timestamps_callback_g) { + ts->byte_offset = static_cast<uint32_t>(head->byte_offset_); + write_timestamps_callback_g(head->trace_context_, ts, error); } - GRPC_CHTTP2_STREAM_UNREF(static_cast<grpc_chttp2_stream*>(head->s_), - "timestamp"); to_be_freed = head; head = head->next_; grpc_core::Delete(to_be_freed); } } -void grpc_http2_set_write_timestamps_callback( - void (*fn)(void*, grpc_core::Timestamps*)) { +void grpc_http2_set_write_timestamps_callback(void (*fn)(void*, + grpc_core::Timestamps*, + grpc_error* error)) { write_timestamps_callback_g = fn; } + +void grpc_http2_set_fn_get_copied_context(void* (*fn)(void*)) { + get_copied_context_fn_g = fn; +} } /* namespace grpc_core */ diff --git a/src/core/ext/transport/chttp2/transport/context_list.h b/src/core/ext/transport/chttp2/transport/context_list.h index d870107749..5b9d2ab378 100644 --- a/src/core/ext/transport/chttp2/transport/context_list.h +++ b/src/core/ext/transport/chttp2/transport/context_list.h @@ -31,42 +31,23 @@ class ContextList { public: /* Creates a new element with \a context as the value and appends it to the * list. */ - static void Append(ContextList** head, grpc_chttp2_stream* s) { - /* Make sure context is not already present */ - GRPC_CHTTP2_STREAM_REF(s, "timestamp"); - -#ifndef NDEBUG - ContextList* ptr = *head; - while (ptr != nullptr) { - if (ptr->s_ == s) { - GPR_ASSERT( - false && - "Trying to append a stream that is already present in the list"); - } - ptr = ptr->next_; - } -#endif - - /* Create a new element in the list and add it at the front */ - ContextList* elem = grpc_core::New<ContextList>(); - elem->s_ = s; - elem->byte_offset_ = s->byte_counter; - elem->next_ = *head; - *head = elem; - } + static void Append(ContextList** head, grpc_chttp2_stream* s); /* Executes a function \a fn with each context in the list and \a ts. It also - * frees up the entire list after this operation. */ + * frees up the entire list after this operation. It is intended as a callback + * and hence does not take a ref on \a error */ static void Execute(void* arg, grpc_core::Timestamps* ts, grpc_error* error); private: - grpc_chttp2_stream* s_ = nullptr; + void* trace_context_ = nullptr; ContextList* next_ = nullptr; size_t byte_offset_ = 0; }; -void grpc_http2_set_write_timestamps_callback( - void (*fn)(void*, grpc_core::Timestamps*)); +void grpc_http2_set_write_timestamps_callback(void (*fn)(void*, + grpc_core::Timestamps*, + grpc_error* error)); +void grpc_http2_set_fn_get_copied_context(void* (*fn)(void*)); } /* namespace grpc_core */ #endif /* GRPC_CORE_EXT_TRANSPORT_CHTTP2_TRANSPORT_CONTEXT_LIST_H */ diff --git a/src/cpp/common/channel_arguments.cc b/src/cpp/common/channel_arguments.cc index 50ee9d871f..214d72f853 100644 --- a/src/cpp/common/channel_arguments.cc +++ b/src/cpp/common/channel_arguments.cc @@ -106,7 +106,9 @@ void ChannelArguments::SetSocketMutator(grpc_socket_mutator* mutator) { } if (!replaced) { + strings_.push_back(grpc::string(mutator_arg.key)); args_.push_back(mutator_arg); + args_.back().key = const_cast<char*>(strings_.back().c_str()); } } diff --git a/src/python/grpcio_tests/commands.py b/src/python/grpcio_tests/commands.py index 496bcfbcbf..d5327711d3 100644 --- a/src/python/grpcio_tests/commands.py +++ b/src/python/grpcio_tests/commands.py @@ -133,6 +133,7 @@ class TestGevent(setuptools.Command): # This test will stuck while running higher version of gevent 'unit._auth_context_test.AuthContextTest.testSessionResumption', # TODO(https://github.com/grpc/grpc/issues/15411) enable these tests + 'unit._metadata_flags_test', 'unit._exit_test.ExitTest.test_in_flight_unary_unary_call', 'unit._exit_test.ExitTest.test_in_flight_unary_stream_call', 'unit._exit_test.ExitTest.test_in_flight_stream_unary_call', diff --git a/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py b/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py index 350b5eebe5..c1d9436c2f 100644 --- a/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py +++ b/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py @@ -39,8 +39,12 @@ class HealthServicerTest(unittest.TestCase): health_pb2_grpc.add_HealthServicer_to_server(servicer, self._server) self._server.start() - channel = grpc.insecure_channel('localhost:%d' % port) - self._stub = health_pb2_grpc.HealthStub(channel) + self._channel = grpc.insecure_channel('localhost:%d' % port) + self._stub = health_pb2_grpc.HealthStub(self._channel) + + def tearDown(self): + self._server.stop(None) + self._channel.close() def test_empty_service(self): request = health_pb2.HealthCheckRequest() diff --git a/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py b/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py index bcd9e14a38..560f6d3ddb 100644 --- a/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py +++ b/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py @@ -56,8 +56,12 @@ class ReflectionServicerTest(unittest.TestCase): port = self._server.add_insecure_port('[::]:0') self._server.start() - channel = grpc.insecure_channel('localhost:%d' % port) - self._stub = reflection_pb2_grpc.ServerReflectionStub(channel) + self._channel = grpc.insecure_channel('localhost:%d' % port) + self._stub = reflection_pb2_grpc.ServerReflectionStub(self._channel) + + def tearDown(self): + self._server.stop(None) + self._channel.close() def testFileByName(self): requests = ( diff --git a/src/python/grpcio_tests/tests/unit/_api_test.py b/src/python/grpcio_tests/tests/unit/_api_test.py index 427894bfe9..0dc6a8718c 100644 --- a/src/python/grpcio_tests/tests/unit/_api_test.py +++ b/src/python/grpcio_tests/tests/unit/_api_test.py @@ -101,6 +101,7 @@ class ChannelTest(unittest.TestCase): def test_secure_channel(self): channel_credentials = grpc.ssl_channel_credentials() channel = grpc.secure_channel('google.com:443', channel_credentials) + channel.close() if __name__ == '__main__': diff --git a/src/python/grpcio_tests/tests/unit/_auth_context_test.py b/src/python/grpcio_tests/tests/unit/_auth_context_test.py index b1b5bbdcab..96c4e9ec76 100644 --- a/src/python/grpcio_tests/tests/unit/_auth_context_test.py +++ b/src/python/grpcio_tests/tests/unit/_auth_context_test.py @@ -71,8 +71,8 @@ class AuthContextTest(unittest.TestCase): port = server.add_insecure_port('[::]:0') server.start() - channel = grpc.insecure_channel('localhost:%d' % port) - response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) + with grpc.insecure_channel('localhost:%d' % port) as channel: + response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) server.stop(None) auth_data = pickle.loads(response) @@ -98,6 +98,7 @@ class AuthContextTest(unittest.TestCase): channel_creds, options=_PROPERTY_OPTIONS) response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) + channel.close() server.stop(None) auth_data = pickle.loads(response) @@ -132,6 +133,7 @@ class AuthContextTest(unittest.TestCase): options=_PROPERTY_OPTIONS) response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) + channel.close() server.stop(None) auth_data = pickle.loads(response) diff --git a/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py b/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py index 727fb7d65f..565bd39b3a 100644 --- a/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py +++ b/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py @@ -75,6 +75,8 @@ class ChannelConnectivityTest(unittest.TestCase): channel.unsubscribe(callback.update) fifth_connectivities = callback.connectivities() + channel.close() + self.assertSequenceEqual((grpc.ChannelConnectivity.IDLE,), first_connectivities) self.assertNotIn(grpc.ChannelConnectivity.READY, second_connectivities) @@ -108,7 +110,8 @@ class ChannelConnectivityTest(unittest.TestCase): _ready_in_connectivities) second_callback.block_until_connectivities_satisfy( _ready_in_connectivities) - del channel + channel.close() + server.stop(None) self.assertSequenceEqual((grpc.ChannelConnectivity.IDLE,), first_connectivities) @@ -139,6 +142,7 @@ class ChannelConnectivityTest(unittest.TestCase): callback.block_until_connectivities_satisfy( _last_connectivity_is_not_ready) channel.unsubscribe(callback.update) + channel.close() self.assertFalse(thread_pool.was_used()) diff --git a/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py b/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py index 345460ef40..46a4eb9bb6 100644 --- a/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py +++ b/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py @@ -60,6 +60,8 @@ class ChannelReadyFutureTest(unittest.TestCase): self.assertTrue(ready_future.done()) self.assertFalse(ready_future.running()) + channel.close() + def test_immediately_connectable_channel_connectivity(self): thread_pool = _thread_pool.RecordingThreadPool(max_workers=None) server = grpc.server(thread_pool, options=(('grpc.so_reuseport', 0),)) @@ -84,6 +86,9 @@ class ChannelReadyFutureTest(unittest.TestCase): self.assertFalse(ready_future.running()) self.assertFalse(thread_pool.was_used()) + channel.close() + server.stop(None) + if __name__ == '__main__': logging.basicConfig() diff --git a/src/python/grpcio_tests/tests/unit/_compression_test.py b/src/python/grpcio_tests/tests/unit/_compression_test.py index 876d8e827e..87884a19dc 100644 --- a/src/python/grpcio_tests/tests/unit/_compression_test.py +++ b/src/python/grpcio_tests/tests/unit/_compression_test.py @@ -77,6 +77,9 @@ class CompressionTest(unittest.TestCase): self._port = self._server.add_insecure_port('[::]:0') self._server.start() + def tearDown(self): + self._server.stop(None) + def testUnary(self): request = b'\x00' * 100 @@ -102,6 +105,7 @@ class CompressionTest(unittest.TestCase): response = multi_callable( request, metadata=[('grpc-internal-encoding-request', 'gzip')]) self.assertEqual(request, response) + compressed_channel.close() def testStreaming(self): request = b'\x00' * 100 @@ -115,6 +119,7 @@ class CompressionTest(unittest.TestCase): call = multi_callable(iter([request] * test_constants.STREAM_LENGTH)) for response in call: self.assertEqual(request, response) + compressed_channel.close() if __name__ == '__main__': diff --git a/src/python/grpcio_tests/tests/unit/_empty_message_test.py b/src/python/grpcio_tests/tests/unit/_empty_message_test.py index 3e8393b53c..f27ea422d0 100644 --- a/src/python/grpcio_tests/tests/unit/_empty_message_test.py +++ b/src/python/grpcio_tests/tests/unit/_empty_message_test.py @@ -96,6 +96,7 @@ class EmptyMessageTest(unittest.TestCase): def tearDown(self): self._server.stop(0) + self._channel.close() def testUnaryUnary(self): response = self._channel.unary_unary(_UNARY_UNARY)(_REQUEST) diff --git a/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py b/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py index 6c551df3ec..81de1dae1d 100644 --- a/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py +++ b/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py @@ -71,6 +71,7 @@ class ErrorMessageEncodingTest(unittest.TestCase): def tearDown(self): self._server.stop(0) + self._channel.close() def testMessageEncoding(self): for message in _UNICODE_ERROR_MESSAGES: diff --git a/src/python/grpcio_tests/tests/unit/_interceptor_test.py b/src/python/grpcio_tests/tests/unit/_interceptor_test.py index 99db0ac58b..a647e5e720 100644 --- a/src/python/grpcio_tests/tests/unit/_interceptor_test.py +++ b/src/python/grpcio_tests/tests/unit/_interceptor_test.py @@ -337,6 +337,7 @@ class InterceptorTest(unittest.TestCase): def tearDown(self): self._server.stop(None) self._server_pool.shutdown(wait=True) + self._channel.close() def testTripleRequestMessagesClientInterceptor(self): diff --git a/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py b/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py index 0ff49490d5..7ed7c83893 100644 --- a/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py +++ b/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py @@ -62,6 +62,9 @@ class InvalidMetadataTest(unittest.TestCase): self._stream_unary = _stream_unary_multi_callable(self._channel) self._stream_stream = _stream_stream_multi_callable(self._channel) + def tearDown(self): + self._channel.close() + def testUnaryRequestBlockingUnaryResponse(self): request = b'\x07\x08' metadata = (('InVaLiD', 'UnaryRequestBlockingUnaryResponse'),) diff --git a/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py b/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py index 00949e2236..e89b521cc5 100644 --- a/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py +++ b/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py @@ -215,6 +215,7 @@ class InvocationDefectsTest(unittest.TestCase): def tearDown(self): self._server.stop(0) + self._channel.close() def testIterableStreamRequestBlockingUnaryResponse(self): requests = [b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)] diff --git a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py index 0dafab827a..a63664ac5d 100644 --- a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py +++ b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py @@ -198,8 +198,8 @@ class MetadataCodeDetailsTest(unittest.TestCase): port = self._server.add_insecure_port('[::]:0') self._server.start() - channel = grpc.insecure_channel('localhost:{}'.format(port)) - self._unary_unary = channel.unary_unary( + self._channel = grpc.insecure_channel('localhost:{}'.format(port)) + self._unary_unary = self._channel.unary_unary( '/'.join(( '', _SERVICE, @@ -208,17 +208,17 @@ class MetadataCodeDetailsTest(unittest.TestCase): request_serializer=_REQUEST_SERIALIZER, response_deserializer=_RESPONSE_DESERIALIZER, ) - self._unary_stream = channel.unary_stream('/'.join(( + self._unary_stream = self._channel.unary_stream('/'.join(( '', _SERVICE, _UNARY_STREAM, )),) - self._stream_unary = channel.stream_unary('/'.join(( + self._stream_unary = self._channel.stream_unary('/'.join(( '', _SERVICE, _STREAM_UNARY, )),) - self._stream_stream = channel.stream_stream( + self._stream_stream = self._channel.stream_stream( '/'.join(( '', _SERVICE, @@ -228,6 +228,10 @@ class MetadataCodeDetailsTest(unittest.TestCase): response_deserializer=_RESPONSE_DESERIALIZER, ) + def tearDown(self): + self._server.stop(None) + self._channel.close() + def testSuccessfulUnaryUnary(self): self._servicer.set_details(_DETAILS) diff --git a/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py b/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py index 2d352e99d4..7b32b5b5f3 100644 --- a/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py +++ b/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py @@ -187,13 +187,14 @@ class MetadataFlagsTest(unittest.TestCase): def test_call_wait_for_ready_default(self): for perform_call in _ALL_CALL_CASES: - self.check_connection_does_failfast(perform_call, - create_dummy_channel()) + with create_dummy_channel() as channel: + self.check_connection_does_failfast(perform_call, channel) def test_call_wait_for_ready_disabled(self): for perform_call in _ALL_CALL_CASES: - self.check_connection_does_failfast( - perform_call, create_dummy_channel(), wait_for_ready=False) + with create_dummy_channel() as channel: + self.check_connection_does_failfast( + perform_call, channel, wait_for_ready=False) def test_call_wait_for_ready_enabled(self): # To test the wait mechanism, Python thread is required to make @@ -210,16 +211,16 @@ class MetadataFlagsTest(unittest.TestCase): wg.done() def test_call(perform_call): - try: - channel = grpc.insecure_channel(addr) - channel.subscribe(wait_for_transient_failure) - perform_call(channel, wait_for_ready=True) - except BaseException as e: # pylint: disable=broad-except - # If the call failed, the thread would be destroyed. The channel - # object can be collected before calling the callback, which - # will result in a deadlock. - wg.done() - unhandled_exceptions.put(e, True) + with grpc.insecure_channel(addr) as channel: + try: + channel.subscribe(wait_for_transient_failure) + perform_call(channel, wait_for_ready=True) + except BaseException as e: # pylint: disable=broad-except + # If the call failed, the thread would be destroyed. The + # channel object can be collected before calling the + # callback, which will result in a deadlock. + wg.done() + unhandled_exceptions.put(e, True) test_threads = [] for perform_call in _ALL_CALL_CASES: diff --git a/src/python/grpcio_tests/tests/unit/_metadata_test.py b/src/python/grpcio_tests/tests/unit/_metadata_test.py index 777ab683e3..892df3df08 100644 --- a/src/python/grpcio_tests/tests/unit/_metadata_test.py +++ b/src/python/grpcio_tests/tests/unit/_metadata_test.py @@ -186,6 +186,7 @@ class MetadataTest(unittest.TestCase): def tearDown(self): self._server.stop(0) + self._channel.close() def testUnaryUnary(self): multi_callable = self._channel.unary_unary(_UNARY_UNARY) diff --git a/src/python/grpcio_tests/tests/unit/_reconnect_test.py b/src/python/grpcio_tests/tests/unit/_reconnect_test.py index f6d4fcbd0a..d4ea126e2b 100644 --- a/src/python/grpcio_tests/tests/unit/_reconnect_test.py +++ b/src/python/grpcio_tests/tests/unit/_reconnect_test.py @@ -98,6 +98,8 @@ class ReconnectTest(unittest.TestCase): server.add_insecure_port('[::]:{}'.format(port)) server.start() self.assertEqual(_RESPONSE, multi_callable(_REQUEST)) + server.stop(None) + channel.close() if __name__ == '__main__': diff --git a/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py b/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py index 4fead8fcd5..517c2d2f97 100644 --- a/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py +++ b/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py @@ -148,6 +148,7 @@ class ResourceExhaustedTest(unittest.TestCase): def tearDown(self): self._server.stop(0) + self._channel.close() def testUnaryUnary(self): multi_callable = self._channel.unary_unary(_UNARY_UNARY) diff --git a/src/python/grpcio_tests/tests/unit/_rpc_test.py b/src/python/grpcio_tests/tests/unit/_rpc_test.py index a768d6c7c1..a99121cee5 100644 --- a/src/python/grpcio_tests/tests/unit/_rpc_test.py +++ b/src/python/grpcio_tests/tests/unit/_rpc_test.py @@ -193,6 +193,7 @@ class RPCTest(unittest.TestCase): def tearDown(self): self._server.stop(None) + self._channel.close() def testUnrecognizedMethod(self): request = b'abc' diff --git a/test/core/transport/chttp2/context_list_test.cc b/test/core/transport/chttp2/context_list_test.cc index edbe658a89..0379eaaee4 100644 --- a/test/core/transport/chttp2/context_list_test.cc +++ b/test/core/transport/chttp2/context_list_test.cc @@ -36,8 +36,12 @@ namespace { const uint32_t kByteOffset = 123; -void TestExecuteFlushesListVerifier(void* arg, grpc_core::Timestamps* ts) { +void* DummyArgsCopier(void* arg) { return arg; } + +void TestExecuteFlushesListVerifier(void* arg, grpc_core::Timestamps* ts, + grpc_error* error) { ASSERT_NE(arg, nullptr); + EXPECT_EQ(error, GRPC_ERROR_NONE); EXPECT_EQ(ts->byte_offset, kByteOffset); gpr_atm* done = reinterpret_cast<gpr_atm*>(arg); gpr_atm_rel_store(done, static_cast<gpr_atm>(1)); @@ -52,6 +56,7 @@ void discard_write(grpc_slice slice) {} TEST(ContextList, ExecuteFlushesList) { grpc_core::ContextList* list = nullptr; grpc_http2_set_write_timestamps_callback(TestExecuteFlushesListVerifier); + grpc_http2_set_fn_get_copied_context(DummyArgsCopier); const int kNumElems = 5; grpc_core::ExecCtx exec_ctx; grpc_stream_refcount ref; diff --git a/test/cpp/common/channel_arguments_test.cc b/test/cpp/common/channel_arguments_test.cc index 183d2afa78..12fd9784f4 100644 --- a/test/cpp/common/channel_arguments_test.cc +++ b/test/cpp/common/channel_arguments_test.cc @@ -209,6 +209,9 @@ TEST_F(ChannelArgumentsTest, SetSocketMutator) { channel_args_.SetSocketMutator(mutator0); EXPECT_TRUE(HasArg(arg0)); + // Exercise the copy constructor because we ran some sanity checks in it. + grpc::ChannelArguments new_args{channel_args_}; + channel_args_.SetSocketMutator(mutator1); EXPECT_TRUE(HasArg(arg1)); // arg0 is replaced by arg1 diff --git a/test/cpp/qps/client.h b/test/cpp/qps/client.h index 668d941916..73f91eed2d 100644 --- a/test/cpp/qps/client.h +++ b/test/cpp/qps/client.h @@ -429,13 +429,7 @@ class ClientImpl : public Client { config.server_targets(i % config.server_targets_size()), config, create_stub_, i); } - std::vector<std::unique_ptr<std::thread>> connecting_threads; - for (auto& c : channels_) { - connecting_threads.emplace_back(c.WaitForReady()); - } - for (auto& t : connecting_threads) { - t->join(); - } + WaitForChannelsToConnect(); median_latency_collection_interval_seconds_ = config.median_latency_collection_interval_millis() / 1e3; ClientRequestCreator<RequestType> create_req(&request_, @@ -443,6 +437,61 @@ class ClientImpl : public Client { } virtual ~ClientImpl() {} + void WaitForChannelsToConnect() { + int connect_deadline_seconds = 10; + /* Allow optionally overriding connect_deadline in order + * to deal with benchmark environments in which the server + * can take a long time to become ready. */ + char* channel_connect_timeout_str = + gpr_getenv("QPS_WORKER_CHANNEL_CONNECT_TIMEOUT"); + if (channel_connect_timeout_str != nullptr && + strcmp(channel_connect_timeout_str, "") != 0) { + connect_deadline_seconds = atoi(channel_connect_timeout_str); + } + gpr_log(GPR_INFO, + "Waiting for up to %d seconds for all channels to connect", + connect_deadline_seconds); + gpr_free(channel_connect_timeout_str); + gpr_timespec connect_deadline = gpr_time_add( + gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_seconds(connect_deadline_seconds, GPR_TIMESPAN)); + CompletionQueue cq; + size_t num_remaining = 0; + for (auto& c : channels_) { + if (!c.is_inproc()) { + Channel* channel = c.get_channel(); + grpc_connectivity_state last_observed = channel->GetState(true); + if (last_observed == GRPC_CHANNEL_READY) { + gpr_log(GPR_INFO, "Channel %p connected!", channel); + } else { + num_remaining++; + channel->NotifyOnStateChange(last_observed, connect_deadline, &cq, + channel); + } + } + } + while (num_remaining > 0) { + bool ok = false; + void* tag = nullptr; + cq.Next(&tag, &ok); + Channel* channel = static_cast<Channel*>(tag); + if (!ok) { + gpr_log(GPR_ERROR, "Channel %p failed to connect within the deadline", + channel); + abort(); + } else { + grpc_connectivity_state last_observed = channel->GetState(true); + if (last_observed == GRPC_CHANNEL_READY) { + gpr_log(GPR_INFO, "Channel %p connected!", channel); + num_remaining--; + } else { + channel->NotifyOnStateChange(last_observed, connect_deadline, &cq, + channel); + } + } + } + } + protected: const int cores_; RequestType request_; @@ -485,31 +534,7 @@ class ClientImpl : public Client { } Channel* get_channel() { return channel_.get(); } StubType* get_stub() { return stub_.get(); } - - std::unique_ptr<std::thread> WaitForReady() { - return std::unique_ptr<std::thread>(new std::thread([this]() { - if (!is_inproc_) { - int connect_deadline = 10; - /* Allow optionally overriding connect_deadline in order - * to deal with benchmark environments in which the server - * can take a long time to become ready. */ - char* channel_connect_timeout_str = - gpr_getenv("QPS_WORKER_CHANNEL_CONNECT_TIMEOUT"); - if (channel_connect_timeout_str != nullptr && - strcmp(channel_connect_timeout_str, "") != 0) { - connect_deadline = atoi(channel_connect_timeout_str); - } - gpr_log(GPR_INFO, - "Waiting for up to %d seconds for the channel %p to connect", - connect_deadline, channel_.get()); - gpr_free(channel_connect_timeout_str); - GPR_ASSERT(channel_->WaitForConnected(gpr_time_add( - gpr_now(GPR_CLOCK_REALTIME), - gpr_time_from_seconds(connect_deadline, GPR_TIMESPAN)))); - gpr_log(GPR_INFO, "Channel %p connected!", channel_.get()); - } - })); - } + bool is_inproc() { return is_inproc_; } private: void set_channel_args(const ClientConfig& config, ChannelArguments* args) { |