aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rpc
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-04-24 13:13:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-24 13:16:00 -0700
commit893aa776009418c841d49c924207f3cdaf1d5174 (patch)
treea11ab86db407e72161d9b4734128604ae3492052 /tensorflow/contrib/rpc
parent33ffc8e7ff5090b92951c7faac150042dd814085 (diff)
Fixing concurrency issues in RPC factory.
PiperOrigin-RevId: 194133903
Diffstat (limited to 'tensorflow/contrib/rpc')
-rw-r--r--tensorflow/contrib/rpc/python/kernel_tests/BUILD1
-rw-r--r--tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py1
-rw-r--r--tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py60
3 files changed, 36 insertions, 26 deletions
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/BUILD b/tensorflow/contrib/rpc/python/kernel_tests/BUILD
index f3e6731213..2311c15a68 100644
--- a/tensorflow/contrib/rpc/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/rpc/python/kernel_tests/BUILD
@@ -28,7 +28,6 @@ py_library(
py_library(
name = "rpc_op_test_base",
srcs = ["rpc_op_test_base.py"],
- tags = ["notsan"],
deps = [
":test_example_proto_py",
"//tensorflow/contrib/proto",
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py
index e2e0dbc7a2..3fc6bfbb4d 100644
--- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py
+++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py
@@ -35,6 +35,7 @@ class RpcOpTest(test.TestCase, rpc_op_test_base.RpcOpTestBase):
_protocol = 'grpc'
invalid_method_string = 'Method not found'
+ connect_failed_string = 'Connect Failed'
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
super(RpcOpTest, self).__init__(methodName)
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
index 89f3ee1a1c..27273d16b1 100644
--- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
+++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
@@ -93,40 +93,39 @@ class RpcOpTestBase(object):
response_values = sess.run(response_tensors)
self.assertAllEqual(response_values.shape, [0])
- def testInvalidAddresses(self):
- with self.test_session() as sess:
- with self.assertRaisesOpError(self.invalid_method_string):
- sess.run(
- self.rpc(
- method='/InvalidService.IncrementTestShapes',
- address=self._address,
- request=''))
+ def testInvalidMethod(self):
+ for method in [
+ '/InvalidService.IncrementTestShapes',
+ self.get_method_name('InvalidMethodName')
+ ]:
+ with self.test_session() as sess:
+ with self.assertRaisesOpError(self.invalid_method_string):
+ sess.run(self.rpc(method=method, address=self._address, request=''))
- with self.assertRaisesOpError(self.invalid_method_string):
- sess.run(
- self.rpc(
- method=self.get_method_name('InvalidMethodName'),
- address=self._address,
- request=''))
+ _, status_code_value, status_message_value = sess.run(
+ self.try_rpc(method=method, address=self._address, request=''))
+ self.assertEqual(errors.UNIMPLEMENTED, status_code_value)
+ self.assertTrue(
+ self.invalid_method_string in status_message_value.decode('ascii'))
- # This also covers the case of address=''
- # and address='localhost:293874293874'
+ def testInvalidAddress(self):
+ # This covers the case of address='' and address='localhost:293874293874'
+ address = 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
+ with self.test_session() as sess:
with self.assertRaises(errors.UnavailableError):
sess.run(
self.rpc(
method=self.get_method_name('IncrementTestShapes'),
- address='unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@',
+ address=address,
request=''))
-
- # Test invalid method with the TryRpc op
_, status_code_value, status_message_value = sess.run(
self.try_rpc(
- method=self.get_method_name('InvalidMethodName'),
- address=self._address,
+ method=self.get_method_name('IncrementTestShapes'),
+ address=address,
request=''))
- self.assertEqual(errors.UNIMPLEMENTED, status_code_value)
+ self.assertEqual(errors.UNAVAILABLE, status_code_value)
self.assertTrue(
- self.invalid_method_string in status_message_value.decode('ascii'))
+ self.connect_failed_string in status_message_value.decode('ascii'))
def testAlwaysFailingMethod(self):
with self.test_session() as sess:
@@ -138,6 +137,18 @@ class RpcOpTestBase(object):
with self.assertRaisesOpError(I_WARNED_YOU):
sess.run(response_tensors)
+ response_tensors, status_code, status_message = self.try_rpc(
+ method=self.get_method_name('AlwaysFailWithInvalidArgument'),
+ address=self._address,
+ request='')
+ self.assertEqual(response_tensors.shape, ())
+ self.assertEqual(status_code.shape, ())
+ self.assertEqual(status_message.shape, ())
+ status_code_value, status_message_value = sess.run((status_code,
+ status_message))
+ self.assertEqual(errors.INVALID_ARGUMENT, status_code_value)
+ self.assertTrue(I_WARNED_YOU in status_message_value.decode('ascii'))
+
def testSometimesFailingMethodWithManyRequests(self):
with self.test_session() as sess:
# Fail hard by default.
@@ -197,8 +208,7 @@ class RpcOpTestBase(object):
address=self._address,
request=request_tensors) for _ in range(10)
]
- # Launch parallel 10 calls to the RpcOp, each containing
- # 20 rpc requests.
+ # Launch parallel 10 calls to the RpcOp, each containing 20 rpc requests.
many_response_values = sess.run(many_response_tensors)
self.assertEqual(10, len(many_response_values))
for response_values in many_response_values: