diff options
author | Jiri Simsa <jsimsa@google.com> | 2018-04-24 13:13:18 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-24 13:16:00 -0700 |
commit | 893aa776009418c841d49c924207f3cdaf1d5174 (patch) | |
tree | a11ab86db407e72161d9b4734128604ae3492052 /tensorflow/contrib/rpc | |
parent | 33ffc8e7ff5090b92951c7faac150042dd814085 (diff) |
Fixing concurrency issues in RPC factory.
PiperOrigin-RevId: 194133903
Diffstat (limited to 'tensorflow/contrib/rpc')
3 files changed, 36 insertions, 26 deletions
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/BUILD b/tensorflow/contrib/rpc/python/kernel_tests/BUILD index f3e6731213..2311c15a68 100644 --- a/tensorflow/contrib/rpc/python/kernel_tests/BUILD +++ b/tensorflow/contrib/rpc/python/kernel_tests/BUILD @@ -28,7 +28,6 @@ py_library( py_library( name = "rpc_op_test_base", srcs = ["rpc_op_test_base.py"], - tags = ["notsan"], deps = [ ":test_example_proto_py", "//tensorflow/contrib/proto", diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py index e2e0dbc7a2..3fc6bfbb4d 100644 --- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py +++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py @@ -35,6 +35,7 @@ class RpcOpTest(test.TestCase, rpc_op_test_base.RpcOpTestBase): _protocol = 'grpc' invalid_method_string = 'Method not found' + connect_failed_string = 'Connect Failed' def __init__(self, methodName='runTest'): # pylint: disable=invalid-name super(RpcOpTest, self).__init__(methodName) diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py index 89f3ee1a1c..27273d16b1 100644 --- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py +++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py @@ -93,40 +93,39 @@ class RpcOpTestBase(object): response_values = sess.run(response_tensors) self.assertAllEqual(response_values.shape, [0]) - def testInvalidAddresses(self): - with self.test_session() as sess: - with self.assertRaisesOpError(self.invalid_method_string): - sess.run( - self.rpc( - method='/InvalidService.IncrementTestShapes', - address=self._address, - request='')) + def testInvalidMethod(self): + for method in [ + '/InvalidService.IncrementTestShapes', + self.get_method_name('InvalidMethodName') + ]: + with self.test_session() as sess: + with self.assertRaisesOpError(self.invalid_method_string): + sess.run(self.rpc(method=method, address=self._address, request='')) - with self.assertRaisesOpError(self.invalid_method_string): - sess.run( - self.rpc( - method=self.get_method_name('InvalidMethodName'), - address=self._address, - request='')) + _, status_code_value, status_message_value = sess.run( + self.try_rpc(method=method, address=self._address, request='')) + self.assertEqual(errors.UNIMPLEMENTED, status_code_value) + self.assertTrue( + self.invalid_method_string in status_message_value.decode('ascii')) - # This also covers the case of address='' - # and address='localhost:293874293874' + def testInvalidAddress(self): + # This covers the case of address='' and address='localhost:293874293874' + address = 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@' + with self.test_session() as sess: with self.assertRaises(errors.UnavailableError): sess.run( self.rpc( method=self.get_method_name('IncrementTestShapes'), - address='unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@', + address=address, request='')) - - # Test invalid method with the TryRpc op _, status_code_value, status_message_value = sess.run( self.try_rpc( - method=self.get_method_name('InvalidMethodName'), - address=self._address, + method=self.get_method_name('IncrementTestShapes'), + address=address, request='')) - self.assertEqual(errors.UNIMPLEMENTED, status_code_value) + self.assertEqual(errors.UNAVAILABLE, status_code_value) self.assertTrue( - self.invalid_method_string in status_message_value.decode('ascii')) + self.connect_failed_string in status_message_value.decode('ascii')) def testAlwaysFailingMethod(self): with self.test_session() as sess: @@ -138,6 +137,18 @@ class RpcOpTestBase(object): with self.assertRaisesOpError(I_WARNED_YOU): sess.run(response_tensors) + response_tensors, status_code, status_message = self.try_rpc( + method=self.get_method_name('AlwaysFailWithInvalidArgument'), + address=self._address, + request='') + self.assertEqual(response_tensors.shape, ()) + self.assertEqual(status_code.shape, ()) + self.assertEqual(status_message.shape, ()) + status_code_value, status_message_value = sess.run((status_code, + status_message)) + self.assertEqual(errors.INVALID_ARGUMENT, status_code_value) + self.assertTrue(I_WARNED_YOU in status_message_value.decode('ascii')) + def testSometimesFailingMethodWithManyRequests(self): with self.test_session() as sess: # Fail hard by default. @@ -197,8 +208,7 @@ class RpcOpTestBase(object): address=self._address, request=request_tensors) for _ in range(10) ] - # Launch parallel 10 calls to the RpcOp, each containing - # 20 rpc requests. + # Launch parallel 10 calls to the RpcOp, each containing 20 rpc requests. many_response_values = sess.run(many_response_tensors) self.assertEqual(10, len(many_response_values)) for response_values in many_response_values: |