aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rpc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-21 00:07:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-21 00:13:08 -0700
commitb19d6657070bbf1df5706195a0bf3a92cbf371fc (patch)
tree64279a06acd61c7028226eba46a05dd1127acee4 /tensorflow/contrib/rpc
parent2952f5134905af795ba90ae1eb97e39091ba9843 (diff)
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement. PiperOrigin-RevId: 213944932
Diffstat (limited to 'tensorflow/contrib/rpc')
-rw-r--r--tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py32
1 files changed, 16 insertions, 16 deletions
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
index 1c23c28860..0d615923e0 100644
--- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
+++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
@@ -49,7 +49,7 @@ class RpcOpTestBase(object):
return rpc_op.try_rpc(*args, protocol=self._protocol, **kwargs)
def testScalarHostPortRpc(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
request_tensors = (
test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString())
response_tensors = self.rpc(
@@ -63,7 +63,7 @@ class RpcOpTestBase(object):
self.assertAllEqual([2, 3, 4], response_message.values)
def testScalarHostPortTryRpc(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
request_tensors = (
test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString())
response_tensors, status_code, status_message = self.try_rpc(
@@ -83,7 +83,7 @@ class RpcOpTestBase(object):
self.assertEqual(b'', status_message_values)
def testEmptyHostPortRpc(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
request_tensors = []
response_tensors = self.rpc(
method=self.get_method_name('Increment'),
@@ -98,7 +98,7 @@ class RpcOpTestBase(object):
'/InvalidService.Increment',
self.get_method_name('InvalidMethodName')
]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesOpError(self.invalid_method_string):
sess.run(self.rpc(method=method, address=self._address, request=''))
@@ -111,7 +111,7 @@ class RpcOpTestBase(object):
def testInvalidAddress(self):
# This covers the case of address='' and address='localhost:293874293874'
address = 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.UnavailableError):
sess.run(
self.rpc(
@@ -128,7 +128,7 @@ class RpcOpTestBase(object):
self.connect_failed_string in status_message_value.decode('ascii'))
def testAlwaysFailingMethod(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
response_tensors = self.rpc(
method=self.get_method_name('AlwaysFailWithInvalidArgument'),
address=self._address,
@@ -150,7 +150,7 @@ class RpcOpTestBase(object):
self.assertTrue(I_WARNED_YOU in status_message_value.decode('ascii'))
def testSometimesFailingMethodWithManyRequests(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Fail hard by default.
response_tensors = self.rpc(
method=self.get_method_name('SometimesFailWithInvalidArgument'),
@@ -179,7 +179,7 @@ class RpcOpTestBase(object):
self.assertAllEqual(expected_message_values, status_message_values)
def testVecHostPortRpc(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
request_tensors = [
test_example_pb2.TestCase(
values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
@@ -197,7 +197,7 @@ class RpcOpTestBase(object):
self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
def testVecHostPortManyParallelRpcs(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
request_tensors = [
test_example_pb2.TestCase(
values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
@@ -219,7 +219,7 @@ class RpcOpTestBase(object):
self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
def testVecHostPortRpcUsingEncodeAndDecodeProto(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
request_tensors = encode_proto_op.encode_proto(
message_type='tensorflow.contrib.rpc.TestCase',
field_names=['values'],
@@ -241,7 +241,7 @@ class RpcOpTestBase(object):
for i in range(20)], response_shape_values)
def testVecHostPortRpcCancelsUponSessionTimeOutWhenSleepingForever(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
request_tensors = [''] * 25 # This will launch 25 RPC requests.
response_tensors = self.rpc(
method=self.get_method_name('SleepForever'),
@@ -254,7 +254,7 @@ class RpcOpTestBase(object):
sess.run(response_tensors, options=options)
def testVecHostPortRpcCancelsUponConfiguredTimeOutWhenSleepingForever(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
request_tensors = [''] * 25 # This will launch 25 RPC requests.
response_tensors = self.rpc(
method=self.get_method_name('SleepForever'),
@@ -265,7 +265,7 @@ class RpcOpTestBase(object):
sess.run(response_tensors)
def testTryRpcPropagatesDeadlineErrorWithSometimesTimingOutRequests(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
response_tensors, status_code, status_message = self.try_rpc(
method=self.get_method_name('SometimesSleepForever'),
timeout_in_ms=1000,
@@ -281,7 +281,7 @@ class RpcOpTestBase(object):
def testTryRpcWithMultipleAddressesSingleRequest(self):
flatten = lambda x: list(itertools.chain.from_iterable(x))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
addresses = flatten([[
self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
] for _ in range(10)])
@@ -301,7 +301,7 @@ class RpcOpTestBase(object):
def testTryRpcWithMultipleMethodsSingleRequest(self):
flatten = lambda x: list(itertools.chain.from_iterable(x))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
methods = flatten(
[[self.get_method_name('Increment'), 'InvalidMethodName']
for _ in range(10)])
@@ -319,7 +319,7 @@ class RpcOpTestBase(object):
def testTryRpcWithMultipleAddressesAndRequests(self):
flatten = lambda x: list(itertools.chain.from_iterable(x))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
addresses = flatten([[
self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
] for _ in range(10)])