diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-21 00:07:20 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-21 00:13:08 -0700 |
commit | b19d6657070bbf1df5706195a0bf3a92cbf371fc (patch) | |
tree | 64279a06acd61c7028226eba46a05dd1127acee4 /tensorflow/contrib/rpc | |
parent | 2952f5134905af795ba90ae1eb97e39091ba9843 (diff) |
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about:
* the fact that the session may be reused.
* the session is not closed even when doing a "with self.test_session()" statement.
PiperOrigin-RevId: 213944932
Diffstat (limited to 'tensorflow/contrib/rpc')
-rw-r--r-- | tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py | 32 |
1 files changed, 16 insertions, 16 deletions
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py index 1c23c28860..0d615923e0 100644 --- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py +++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py @@ -49,7 +49,7 @@ class RpcOpTestBase(object): return rpc_op.try_rpc(*args, protocol=self._protocol, **kwargs) def testScalarHostPortRpc(self): - with self.test_session() as sess: + with self.cached_session() as sess: request_tensors = ( test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString()) response_tensors = self.rpc( @@ -63,7 +63,7 @@ class RpcOpTestBase(object): self.assertAllEqual([2, 3, 4], response_message.values) def testScalarHostPortTryRpc(self): - with self.test_session() as sess: + with self.cached_session() as sess: request_tensors = ( test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString()) response_tensors, status_code, status_message = self.try_rpc( @@ -83,7 +83,7 @@ class RpcOpTestBase(object): self.assertEqual(b'', status_message_values) def testEmptyHostPortRpc(self): - with self.test_session() as sess: + with self.cached_session() as sess: request_tensors = [] response_tensors = self.rpc( method=self.get_method_name('Increment'), @@ -98,7 +98,7 @@ class RpcOpTestBase(object): '/InvalidService.Increment', self.get_method_name('InvalidMethodName') ]: - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaisesOpError(self.invalid_method_string): sess.run(self.rpc(method=method, address=self._address, request='')) @@ -111,7 +111,7 @@ class RpcOpTestBase(object): def testInvalidAddress(self): # This covers the case of address='' and address='localhost:293874293874' address = 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@' - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors.UnavailableError): sess.run( self.rpc( @@ -128,7 +128,7 @@ class RpcOpTestBase(object): self.connect_failed_string in status_message_value.decode('ascii')) def testAlwaysFailingMethod(self): - with self.test_session() as sess: + with self.cached_session() as sess: response_tensors = self.rpc( method=self.get_method_name('AlwaysFailWithInvalidArgument'), address=self._address, @@ -150,7 +150,7 @@ class RpcOpTestBase(object): self.assertTrue(I_WARNED_YOU in status_message_value.decode('ascii')) def testSometimesFailingMethodWithManyRequests(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Fail hard by default. response_tensors = self.rpc( method=self.get_method_name('SometimesFailWithInvalidArgument'), @@ -179,7 +179,7 @@ class RpcOpTestBase(object): self.assertAllEqual(expected_message_values, status_message_values) def testVecHostPortRpc(self): - with self.test_session() as sess: + with self.cached_session() as sess: request_tensors = [ test_example_pb2.TestCase( values=[i, i + 1, i + 2]).SerializeToString() for i in range(20) @@ -197,7 +197,7 @@ class RpcOpTestBase(object): self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values) def testVecHostPortManyParallelRpcs(self): - with self.test_session() as sess: + with self.cached_session() as sess: request_tensors = [ test_example_pb2.TestCase( values=[i, i + 1, i + 2]).SerializeToString() for i in range(20) @@ -219,7 +219,7 @@ class RpcOpTestBase(object): self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values) def testVecHostPortRpcUsingEncodeAndDecodeProto(self): - with self.test_session() as sess: + with self.cached_session() as sess: request_tensors = encode_proto_op.encode_proto( message_type='tensorflow.contrib.rpc.TestCase', field_names=['values'], @@ -241,7 +241,7 @@ class RpcOpTestBase(object): for i in range(20)], response_shape_values) def testVecHostPortRpcCancelsUponSessionTimeOutWhenSleepingForever(self): - with self.test_session() as sess: + with self.cached_session() as sess: request_tensors = [''] * 25 # This will launch 25 RPC requests. response_tensors = self.rpc( method=self.get_method_name('SleepForever'), @@ -254,7 +254,7 @@ class RpcOpTestBase(object): sess.run(response_tensors, options=options) def testVecHostPortRpcCancelsUponConfiguredTimeOutWhenSleepingForever(self): - with self.test_session() as sess: + with self.cached_session() as sess: request_tensors = [''] * 25 # This will launch 25 RPC requests. response_tensors = self.rpc( method=self.get_method_name('SleepForever'), @@ -265,7 +265,7 @@ class RpcOpTestBase(object): sess.run(response_tensors) def testTryRpcPropagatesDeadlineErrorWithSometimesTimingOutRequests(self): - with self.test_session() as sess: + with self.cached_session() as sess: response_tensors, status_code, status_message = self.try_rpc( method=self.get_method_name('SometimesSleepForever'), timeout_in_ms=1000, @@ -281,7 +281,7 @@ class RpcOpTestBase(object): def testTryRpcWithMultipleAddressesSingleRequest(self): flatten = lambda x: list(itertools.chain.from_iterable(x)) - with self.test_session() as sess: + with self.cached_session() as sess: addresses = flatten([[ self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@' ] for _ in range(10)]) @@ -301,7 +301,7 @@ class RpcOpTestBase(object): def testTryRpcWithMultipleMethodsSingleRequest(self): flatten = lambda x: list(itertools.chain.from_iterable(x)) - with self.test_session() as sess: + with self.cached_session() as sess: methods = flatten( [[self.get_method_name('Increment'), 'InvalidMethodName'] for _ in range(10)]) @@ -319,7 +319,7 @@ class RpcOpTestBase(object): def testTryRpcWithMultipleAddressesAndRequests(self): flatten = lambda x: list(itertools.chain.from_iterable(x)) - with self.test_session() as sess: + with self.cached_session() as sess: addresses = flatten([[ self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@' ] for _ in range(10)]) |