aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rpc
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-07-12 08:54:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-12 08:58:02 -0700
commit0ca8c47bfe47da178d976c7fd8c8ac8df1b2ba19 (patch)
tree99c46e707aa71d61921d7ad9bee5c204e22e729e /tensorflow/contrib/rpc
parent34a1b6780b55764802cd490e50481d4d2ed8355c (diff)
Cleaning up test proto for `tensorflow/contrib/rpc`.
PiperOrigin-RevId: 204307008
Diffstat (limited to 'tensorflow/contrib/rpc')
-rw-r--r--tensorflow/contrib/rpc/python/kernel_tests/BUILD3
-rw-r--r--tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py52
-rw-r--r--tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_servicer.py8
-rw-r--r--tensorflow/contrib/rpc/python/kernel_tests/test_example.proto147
4 files changed, 34 insertions, 176 deletions
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/BUILD b/tensorflow/contrib/rpc/python/kernel_tests/BUILD
index 2311c15a68..cb0b89ae55 100644
--- a/tensorflow/contrib/rpc/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/rpc/python/kernel_tests/BUILD
@@ -1,5 +1,3 @@
-# TODO(b/76425722): Port everything in here to OS (currently excluded).
-
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
@@ -17,7 +15,6 @@ tf_proto_library(
srcs = ["test_example.proto"],
has_services = 1,
cc_api_version = 2,
- protodeps = ["//tensorflow/core:protos_all"],
)
py_library(
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
index 27273d16b1..1c23c28860 100644
--- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
+++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
@@ -51,23 +51,23 @@ class RpcOpTestBase(object):
def testScalarHostPortRpc(self):
with self.test_session() as sess:
request_tensors = (
- test_example_pb2.TestCase(shape=[1, 2, 3]).SerializeToString())
+ test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString())
response_tensors = self.rpc(
- method=self.get_method_name('IncrementTestShapes'),
+ method=self.get_method_name('Increment'),
address=self._address,
request=request_tensors)
self.assertEqual(response_tensors.shape, ())
response_values = sess.run(response_tensors)
response_message = test_example_pb2.TestCase()
self.assertTrue(response_message.ParseFromString(response_values))
- self.assertAllEqual([2, 3, 4], response_message.shape)
+ self.assertAllEqual([2, 3, 4], response_message.values)
def testScalarHostPortTryRpc(self):
with self.test_session() as sess:
request_tensors = (
- test_example_pb2.TestCase(shape=[1, 2, 3]).SerializeToString())
+ test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString())
response_tensors, status_code, status_message = self.try_rpc(
- method=self.get_method_name('IncrementTestShapes'),
+ method=self.get_method_name('Increment'),
address=self._address,
request=request_tensors)
self.assertEqual(status_code.shape, ())
@@ -77,7 +77,7 @@ class RpcOpTestBase(object):
sess.run((response_tensors, status_code, status_message)))
response_message = test_example_pb2.TestCase()
self.assertTrue(response_message.ParseFromString(response_values))
- self.assertAllEqual([2, 3, 4], response_message.shape)
+ self.assertAllEqual([2, 3, 4], response_message.values)
# For the base Rpc op, don't expect to get error status back.
self.assertEqual(errors.OK, status_code_values)
self.assertEqual(b'', status_message_values)
@@ -86,7 +86,7 @@ class RpcOpTestBase(object):
with self.test_session() as sess:
request_tensors = []
response_tensors = self.rpc(
- method=self.get_method_name('IncrementTestShapes'),
+ method=self.get_method_name('Increment'),
address=self._address,
request=request_tensors)
self.assertAllEqual(response_tensors.shape, [0])
@@ -95,7 +95,7 @@ class RpcOpTestBase(object):
def testInvalidMethod(self):
for method in [
- '/InvalidService.IncrementTestShapes',
+ '/InvalidService.Increment',
self.get_method_name('InvalidMethodName')
]:
with self.test_session() as sess:
@@ -115,12 +115,12 @@ class RpcOpTestBase(object):
with self.assertRaises(errors.UnavailableError):
sess.run(
self.rpc(
- method=self.get_method_name('IncrementTestShapes'),
+ method=self.get_method_name('Increment'),
address=address,
request=''))
_, status_code_value, status_message_value = sess.run(
self.try_rpc(
- method=self.get_method_name('IncrementTestShapes'),
+ method=self.get_method_name('Increment'),
address=address,
request=''))
self.assertEqual(errors.UNAVAILABLE, status_code_value)
@@ -182,10 +182,10 @@ class RpcOpTestBase(object):
with self.test_session() as sess:
request_tensors = [
test_example_pb2.TestCase(
- shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
+ values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
]
response_tensors = self.rpc(
- method=self.get_method_name('IncrementTestShapes'),
+ method=self.get_method_name('Increment'),
address=self._address,
request=request_tensors)
self.assertEqual(response_tensors.shape, (20,))
@@ -194,17 +194,17 @@ class RpcOpTestBase(object):
for i in range(20):
response_message = test_example_pb2.TestCase()
self.assertTrue(response_message.ParseFromString(response_values[i]))
- self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape)
+ self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
def testVecHostPortManyParallelRpcs(self):
with self.test_session() as sess:
request_tensors = [
test_example_pb2.TestCase(
- shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
+ values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
]
many_response_tensors = [
self.rpc(
- method=self.get_method_name('IncrementTestShapes'),
+ method=self.get_method_name('Increment'),
address=self._address,
request=request_tensors) for _ in range(10)
]
@@ -216,25 +216,25 @@ class RpcOpTestBase(object):
for i in range(20):
response_message = test_example_pb2.TestCase()
self.assertTrue(response_message.ParseFromString(response_values[i]))
- self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape)
+ self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
def testVecHostPortRpcUsingEncodeAndDecodeProto(self):
with self.test_session() as sess:
request_tensors = encode_proto_op.encode_proto(
message_type='tensorflow.contrib.rpc.TestCase',
- field_names=['shape'],
+ field_names=['values'],
sizes=[[3]] * 20,
values=[
[[i, i + 1, i + 2] for i in range(20)],
])
response_tensor_strings = self.rpc(
- method=self.get_method_name('IncrementTestShapes'),
+ method=self.get_method_name('Increment'),
address=self._address,
request=request_tensors)
_, (response_shape,) = decode_proto_op.decode_proto(
bytes=response_tensor_strings,
message_type='tensorflow.contrib.rpc.TestCase',
- field_names=['shape'],
+ field_names=['values'],
output_types=[dtypes.int32])
response_shape_values = sess.run(response_shape)
self.assertAllEqual([[i + 1, i + 2, i + 3]
@@ -285,9 +285,9 @@ class RpcOpTestBase(object):
addresses = flatten([[
self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
] for _ in range(10)])
- request = test_example_pb2.TestCase(shape=[0, 1, 2]).SerializeToString()
+ request = test_example_pb2.TestCase(values=[0, 1, 2]).SerializeToString()
response_tensors, status_code, _ = self.try_rpc(
- method=self.get_method_name('IncrementTestShapes'),
+ method=self.get_method_name('Increment'),
address=addresses,
request=request)
response_tensors_values, status_code_values = sess.run((response_tensors,
@@ -303,9 +303,9 @@ class RpcOpTestBase(object):
flatten = lambda x: list(itertools.chain.from_iterable(x))
with self.test_session() as sess:
methods = flatten(
- [[self.get_method_name('IncrementTestShapes'), 'InvalidMethodName']
+ [[self.get_method_name('Increment'), 'InvalidMethodName']
for _ in range(10)])
- request = test_example_pb2.TestCase(shape=[0, 1, 2]).SerializeToString()
+ request = test_example_pb2.TestCase(values=[0, 1, 2]).SerializeToString()
response_tensors, status_code, _ = self.try_rpc(
method=methods, address=self._address, request=request)
response_tensors_values, status_code_values = sess.run((response_tensors,
@@ -325,10 +325,10 @@ class RpcOpTestBase(object):
] for _ in range(10)])
requests = [
test_example_pb2.TestCase(
- shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
+ values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
]
response_tensors, status_code, _ = self.try_rpc(
- method=self.get_method_name('IncrementTestShapes'),
+ method=self.get_method_name('Increment'),
address=addresses,
request=requests)
response_tensors_values, status_code_values = sess.run((response_tensors,
@@ -343,4 +343,4 @@ class RpcOpTestBase(object):
response_message = test_example_pb2.TestCase()
self.assertTrue(
response_message.ParseFromString(response_tensors_values[i]))
- self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape)
+ self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_servicer.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_servicer.py
index 7cbd636cb1..265254aa51 100644
--- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_servicer.py
+++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_servicer.py
@@ -30,8 +30,8 @@ from tensorflow.contrib.rpc.python.kernel_tests import test_example_pb2_grpc
class RpcOpTestServicer(test_example_pb2_grpc.TestCaseServiceServicer):
"""Test servicer for RpcOp tests."""
- def IncrementTestShapes(self, request, context):
- """Increment the entries in the shape attribute of request.
+ def Increment(self, request, context):
+ """Increment the entries in the `values` attribute of request.
Args:
request: input TestCase.
@@ -40,8 +40,8 @@ class RpcOpTestServicer(test_example_pb2_grpc.TestCaseServiceServicer):
Returns:
output TestCase.
"""
- for i in range(len(request.shape)):
- request.shape[i] += 1
+ for i in range(len(request.values)):
+ request.values[i] += 1
return request
def AlwaysFailWithInvalidArgument(self, request, context):
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/test_example.proto b/tensorflow/contrib/rpc/python/kernel_tests/test_example.proto
index 96f4550f62..8141466349 100644
--- a/tensorflow/contrib/rpc/python/kernel_tests/test_example.proto
+++ b/tensorflow/contrib/rpc/python/kernel_tests/test_example.proto
@@ -1,29 +1,17 @@
// Test description and protos to work with it.
-//
-// Many of the protos in this file are for unit tests that haven't been written yet.
syntax = "proto2";
-import "tensorflow/core/framework/types.proto";
-
package tensorflow.contrib.rpc;
-// A TestCase holds a proto and a bunch of assertions
-// about how it should decode.
+// A TestCase holds a sequence of values.
message TestCase {
- // A batch of primitives to be serialized and decoded.
- repeated RepeatedPrimitiveValue primitive = 1;
- // The shape of the batch.
- repeated int32 shape = 2;
- // Expected sizes for each field.
- repeated int32 sizes = 3;
- // Expected values for each field.
- repeated FieldSpec field = 4;
+ repeated int32 values = 1;
};
service TestCaseService {
- // Copy input, and increment each entry in 'shape' by 1.
- rpc IncrementTestShapes(TestCase) returns (TestCase) {
+ // Copy input, and increment each entry in 'values' by 1.
+ rpc Increment(TestCase) returns (TestCase) {
}
// Sleep forever.
@@ -42,130 +30,3 @@ service TestCaseService {
rpc SometimesFailWithInvalidArgument(TestCase) returns (TestCase) {
}
};
-
-// FieldSpec describes the expected output for a single field.
-message FieldSpec {
- optional string name = 1;
- optional tensorflow.DataType dtype = 2;
- optional RepeatedPrimitiveValue expected = 3;
-};
-
-message TestValue {
- optional PrimitiveValue primitive_value = 1;
- optional EnumValue enum_value = 2;
- optional MessageValue message_value = 3;
- optional RepeatedMessageValue repeated_message_value = 4;
- optional RepeatedPrimitiveValue repeated_primitive_value = 6;
-}
-
-message PrimitiveValue {
- optional double double_value = 1;
- optional float float_value = 2;
- optional int64 int64_value = 3;
- optional uint64 uint64_value = 4;
- optional int32 int32_value = 5;
- optional fixed64 fixed64_value = 6;
- optional fixed32 fixed32_value = 7;
- optional bool bool_value = 8;
- optional string string_value = 9;
- optional bytes bytes_value = 12;
- optional uint32 uint32_value = 13;
- optional sfixed32 sfixed32_value = 15;
- optional sfixed64 sfixed64_value = 16;
- optional sint32 sint32_value = 17;
- optional sint64 sint64_value = 18;
-}
-
-// NOTE: This definition must be kept in sync with PackedPrimitiveValue.
-message RepeatedPrimitiveValue {
- repeated double double_value = 1;
- repeated float float_value = 2;
- repeated int64 int64_value = 3;
- repeated uint64 uint64_value = 4;
- repeated int32 int32_value = 5;
- repeated fixed64 fixed64_value = 6;
- repeated fixed32 fixed32_value = 7;
- repeated bool bool_value = 8;
- repeated string string_value = 9;
- repeated bytes bytes_value = 12;
- repeated uint32 uint32_value = 13;
- repeated sfixed32 sfixed32_value = 15;
- repeated sfixed64 sfixed64_value = 16;
- repeated sint32 sint32_value = 17;
- repeated sint64 sint64_value = 18;
- repeated PrimitiveValue message_value = 19;
-}
-
-// A PackedPrimitiveValue looks exactly the same as a RepeatedPrimitiveValue
-// in the text format, but the binary serializion is different.
-// We test the packed representations by loading the same test cases
-// using this definition instead of RepeatedPrimitiveValue.
-// NOTE: This definition must be kept in sync with RepeatedPrimitiveValue
-// in every way except the packed=true declaration.
-message PackedPrimitiveValue {
- repeated double double_value = 1 [packed = true];
- repeated float float_value = 2 [packed = true];
- repeated int64 int64_value = 3 [packed = true];
- repeated uint64 uint64_value = 4 [packed = true];
- repeated int32 int32_value = 5 [packed = true];
- repeated fixed64 fixed64_value = 6 [packed = true];
- repeated fixed32 fixed32_value = 7 [packed = true];
- repeated bool bool_value = 8 [packed = true];
- repeated string string_value = 9;
- repeated bytes bytes_value = 12;
- repeated uint32 uint32_value = 13 [packed = true];
- repeated sfixed32 sfixed32_value = 15 [packed = true];
- repeated sfixed64 sfixed64_value = 16 [packed = true];
- repeated sint32 sint32_value = 17 [packed = true];
- repeated sint64 sint64_value = 18 [packed = true];
- repeated PrimitiveValue message_value = 19;
-}
-
-message EnumValue {
- enum Color {
- RED = 0;
- ORANGE = 1;
- YELLOW = 2;
- GREEN = 3;
- BLUE = 4;
- INDIGO = 5;
- VIOLET = 6;
- };
- optional Color enum_value = 14;
- repeated Color repeated_enum_value = 15;
-}
-
-
-message InnerMessageValue {
- optional float float_value = 2;
- repeated bytes bytes_values = 8;
-}
-
-message MiddleMessageValue {
- repeated int32 int32_values = 5;
- optional InnerMessageValue message_value = 11;
- optional uint32 uint32_value = 13;
-}
-
-message MessageValue {
- optional double double_value = 1;
- optional MiddleMessageValue message_value = 11;
-}
-
-message RepeatedMessageValue {
- message NestedMessageValue {
- optional float float_value = 2;
- repeated bytes bytes_values = 8;
- }
-
- repeated NestedMessageValue message_values = 11;
-}
-
-// Message containing fields with field numbers higher than any field above. An
-// instance of this message is prepended to each binary message in the test to
-// exercise the code path that handles fields encoded out of order of field
-// number.
-message ExtraFields {
- optional string string_value = 1776;
- optional bool bool_value = 1777;
-}