aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rpc
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-04-12 17:32:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-12 17:35:08 -0700
commitd42e4bde7ace9bb757b0fdf0e2a48c97cabe938b (patch)
treeba3d763a60d774e943f9d7c4c57e81301ecc74b5 /tensorflow/contrib/rpc
parentfffd3ca4fcf1f54f97a7be6f225fe183ad82b0ea (diff)
Porting tests for `rpc_op` to OS.
PiperOrigin-RevId: 192698931
Diffstat (limited to 'tensorflow/contrib/rpc')
-rw-r--r--tensorflow/contrib/rpc/BUILD16
-rw-r--r--tensorflow/contrib/rpc/python/kernel_tests/BUILD76
-rw-r--r--tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py71
-rw-r--r--tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py337
-rw-r--r--tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_servicer.py101
-rw-r--r--tensorflow/contrib/rpc/python/kernel_tests/test_example.proto171
6 files changed, 772 insertions, 0 deletions
diff --git a/tensorflow/contrib/rpc/BUILD b/tensorflow/contrib/rpc/BUILD
index 597f18c771..dbd311a276 100644
--- a/tensorflow/contrib/rpc/BUILD
+++ b/tensorflow/contrib/rpc/BUILD
@@ -4,6 +4,8 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
+load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
+
py_library(
name = "rpc",
srcs = [
@@ -11,3 +13,17 @@ py_library(
],
deps = ["//tensorflow/contrib/rpc/python/ops:rpc_op_py"],
)
+
+py_library(
+ name = "rpc_pip",
+ data = if_static(
+ [],
+ otherwise = ["//tensorflow/contrib/rpc/python/kernel_tests:libtestexample.so"],
+ ),
+ deps = [
+ ":rpc",
+ "//tensorflow/contrib/rpc/python/kernel_tests:py_test_deps",
+ "//tensorflow/contrib/rpc/python/kernel_tests:rpc_op_test_base",
+ "//tensorflow/contrib/rpc/python/kernel_tests:rpc_op_test_servicer",
+ ],
+)
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/BUILD b/tensorflow/contrib/rpc/python/kernel_tests/BUILD
new file mode 100644
index 0000000000..08ec1e61a4
--- /dev/null
+++ b/tensorflow/contrib/rpc/python/kernel_tests/BUILD
@@ -0,0 +1,76 @@
+# TODO(b/76425722): Port everything in here to OS (currently excluded).
+
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+# Placeholder for loading internal BUILD rule.
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
+load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
+load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
+load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
+
+tf_proto_library(
+ name = "test_example_proto",
+ srcs = ["test_example.proto"],
+ has_services = 1,
+ cc_api_version = 2,
+ protodeps = ["//tensorflow/core:protos_all"],
+)
+
+py_library(
+ name = "py_test_deps",
+ deps = [":test_example_proto_py"],
+)
+
+py_library(
+ name = "rpc_op_test_base",
+ srcs = ["rpc_op_test_base.py"],
+ deps = [
+ ":test_example_proto_py",
+ "//tensorflow/contrib/proto",
+ "//tensorflow/contrib/rpc",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "rpc_op_test_servicer",
+ srcs = ["rpc_op_test_servicer.py"],
+ deps = [
+ ":py_test_deps",
+ ":rpc_op_test_base",
+ "//tensorflow/core:protos_all_py",
+ "//third_party/py/numpy",
+ ],
+)
+
+tf_cc_shared_object(
+ name = "libtestexample.so",
+ linkstatic = 1,
+ deps = [
+ ":test_example_proto_cc",
+ ],
+)
+
+tf_py_test(
+ name = "rpc_op_test",
+ size = "small",
+ srcs = ["rpc_op_test.py"],
+ additional_deps = [
+ ":py_test_deps",
+ ":rpc_op_test_base",
+ ":rpc_op_test_servicer",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ ],
+ data = if_static(
+ [],
+ otherwise = [":libtestexample.so"],
+ ),
+)
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py
new file mode 100644
index 0000000000..e2e0dbc7a2
--- /dev/null
+++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py
@@ -0,0 +1,71 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+"""Tests for RpcOp."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ctypes as ct
+import os
+
+import grpc
+from grpc.framework.foundation import logging_pool
+import portpicker
+
+from tensorflow.contrib.rpc.python.kernel_tests import rpc_op_test_base
+from tensorflow.contrib.rpc.python.kernel_tests import rpc_op_test_servicer
+from tensorflow.contrib.rpc.python.kernel_tests import test_example_pb2_grpc
+from tensorflow.python.platform import test
+
+
+class RpcOpTest(test.TestCase, rpc_op_test_base.RpcOpTestBase):
+ _protocol = 'grpc'
+
+ invalid_method_string = 'Method not found'
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ super(RpcOpTest, self).__init__(methodName)
+ lib = os.path.join(os.path.dirname(__file__), 'libtestexample.so')
+ if os.path.isfile(lib):
+ ct.cdll.LoadLibrary(lib)
+
+ def get_method_name(self, suffix):
+ return '/tensorflow.contrib.rpc.TestCaseService/%s' % suffix
+
+ def setUp(self):
+ super(RpcOpTest, self).setUp()
+
+ service_port = portpicker.pick_unused_port()
+
+ server = grpc.server(logging_pool.pool(max_workers=25))
+ servicer = rpc_op_test_servicer.RpcOpTestServicer()
+ test_example_pb2_grpc.add_TestCaseServiceServicer_to_server(
+ servicer, server)
+ self._address = 'localhost:%d' % service_port
+ server.add_insecure_port(self._address)
+ server.start()
+ self._server = server
+
+ def tearDown(self):
+ # TODO(ebrevdo): Figure out why this sometimes times out.
+ # self._service.ExitLoop()
+ # self._service_thread.join()
+ # self._server.stop()
+ super(RpcOpTest, self).tearDown()
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
new file mode 100644
index 0000000000..aa03a103ed
--- /dev/null
+++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
@@ -0,0 +1,337 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+"""Base class for RpcOp tests."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import itertools
+
+import numpy as np
+
+from tensorflow.contrib.proto import decode_proto
+from tensorflow.contrib.proto import encode_proto
+from tensorflow.contrib.rpc import rpc
+from tensorflow.contrib.rpc import try_rpc
+from tensorflow.contrib.rpc.python.kernel_tests import test_example_pb2
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+
+__all__ = ['I_WARNED_YOU', 'RpcOpTestBase']
+
+I_WARNED_YOU = 'I warned you!'
+
+
+class RpcOpTestBase(object):
+ # pylint: disable=missing-docstring,invalid-name
+ """Base class for RpcOp tests."""
+
+ def get_method_name(self, suffix):
+ raise NotImplementedError
+
+ def rpc(self, *args, **kwargs):
+ return rpc(*args, protocol=self._protocol, **kwargs)
+
+ def try_rpc(self, *args, **kwargs):
+ return try_rpc(*args, protocol=self._protocol, **kwargs)
+
+ def testScalarHostPortRpc(self):
+ with self.test_session() as sess:
+ request_tensors = (
+ test_example_pb2.TestCase(shape=[1, 2, 3]).SerializeToString())
+ response_tensors = self.rpc(
+ method=self.get_method_name('IncrementTestShapes'),
+ address=self._address,
+ request=request_tensors)
+ self.assertEqual(response_tensors.shape, ())
+ response_values = sess.run(response_tensors)
+ response_message = test_example_pb2.TestCase()
+ self.assertTrue(response_message.ParseFromString(response_values))
+ self.assertAllEqual([2, 3, 4], response_message.shape)
+
+ def testScalarHostPortTryRpc(self):
+ with self.test_session() as sess:
+ request_tensors = (
+ test_example_pb2.TestCase(shape=[1, 2, 3]).SerializeToString())
+ response_tensors, status_code, status_message = self.try_rpc(
+ method=self.get_method_name('IncrementTestShapes'),
+ address=self._address,
+ request=request_tensors)
+ self.assertEqual(status_code.shape, ())
+ self.assertEqual(status_message.shape, ())
+ self.assertEqual(response_tensors.shape, ())
+ response_values, status_code_values, status_message_values = (
+ sess.run((response_tensors, status_code, status_message)))
+ response_message = test_example_pb2.TestCase()
+ self.assertTrue(response_message.ParseFromString(response_values))
+ self.assertAllEqual([2, 3, 4], response_message.shape)
+ # For the base Rpc op, don't expect to get error status back.
+ self.assertEqual(errors.OK, status_code_values)
+ self.assertEqual(b'', status_message_values)
+
+ def testEmptyHostPortRpc(self):
+ with self.test_session() as sess:
+ request_tensors = []
+ response_tensors = self.rpc(
+ method=self.get_method_name('IncrementTestShapes'),
+ address=self._address,
+ request=request_tensors)
+ self.assertAllEqual(response_tensors.shape, [0])
+ response_values = sess.run(response_tensors)
+ self.assertAllEqual(response_values.shape, [0])
+
+ def testInvalidAddresses(self):
+ with self.test_session() as sess:
+ with self.assertRaisesOpError(self.invalid_method_string):
+ sess.run(
+ self.rpc(
+ method='/InvalidService.IncrementTestShapes',
+ address=self._address,
+ request=''))
+
+ with self.assertRaisesOpError(self.invalid_method_string):
+ sess.run(
+ self.rpc(
+ method=self.get_method_name('InvalidMethodName'),
+ address=self._address,
+ request=''))
+
+ # This also covers the case of address=''
+ # and address='localhost:293874293874'
+ with self.assertRaises(errors.UnavailableError):
+ sess.run(
+ self.rpc(
+ method=self.get_method_name('IncrementTestShapes'),
+ address='unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@',
+ request=''))
+
+ # Test invalid method with the TryRpc op
+ _, status_code_value, status_message_value = sess.run(
+ self.try_rpc(
+ method=self.get_method_name('InvalidMethodName'),
+ address=self._address,
+ request=''))
+ self.assertEqual(errors.UNIMPLEMENTED, status_code_value)
+ self.assertTrue(
+ self.invalid_method_string in status_message_value.decode('ascii'))
+
+ def testAlwaysFailingMethod(self):
+ with self.test_session() as sess:
+ response_tensors = self.rpc(
+ method=self.get_method_name('AlwaysFailWithInvalidArgument'),
+ address=self._address,
+ request='')
+ self.assertEqual(response_tensors.shape, ())
+ with self.assertRaisesOpError(I_WARNED_YOU):
+ sess.run(response_tensors)
+
+ def testSometimesFailingMethodWithManyRequests(self):
+ with self.test_session() as sess:
+ # Fail hard by default.
+ response_tensors = self.rpc(
+ method=self.get_method_name('SometimesFailWithInvalidArgument'),
+ address=self._address,
+ request=[''] * 20)
+ self.assertEqual(response_tensors.shape, (20,))
+ with self.assertRaisesOpError(I_WARNED_YOU):
+ sess.run(response_tensors)
+
+ # Don't fail hard, use TryRpc - return the failing status instead.
+ response_tensors, status_code, status_message = self.try_rpc(
+ method=self.get_method_name('SometimesFailWithInvalidArgument'),
+ address=self._address,
+ request=[''] * 20)
+ self.assertEqual(response_tensors.shape, (20,))
+ self.assertEqual(status_code.shape, (20,))
+ self.assertEqual(status_message.shape, (20,))
+ status_code_values, status_message_values = sess.run((status_code,
+ status_message))
+ self.assertTrue([
+ x in (errors.OK, errors.INVALID_ARGUMENT) for x in status_code_values
+ ])
+ expected_message_values = np.where(
+ status_code_values == errors.INVALID_ARGUMENT,
+ I_WARNED_YOU.encode('ascii'), b'')
+ self.assertAllEqual(expected_message_values, status_message_values)
+
+ def testVecHostPortRpc(self):
+ with self.test_session() as sess:
+ request_tensors = [
+ test_example_pb2.TestCase(
+ shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
+ ]
+ response_tensors = self.rpc(
+ method=self.get_method_name('IncrementTestShapes'),
+ address=self._address,
+ request=request_tensors)
+ self.assertEqual(response_tensors.shape, (20,))
+ response_values = sess.run(response_tensors)
+ self.assertEqual(response_values.shape, (20,))
+ for i in range(20):
+ response_message = test_example_pb2.TestCase()
+ self.assertTrue(response_message.ParseFromString(response_values[i]))
+ self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape)
+
+ def testVecHostPortManyParallelRpcs(self):
+ with self.test_session() as sess:
+ request_tensors = [
+ test_example_pb2.TestCase(
+ shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
+ ]
+ many_response_tensors = [
+ self.rpc(
+ method=self.get_method_name('IncrementTestShapes'),
+ address=self._address,
+ request=request_tensors) for _ in range(10)
+ ]
+ # Launch parallel 10 calls to the RpcOp, each containing
+ # 20 rpc requests.
+ many_response_values = sess.run(many_response_tensors)
+ self.assertEqual(10, len(many_response_values))
+ for response_values in many_response_values:
+ self.assertEqual(response_values.shape, (20,))
+ for i in range(20):
+ response_message = test_example_pb2.TestCase()
+ self.assertTrue(response_message.ParseFromString(response_values[i]))
+ self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape)
+
+ def testVecHostPortRpcUsingEncodeAndDecodeProto(self):
+ with self.test_session() as sess:
+ request_tensors = encode_proto(
+ message_type='tensorflow.contrib.rpc.TestCase',
+ field_names=['shape'],
+ sizes=[[3]] * 20,
+ values=[
+ [[i, i + 1, i + 2] for i in range(20)],
+ ])
+ response_tensor_strings = self.rpc(
+ method=self.get_method_name('IncrementTestShapes'),
+ address=self._address,
+ request=request_tensors)
+ _, (response_shape,) = decode_proto(
+ bytes=response_tensor_strings,
+ message_type='tensorflow.contrib.rpc.TestCase',
+ field_names=['shape'],
+ output_types=[dtypes.int32])
+ response_shape_values = sess.run(response_shape)
+ self.assertAllEqual([[i + 1, i + 2, i + 3]
+ for i in range(20)], response_shape_values)
+
+ def testVecHostPortRpcCancelsUponSessionTimeOutWhenSleepingForever(self):
+ with self.test_session() as sess:
+ request_tensors = [''] * 25 # This will launch 25 RPC requests.
+ response_tensors = self.rpc(
+ method=self.get_method_name('SleepForever'),
+ address=self._address,
+ request=request_tensors)
+ for timeout_ms in [1, 500, 1000]:
+ options = config_pb2.RunOptions(timeout_in_ms=timeout_ms)
+ with self.assertRaises((errors.UnavailableError,
+ errors.DeadlineExceededError)):
+ sess.run(response_tensors, options=options)
+
+ def testVecHostPortRpcCancelsUponConfiguredTimeOutWhenSleepingForever(self):
+ with self.test_session() as sess:
+ request_tensors = [''] * 25 # This will launch 25 RPC requests.
+ response_tensors = self.rpc(
+ method=self.get_method_name('SleepForever'),
+ address=self._address,
+ timeout_in_ms=1000,
+ request=request_tensors)
+ with self.assertRaises(errors.DeadlineExceededError):
+ sess.run(response_tensors)
+
+ def testTryRpcPropagatesDeadlineErrorWithSometimesTimingOutRequests(self):
+ with self.test_session() as sess:
+ response_tensors, status_code, status_message = self.try_rpc(
+ method=self.get_method_name('SometimesSleepForever'),
+ timeout_in_ms=1000,
+ address=self._address,
+ request=[''] * 20)
+ self.assertEqual(response_tensors.shape, (20,))
+ self.assertEqual(status_code.shape, (20,))
+ self.assertEqual(status_message.shape, (20,))
+ status_code_values = sess.run(status_code)
+ self.assertTrue([
+ x in (errors.OK, errors.DEADLINE_EXCEEDED) for x in status_code_values
+ ])
+
+ def testTryRpcWithMultipleAddressesSingleRequest(self):
+ flatten = lambda x: list(itertools.chain.from_iterable(x))
+ with self.test_session() as sess:
+ addresses = flatten([[
+ self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
+ ] for _ in range(10)])
+ request = test_example_pb2.TestCase(shape=[0, 1, 2]).SerializeToString()
+ response_tensors, status_code, _ = self.try_rpc(
+ method=self.get_method_name('IncrementTestShapes'),
+ address=addresses,
+ request=request)
+ response_tensors_values, status_code_values = sess.run((response_tensors,
+ status_code))
+ self.assertAllEqual(
+ flatten([errors.OK, errors.UNAVAILABLE] for _ in range(10)),
+ status_code_values)
+ for i in range(10):
+ self.assertTrue(response_tensors_values[2 * i])
+ self.assertFalse(response_tensors_values[2 * i + 1])
+
+ def testTryRpcWithMultipleMethodsSingleRequest(self):
+ flatten = lambda x: list(itertools.chain.from_iterable(x))
+ with self.test_session() as sess:
+ methods = flatten(
+ [[self.get_method_name('IncrementTestShapes'), 'InvalidMethodName']
+ for _ in range(10)])
+ request = test_example_pb2.TestCase(shape=[0, 1, 2]).SerializeToString()
+ response_tensors, status_code, _ = self.try_rpc(
+ method=methods, address=self._address, request=request)
+ response_tensors_values, status_code_values = sess.run((response_tensors,
+ status_code))
+ self.assertAllEqual(
+ flatten([errors.OK, errors.UNIMPLEMENTED] for _ in range(10)),
+ status_code_values)
+ for i in range(10):
+ self.assertTrue(response_tensors_values[2 * i])
+ self.assertFalse(response_tensors_values[2 * i + 1])
+
+ def testTryRpcWithMultipleAddressesAndRequests(self):
+ flatten = lambda x: list(itertools.chain.from_iterable(x))
+ with self.test_session() as sess:
+ addresses = flatten([[
+ self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
+ ] for _ in range(10)])
+ requests = [
+ test_example_pb2.TestCase(
+ shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
+ ]
+ response_tensors, status_code, _ = self.try_rpc(
+ method=self.get_method_name('IncrementTestShapes'),
+ address=addresses,
+ request=requests)
+ response_tensors_values, status_code_values = sess.run((response_tensors,
+ status_code))
+ self.assertAllEqual(
+ flatten([errors.OK, errors.UNAVAILABLE] for _ in range(10)),
+ status_code_values)
+ for i in range(20):
+ if i % 2 == 1:
+ self.assertFalse(response_tensors_values[i])
+ else:
+ response_message = test_example_pb2.TestCase()
+ self.assertTrue(
+ response_message.ParseFromString(response_tensors_values[i]))
+ self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape)
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_servicer.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_servicer.py
new file mode 100644
index 0000000000..7cbd636cb1
--- /dev/null
+++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_servicer.py
@@ -0,0 +1,101 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+"""Test servicer for RpcOp tests."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import random
+import time
+
+import grpc
+
+from tensorflow.contrib.rpc.python.kernel_tests import rpc_op_test_base
+from tensorflow.contrib.rpc.python.kernel_tests import test_example_pb2_grpc
+
+
+class RpcOpTestServicer(test_example_pb2_grpc.TestCaseServiceServicer):
+ """Test servicer for RpcOp tests."""
+
+ def IncrementTestShapes(self, request, context):
+ """Increment the entries in the shape attribute of request.
+
+ Args:
+ request: input TestCase.
+ context: the rpc context.
+
+ Returns:
+ output TestCase.
+ """
+ for i in range(len(request.shape)):
+ request.shape[i] += 1
+ return request
+
+ def AlwaysFailWithInvalidArgument(self, request, context):
+ """Always fails with an InvalidArgument status.
+
+ Args:
+ request: input TestCase.
+ context: the rpc context.
+
+ Returns:
+ output TestCase.
+ """
+ del request
+ context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
+ context.set_details(rpc_op_test_base.I_WARNED_YOU)
+
+ def SometimesFailWithInvalidArgument(self, request, context):
+ """Sometimes fails with an InvalidArgument status.
+
+ Args:
+ request: input TestCase.
+ context: the rpc context.
+
+ Returns:
+ output TestCase.
+ """
+ if random.randint(0, 1) == 1:
+ context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
+ context.set_details(rpc_op_test_base.I_WARNED_YOU)
+ return request
+
+ def SleepForever(self, request, context):
+ """Sleeps forever.
+
+ Args:
+ request: input TestCase.
+ context: the rpc context.
+
+ Returns:
+ output TestCase.
+ """
+ # TODO(ebrevdo): Make this async wait like the stubby version.
+ time.sleep(5)
+
+ def SometimesSleepForever(self, request, context):
+ """Sometimes sleeps forever.
+
+ Args:
+ request: input TestCase.
+ context: the rpc context.
+
+ Returns:
+ output TestCase.
+ """
+ if random.randint(0, 1) == 1:
+ time.sleep(5)
+ return request
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/test_example.proto b/tensorflow/contrib/rpc/python/kernel_tests/test_example.proto
new file mode 100644
index 0000000000..96f4550f62
--- /dev/null
+++ b/tensorflow/contrib/rpc/python/kernel_tests/test_example.proto
@@ -0,0 +1,171 @@
+// Test description and protos to work with it.
+//
+// Many of the protos in this file are for unit tests that haven't been written yet.
+
+syntax = "proto2";
+
+import "tensorflow/core/framework/types.proto";
+
+package tensorflow.contrib.rpc;
+
+// A TestCase holds a proto and a bunch of assertions
+// about how it should decode.
+message TestCase {
+ // A batch of primitives to be serialized and decoded.
+ repeated RepeatedPrimitiveValue primitive = 1;
+ // The shape of the batch.
+ repeated int32 shape = 2;
+ // Expected sizes for each field.
+ repeated int32 sizes = 3;
+ // Expected values for each field.
+ repeated FieldSpec field = 4;
+};
+
+service TestCaseService {
+ // Copy input, and increment each entry in 'shape' by 1.
+ rpc IncrementTestShapes(TestCase) returns (TestCase) {
+ }
+
+ // Sleep forever.
+ rpc SleepForever(TestCase) returns (TestCase) {
+ }
+
+ // Sleep forever 50% of the time, return immediately the other 50%.
+ rpc SometimesSleepForever(TestCase) returns (TestCase) {
+ }
+
+ // Always fails with InvalidArgument.
+ rpc AlwaysFailWithInvalidArgument(TestCase) returns (TestCase) {
+ }
+
+ // Fails with InvalidArgument 50% of the time.
+ rpc SometimesFailWithInvalidArgument(TestCase) returns (TestCase) {
+ }
+};
+
+// FieldSpec describes the expected output for a single field.
+message FieldSpec {
+ optional string name = 1;
+ optional tensorflow.DataType dtype = 2;
+ optional RepeatedPrimitiveValue expected = 3;
+};
+
+message TestValue {
+ optional PrimitiveValue primitive_value = 1;
+ optional EnumValue enum_value = 2;
+ optional MessageValue message_value = 3;
+ optional RepeatedMessageValue repeated_message_value = 4;
+ optional RepeatedPrimitiveValue repeated_primitive_value = 6;
+}
+
+message PrimitiveValue {
+ optional double double_value = 1;
+ optional float float_value = 2;
+ optional int64 int64_value = 3;
+ optional uint64 uint64_value = 4;
+ optional int32 int32_value = 5;
+ optional fixed64 fixed64_value = 6;
+ optional fixed32 fixed32_value = 7;
+ optional bool bool_value = 8;
+ optional string string_value = 9;
+ optional bytes bytes_value = 12;
+ optional uint32 uint32_value = 13;
+ optional sfixed32 sfixed32_value = 15;
+ optional sfixed64 sfixed64_value = 16;
+ optional sint32 sint32_value = 17;
+ optional sint64 sint64_value = 18;
+}
+
+// NOTE: This definition must be kept in sync with PackedPrimitiveValue.
+message RepeatedPrimitiveValue {
+ repeated double double_value = 1;
+ repeated float float_value = 2;
+ repeated int64 int64_value = 3;
+ repeated uint64 uint64_value = 4;
+ repeated int32 int32_value = 5;
+ repeated fixed64 fixed64_value = 6;
+ repeated fixed32 fixed32_value = 7;
+ repeated bool bool_value = 8;
+ repeated string string_value = 9;
+ repeated bytes bytes_value = 12;
+ repeated uint32 uint32_value = 13;
+ repeated sfixed32 sfixed32_value = 15;
+ repeated sfixed64 sfixed64_value = 16;
+ repeated sint32 sint32_value = 17;
+ repeated sint64 sint64_value = 18;
+ repeated PrimitiveValue message_value = 19;
+}
+
+// A PackedPrimitiveValue looks exactly the same as a RepeatedPrimitiveValue
+// in the text format, but the binary serializion is different.
+// We test the packed representations by loading the same test cases
+// using this definition instead of RepeatedPrimitiveValue.
+// NOTE: This definition must be kept in sync with RepeatedPrimitiveValue
+// in every way except the packed=true declaration.
+message PackedPrimitiveValue {
+ repeated double double_value = 1 [packed = true];
+ repeated float float_value = 2 [packed = true];
+ repeated int64 int64_value = 3 [packed = true];
+ repeated uint64 uint64_value = 4 [packed = true];
+ repeated int32 int32_value = 5 [packed = true];
+ repeated fixed64 fixed64_value = 6 [packed = true];
+ repeated fixed32 fixed32_value = 7 [packed = true];
+ repeated bool bool_value = 8 [packed = true];
+ repeated string string_value = 9;
+ repeated bytes bytes_value = 12;
+ repeated uint32 uint32_value = 13 [packed = true];
+ repeated sfixed32 sfixed32_value = 15 [packed = true];
+ repeated sfixed64 sfixed64_value = 16 [packed = true];
+ repeated sint32 sint32_value = 17 [packed = true];
+ repeated sint64 sint64_value = 18 [packed = true];
+ repeated PrimitiveValue message_value = 19;
+}
+
+message EnumValue {
+ enum Color {
+ RED = 0;
+ ORANGE = 1;
+ YELLOW = 2;
+ GREEN = 3;
+ BLUE = 4;
+ INDIGO = 5;
+ VIOLET = 6;
+ };
+ optional Color enum_value = 14;
+ repeated Color repeated_enum_value = 15;
+}
+
+
+message InnerMessageValue {
+ optional float float_value = 2;
+ repeated bytes bytes_values = 8;
+}
+
+message MiddleMessageValue {
+ repeated int32 int32_values = 5;
+ optional InnerMessageValue message_value = 11;
+ optional uint32 uint32_value = 13;
+}
+
+message MessageValue {
+ optional double double_value = 1;
+ optional MiddleMessageValue message_value = 11;
+}
+
+message RepeatedMessageValue {
+ message NestedMessageValue {
+ optional float float_value = 2;
+ repeated bytes bytes_values = 8;
+ }
+
+ repeated NestedMessageValue message_values = 11;
+}
+
+// Message containing fields with field numbers higher than any field above. An
+// instance of this message is prepended to each binary message in the test to
+// exercise the code path that handles fields encoded out of order of field
+// number.
+message ExtraFields {
+ optional string string_value = 1776;
+ optional bool bool_value = 1777;
+}