aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/proto
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-04-11 13:59:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-11 14:02:37 -0700
commit1e283d64816b92de6c398bee6df2122409c87d73 (patch)
treedf55ba2b29dcc1802ea57bf7c8b9d2141e64f2c8 /tensorflow/contrib/proto
parent73aef57c451a13e07e48933d0bae3ad3ed2c64bd (diff)
Porting tests for the `decode_proto` and `encode_proto` to OS.
PiperOrigin-RevId: 192504411
Diffstat (limited to 'tensorflow/contrib/proto')
-rw-r--r--tensorflow/contrib/proto/BUILD16
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/BUILD81
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl78
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py68
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py300
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py179
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt161
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt16
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt20
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt21
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt32
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt62
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt21
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/test_case.py35
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/test_example.proto149
15 files changed, 1239 insertions, 0 deletions
diff --git a/tensorflow/contrib/proto/BUILD b/tensorflow/contrib/proto/BUILD
index 046652cbc5..3e9b1a0b8d 100644
--- a/tensorflow/contrib/proto/BUILD
+++ b/tensorflow/contrib/proto/BUILD
@@ -4,6 +4,8 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
+load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
+
py_library(
name = "proto",
srcs = [
@@ -14,3 +16,17 @@ py_library(
"//tensorflow/contrib/proto/python/ops:encode_proto_op_py",
],
)
+
+py_library(
+ name = "proto_pip",
+ data = [
+ "//tensorflow/contrib/proto/python/kernel_tests:test_messages",
+ ] + if_static(
+ [],
+ otherwise = ["//tensorflow/contrib/proto/python/kernel_tests:libtestexample.so"],
+ ),
+ deps = [
+ ":proto",
+ "//tensorflow/contrib/proto/python/kernel_tests:py_test_deps",
+ ],
+)
diff --git a/tensorflow/contrib/proto/python/kernel_tests/BUILD b/tensorflow/contrib/proto/python/kernel_tests/BUILD
new file mode 100644
index 0000000000..4125ea8a2a
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/BUILD
@@ -0,0 +1,81 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+# Much of the work in this BUILD file actually happens in the corresponding
+# build_defs.bzl, which creates an individual testcase for each example .pbtxt
+# file in this directory.
+#
+load(":build_defs.bzl", "decode_proto_test_suite")
+load(":build_defs.bzl", "encode_proto_test_suite")
+
+# This expands to a tf_py_test for each test file.
+# It defines the test_suite :decode_proto_op_tests.
+decode_proto_test_suite(
+ name = "decode_proto_tests",
+ examples = glob(["*.pbtxt"]),
+)
+
+# This expands to a tf_py_test for each test file.
+# It defines the test_suite :encode_proto_op_tests.
+encode_proto_test_suite(
+ name = "encode_proto_tests",
+ examples = glob(["*.pbtxt"]),
+)
+
+# Below here are tests that are not tied to an example text proto.
+filegroup(
+ name = "test_messages",
+ srcs = glob(["*.pbtxt"]),
+)
+
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
+load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
+load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
+load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
+
+tf_py_test(
+ name = "decode_proto_fail_test",
+ size = "small",
+ srcs = ["decode_proto_fail_test.py"],
+ additional_deps = [
+ ":py_test_deps",
+ "//third_party/py/numpy",
+ "//tensorflow/contrib/proto:proto",
+ ],
+ data = if_static(
+ [],
+ otherwise = [":libtestexample.so"],
+ ),
+)
+
+py_library(
+ name = "test_case",
+ srcs = ["test_case.py"],
+ deps = ["//tensorflow/python:client_testlib"],
+)
+
+py_library(
+ name = "py_test_deps",
+ deps = [
+ ":test_case",
+ ":test_example_proto_py",
+ ],
+)
+
+tf_proto_library(
+ name = "test_example_proto",
+ srcs = ["test_example.proto"],
+ cc_api_version = 2,
+ protodeps = ["//tensorflow/core:protos_all"],
+)
+
+tf_cc_shared_object(
+ name = "libtestexample.so",
+ linkstatic = 1,
+ deps = [
+ ":test_example_proto_cc",
+ ],
+)
diff --git a/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl b/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl
new file mode 100644
index 0000000000..6fe48ae807
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl
@@ -0,0 +1,78 @@
+"""BUILD rules for generating file-driven proto test cases.
+
+The decode_proto_test_suite() and encode_proto_test_suite() rules take a list
+of text protos and generates a tf_py_test() for each one.
+"""
+
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
+load("//tensorflow:tensorflow.bzl", "register_extension_info")
+load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
+
+def _test_name(test, path):
+ return "%s_%s_test" % (test, path.split("/")[-1].split(".")[0])
+
+def decode_proto_test_suite(name, examples):
+ """Build the decode_proto py_test for each test filename."""
+ for test_filename in examples:
+ tf_py_test(
+ name = _test_name("decode_proto", test_filename),
+ srcs = ["decode_proto_op_test.py"],
+ size = "small",
+ data = [test_filename] + if_static(
+ [],
+ otherwise = [":libtestexample.so"],
+ ),
+ main = "decode_proto_op_test.py",
+ args = [
+ "--message_text_file=\"%s/%s\"" % (native.package_name(), test_filename),
+ ],
+ additional_deps = [
+ ":py_test_deps",
+ "//third_party/py/numpy",
+ "//tensorflow/contrib/proto:proto",
+ ],
+ )
+ native.test_suite(
+ name = name,
+ tests = [":" + _test_name("decode_proto", test_filename)
+ for test_filename in examples],
+ )
+
+def encode_proto_test_suite(name, examples):
+ """Build the encode_proto py_test for each test filename."""
+ for test_filename in examples:
+ tf_py_test(
+ name = _test_name("encode_proto", test_filename),
+ srcs = ["encode_proto_op_test.py"],
+ size = "small",
+ data = [test_filename] + if_static(
+ [],
+ otherwise = [":libtestexample.so"],
+ ),
+ main = "encode_proto_op_test.py",
+ args = [
+ "--message_text_file=\"%s/%s\"" % (native.package_name(), test_filename),
+ ],
+ additional_deps = [
+ ":py_test_deps",
+ "//third_party/py/numpy",
+ "//tensorflow/contrib/proto:proto",
+ ],
+ )
+ native.test_suite(
+ name = name,
+ tests = [":" + _test_name("encode_proto", test_filename)
+ for test_filename in examples],
+ )
+
+register_extension_info(
+ extension_name = "decode_proto_test_suite",
+ label_regex_map = {
+ "deps": "deps:decode_example_.*",
+ })
+
+register_extension_info(
+ extension_name = "encode_proto_test_suite",
+ label_regex_map = {
+ "deps": "deps:encode_example_.*",
+ })
diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py
new file mode 100644
index 0000000000..f019833905
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py
@@ -0,0 +1,68 @@
+# =============================================================================
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+# Python3 preparedness imports.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib import proto
+from tensorflow.contrib.proto.python.kernel_tests import test_case
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+
+
+class DecodeProtoFailTest(test_case.ProtoOpTestCase):
+ """Test failure cases for DecodeToProto."""
+
+ def _TestCorruptProtobuf(self, sanitize):
+ """Test failure cases for DecodeToProto."""
+
+ # The goal here is to check the error reporting.
+ # Testing against a variety of corrupt protobufs is
+ # done by fuzzing.
+ corrupt_proto = 'This is not a binary protobuf'
+
+ # Numpy silently truncates the strings if you don't specify dtype=object.
+ batch = np.array(corrupt_proto, dtype=object)
+ msg_type = 'tensorflow.contrib.proto.TestCase'
+ field_names = ['sizes']
+ field_types = [dtypes.int32]
+
+ with self.test_session() as sess:
+ ctensor, vtensor = proto.decode_proto(
+ batch,
+ message_type=msg_type,
+ field_names=field_names,
+ output_types=field_types,
+ sanitize=sanitize)
+ with self.assertRaisesRegexp(errors.DataLossError,
+ 'Unable to parse binary protobuf'
+ '|Failed to consume entire buffer'):
+ _ = sess.run([ctensor] + vtensor)
+
+ def testCorrupt(self):
+ self._TestCorruptProtobuf(sanitize=False)
+
+ def testSanitizerCorrupt(self):
+ self._TestCorruptProtobuf(sanitize=True)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py
new file mode 100644
index 0000000000..30ceac5f5f
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py
@@ -0,0 +1,300 @@
+# =============================================================================
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Table-driven test for decode_proto op.
+
+This test is run once with each of the *.TestCase.pbtxt files
+in the test directory.
+"""
+# Python3 preparedness imports.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from google.protobuf import text_format
+
+from tensorflow.contrib import proto
+from tensorflow.contrib.proto.python.kernel_tests import test_case
+from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2
+from tensorflow.python.framework import dtypes
+from tensorflow.python.platform import flags
+from tensorflow.python.platform import test
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string('message_text_file', None,
+ 'A file containing a text serialized TestCase protobuf.')
+
+
+class DecodeProtoOpTest(test_case.ProtoOpTestCase):
+
+ def _compareValues(self, fd, vs, evs):
+ """Compare lists/arrays of field values."""
+
+ if len(vs) != len(evs):
+ self.fail('Field %s decoded %d outputs, expected %d' %
+ (fd.name, len(vs), len(evs)))
+ for i, ev in enumerate(evs):
+ # Special case fuzzy match for float32. TensorFlow seems to mess with
+ # MAX_FLT slightly and the test doesn't work otherwise.
+ # TODO(nix): ask on TF list about why MAX_FLT doesn't pass through.
+ if fd.cpp_type == fd.CPPTYPE_FLOAT:
+ # Numpy isclose() is better than assertIsClose() which uses an absolute
+ # value comparison.
+ self.assertTrue(
+ np.isclose(vs[i], ev), 'expected %r, actual %r' % (ev, vs[i]))
+ elif fd.cpp_type == fd.CPPTYPE_STRING:
+ # In Python3 string tensor values will be represented as bytes, so we
+ # reencode the proto values to match that.
+ self.assertEqual(vs[i], ev.encode('ascii'))
+ else:
+ # Doubles and other types pass through unscathed.
+ self.assertEqual(vs[i], ev)
+
+ def _compareRepeatedPrimitiveValue(self, batch_shape, sizes, fields,
+ field_dict):
+ """Compare protos of type RepeatedPrimitiveValue.
+
+ Args:
+ batch_shape: the shape of the input tensor of serialized messages.
+ sizes: int matrix of repeat counts returned by decode_proto
+ fields: list of test_example_pb2.FieldSpec (types and expected values)
+ field_dict: map from field names to decoded numpy tensors of values
+ """
+
+ # Check that expected values match.
+ for field in fields:
+ values = field_dict[field.name]
+ self.assertEqual(dtypes.as_dtype(values.dtype), field.dtype)
+
+ fd = field.expected.DESCRIPTOR.fields_by_name[field.name]
+
+ # Values has the same shape as the input plus an extra
+ # dimension for repeats.
+ self.assertEqual(list(values.shape)[:-1], batch_shape)
+
+ # Nested messages are represented as TF strings, requiring
+ # some special handling.
+ if field.name == 'message_value':
+ vs = []
+ for buf in values.flat:
+ msg = test_example_pb2.PrimitiveValue()
+ msg.ParseFromString(buf)
+ vs.append(msg)
+ evs = getattr(field.expected, field.name)
+ if len(vs) != len(evs):
+ self.fail('Field %s decoded %d outputs, expected %d' %
+ (fd.name, len(vs), len(evs)))
+ for v, ev in zip(vs, evs):
+ self.assertEqual(v, ev)
+ continue
+
+ # This can be a little confusing. For testing we are using
+ # RepeatedPrimitiveValue in two ways: it's the proto that we
+ # decode for testing, and it's used in the expected value as a
+ # union type. The two cases are slightly different: this is the
+ # second case.
+ # We may be fetching the uint64_value from the test proto, but
+ # in the expected proto we store it in the int64_value field
+ # because TensorFlow doesn't support unsigned int64.
+ tf_type_to_primitive_value_field = {
+ dtypes.float32:
+ 'float_value',
+ dtypes.float64:
+ 'double_value',
+ dtypes.int32:
+ 'int32_value',
+ dtypes.uint8:
+ 'uint8_value',
+ dtypes.int8:
+ 'int8_value',
+ dtypes.string:
+ 'string_value',
+ dtypes.int64:
+ 'int64_value',
+ dtypes.bool:
+ 'bool_value',
+ # Unhandled TensorFlow types:
+ # DT_INT16 DT_COMPLEX64 DT_QINT8 DT_QUINT8 DT_QINT32
+ # DT_BFLOAT16 DT_QINT16 DT_QUINT16 DT_UINT16
+ }
+ tf_field_name = tf_type_to_primitive_value_field.get(field.dtype)
+ if tf_field_name is None:
+ self.fail('Unhandled tensorflow type %d' % field.dtype)
+
+ self._compareValues(fd, values.flat,
+ getattr(field.expected, tf_field_name))
+
+ def _runDecodeProtoTests(self, fields, case_sizes, batch_shape, batch,
+ message_type, message_format, sanitize,
+ force_disordered=False):
+ """Run decode tests on a batch of messages.
+
+ Args:
+ fields: list of test_example_pb2.FieldSpec (types and expected values)
+ case_sizes: expected sizes array
+ batch_shape: the shape of the input tensor of serialized messages
+ batch: list of serialized messages
+ message_type: descriptor name for messages
+ message_format: format of messages, 'text' or 'binary'
+ sanitize: whether to sanitize binary protobuf inputs
+ force_disordered: whether to force fields encoded out of order.
+ """
+
+ if force_disordered:
+ # Exercise code path that handles out-of-order fields by prepending extra
+ # fields with tag numbers higher than any real field. Note that this won't
+ # work with sanitization because that forces reserialization using a
+ # trusted decoder and encoder.
+ assert not sanitize
+ extra_fields = test_example_pb2.ExtraFields()
+ extra_fields.string_value = 'IGNORE ME'
+ extra_fields.bool_value = False
+ extra_msg = extra_fields.SerializeToString()
+ batch = [extra_msg + msg for msg in batch]
+
+ # Numpy silently truncates the strings if you don't specify dtype=object.
+ batch = np.array(batch, dtype=object)
+ batch = np.reshape(batch, batch_shape)
+
+ field_names = [f.name for f in fields]
+ output_types = [f.dtype for f in fields]
+
+ with self.test_session() as sess:
+ sizes, vtensor = proto.decode_proto(
+ batch,
+ message_type=message_type,
+ field_names=field_names,
+ output_types=output_types,
+ message_format=message_format,
+ sanitize=sanitize)
+
+ vlist = sess.run([sizes] + vtensor)
+ sizes = vlist[0]
+ # Values is a list of tensors, one for each field.
+ value_tensors = vlist[1:]
+
+ # Check that the repeat sizes are correct.
+ self.assertTrue(
+ np.all(np.array(sizes.shape) == batch_shape + [len(field_names)]))
+
+ # Check that the decoded sizes match the expected sizes.
+ self.assertEqual(len(sizes.flat), len(case_sizes))
+ self.assertTrue(
+ np.all(sizes.flat == np.array(
+ case_sizes, dtype=np.int32)))
+
+ field_dict = dict(zip(field_names, value_tensors))
+
+ self._compareRepeatedPrimitiveValue(batch_shape, sizes, fields,
+ field_dict)
+
+ def testBinary(self):
+ with open(FLAGS.message_text_file, 'r') as fp:
+ case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
+
+ batch = [primitive.SerializeToString() for primitive in case.primitive]
+ self._runDecodeProtoTests(
+ case.field,
+ case.sizes,
+ list(case.shape),
+ batch,
+ 'tensorflow.contrib.proto.RepeatedPrimitiveValue',
+ 'binary',
+ sanitize=False)
+
+ def testBinaryDisordered(self):
+ with open(FLAGS.message_text_file, 'r') as fp:
+ case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
+
+ batch = [primitive.SerializeToString() for primitive in case.primitive]
+ self._runDecodeProtoTests(
+ case.field,
+ case.sizes,
+ list(case.shape),
+ batch,
+ 'tensorflow.contrib.proto.RepeatedPrimitiveValue',
+ 'binary',
+ sanitize=False,
+ force_disordered=True)
+
+ def testPacked(self):
+ with open(FLAGS.message_text_file, 'r') as fp:
+ case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
+
+ # Now try with the packed serialization.
+ # We test the packed representations by loading the same test cases
+ # using PackedPrimitiveValue instead of RepeatedPrimitiveValue.
+ # To do this we rely on the text format being the same for packed and
+ # unpacked fields, and reparse the test message using the packed version
+ # of the proto.
+ packed_batch = [
+ # Note: float_format='.17g' is necessary to ensure preservation of
+ # doubles and floats in text format.
+ text_format.Parse(
+ text_format.MessageToString(
+ primitive, float_format='.17g'),
+ test_example_pb2.PackedPrimitiveValue()).SerializeToString()
+ for primitive in case.primitive
+ ]
+
+ self._runDecodeProtoTests(
+ case.field,
+ case.sizes,
+ list(case.shape),
+ packed_batch,
+ 'tensorflow.contrib.proto.PackedPrimitiveValue',
+ 'binary',
+ sanitize=False)
+
+ def testText(self):
+ with open(FLAGS.message_text_file, 'r') as fp:
+ case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
+
+ # Note: float_format='.17g' is necessary to ensure preservation of
+ # doubles and floats in text format.
+ text_batch = [
+ text_format.MessageToString(
+ primitive, float_format='.17g') for primitive in case.primitive
+ ]
+
+ self._runDecodeProtoTests(
+ case.field,
+ case.sizes,
+ list(case.shape),
+ text_batch,
+ 'tensorflow.contrib.proto.RepeatedPrimitiveValue',
+ 'text',
+ sanitize=False)
+
+ def testSanitizerGood(self):
+ with open(FLAGS.message_text_file, 'r') as fp:
+ case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
+
+ batch = [primitive.SerializeToString() for primitive in case.primitive]
+ self._runDecodeProtoTests(
+ case.field,
+ case.sizes,
+ list(case.shape),
+ batch,
+ 'tensorflow.contrib.proto.RepeatedPrimitiveValue',
+ 'binary',
+ sanitize=True)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py
new file mode 100644
index 0000000000..2a24c3b8ce
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py
@@ -0,0 +1,179 @@
+# =============================================================================
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Table-driven test for encode_proto op.
+
+This test is run once with each of the *.TestCase.pbtxt files
+in the test directory.
+
+It tests that encode_proto is a lossless inverse of decode_proto
+(for the specified fields).
+"""
+# Python3 readiness boilerplate
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from google.protobuf import text_format
+
+from tensorflow.contrib import proto
+from tensorflow.contrib.proto.python.kernel_tests import test_case
+from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import flags
+from tensorflow.python.platform import test
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string('message_text_file', None,
+ 'A file containing a text serialized TestCase protobuf.')
+
+
+class EncodeProtoOpTest(test_case.ProtoOpTestCase):
+
+ def testBadInputs(self):
+ # Invalid field name
+ with self.test_session():
+ with self.assertRaisesOpError('Unknown field: non_existent_field'):
+ proto.encode_proto(
+ sizes=[[1]],
+ values=[np.array([[0.0]], dtype=np.int32)],
+ message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue',
+ field_names=['non_existent_field']).eval()
+
+ # Incorrect types.
+ with self.test_session():
+ with self.assertRaisesOpError(
+ 'Incompatible type for field double_value.'):
+ proto.encode_proto(
+ sizes=[[1]],
+ values=[np.array([[0.0]], dtype=np.int32)],
+ message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue',
+ field_names=['double_value']).eval()
+
+ # Incorrect shapes of sizes.
+ with self.test_session():
+ with self.assertRaisesOpError(
+ r'sizes should be batch_size \+ \[len\(field_names\)\]'):
+ sizes = array_ops.placeholder(dtypes.int32)
+ values = array_ops.placeholder(dtypes.float64)
+ proto.encode_proto(
+ sizes=sizes,
+ values=[values],
+ message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue',
+ field_names=['double_value']).eval(feed_dict={
+ sizes: [[[0, 0]]],
+ values: [[0.0]]
+ })
+
+ # Inconsistent shapes of values.
+ with self.test_session():
+ with self.assertRaisesOpError(
+ 'Values must match up to the last dimension'):
+ sizes = array_ops.placeholder(dtypes.int32)
+ values1 = array_ops.placeholder(dtypes.float64)
+ values2 = array_ops.placeholder(dtypes.int32)
+ (proto.encode_proto(
+ sizes=[[1, 1]],
+ values=[values1, values2],
+ message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue',
+ field_names=['double_value', 'int32_value']).eval(feed_dict={
+ values1: [[0.0]],
+ values2: [[0], [0]]
+ }))
+
+ def _testRoundtrip(self, in_bufs, message_type, fields):
+
+ field_names = [f.name for f in fields]
+ out_types = [f.dtype for f in fields]
+
+ with self.test_session() as sess:
+ sizes, field_tensors = proto.decode_proto(
+ in_bufs,
+ message_type=message_type,
+ field_names=field_names,
+ output_types=out_types)
+
+ out_tensors = proto.encode_proto(
+ sizes,
+ field_tensors,
+ message_type=message_type,
+ field_names=field_names)
+
+ out_bufs, = sess.run([out_tensors])
+
+ # Check that the re-encoded tensor has the same shape.
+ self.assertEqual(in_bufs.shape, out_bufs.shape)
+
+ # Compare the input and output.
+ for in_buf, out_buf in zip(in_bufs.flat, out_bufs.flat):
+ in_obj = test_example_pb2.RepeatedPrimitiveValue()
+ in_obj.ParseFromString(in_buf)
+
+ out_obj = test_example_pb2.RepeatedPrimitiveValue()
+ out_obj.ParseFromString(out_buf)
+
+ # Check that the deserialized objects are identical.
+ self.assertEqual(in_obj, out_obj)
+
+ # Check that the input and output serialized messages are identical.
+ # If we fail here, there is a difference in the serialized
+ # representation but the new serialization still parses. This could
+ # be harmless (a change in map ordering?) or it could be bad (e.g.
+ # loss of packing in the encoding).
+ self.assertEqual(in_buf, out_buf)
+
+ def testRoundtrip(self):
+ with open(FLAGS.message_text_file, 'r') as fp:
+ case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
+
+ in_bufs = [primitive.SerializeToString() for primitive in case.primitive]
+
+ # np.array silently truncates strings if you don't specify dtype=object.
+ in_bufs = np.reshape(np.array(in_bufs, dtype=object), list(case.shape))
+ return self._testRoundtrip(
+ in_bufs, 'tensorflow.contrib.proto.RepeatedPrimitiveValue', case.field)
+
+ def testRoundtripPacked(self):
+ with open(FLAGS.message_text_file, 'r') as fp:
+ case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
+
+ # Now try with the packed serialization.
+ # We test the packed representations by loading the same test cases
+ # using PackedPrimitiveValue instead of RepeatedPrimitiveValue.
+ # To do this we rely on the text format being the same for packed and
+ # unpacked fields, and reparse the test message using the packed version
+ # of the proto.
+ in_bufs = [
+ # Note: float_format='.17g' is necessary to ensure preservation of
+ # doubles and floats in text format.
+ text_format.Parse(
+ text_format.MessageToString(
+ primitive, float_format='.17g'),
+ test_example_pb2.PackedPrimitiveValue()).SerializeToString()
+ for primitive in case.primitive
+ ]
+
+ # np.array silently truncates strings if you don't specify dtype=object.
+ in_bufs = np.reshape(np.array(in_bufs, dtype=object), list(case.shape))
+ return self._testRoundtrip(
+ in_bufs, 'tensorflow.contrib.proto.PackedPrimitiveValue', case.field)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt
new file mode 100644
index 0000000000..b170f89c0f
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt
@@ -0,0 +1,161 @@
+primitive {
+ double_value: -1.7976931348623158e+308
+ double_value: 2.2250738585072014e-308
+ double_value: 1.7976931348623158e+308
+ float_value: -3.402823466e+38
+ float_value: 1.175494351e-38
+ float_value: 3.402823466e+38
+ int64_value: -9223372036854775808
+ int64_value: 9223372036854775807
+ uint64_value: 0
+ uint64_value: 18446744073709551615
+ int32_value: -2147483648
+ int32_value: 2147483647
+ fixed64_value: 0
+ fixed64_value: 18446744073709551615
+ fixed32_value: 0
+ fixed32_value: 4294967295
+ bool_value: false
+ bool_value: true
+ string_value: ""
+ string_value: "I refer to the infinite."
+ uint32_value: 0
+ uint32_value: 4294967295
+ sfixed32_value: -2147483648
+ sfixed32_value: 2147483647
+ sfixed64_value: -9223372036854775808
+ sfixed64_value: 9223372036854775807
+ sint32_value: -2147483648
+ sint32_value: 2147483647
+ sint64_value: -9223372036854775808
+ sint64_value: 9223372036854775807
+}
+shape: 1
+sizes: 3
+sizes: 3
+sizes: 2
+sizes: 2
+sizes: 2
+sizes: 2
+sizes: 2
+sizes: 2
+sizes: 2
+sizes: 2
+sizes: 2
+sizes: 2
+sizes: 2
+sizes: 2
+field {
+ name: "double_value"
+ dtype: DT_DOUBLE
+ expected {
+ double_value: -1.7976931348623158e+308
+ double_value: 2.2250738585072014e-308
+ double_value: 1.7976931348623158e+308
+ }
+}
+field {
+ name: "float_value"
+ dtype: DT_FLOAT
+ expected {
+ float_value: -3.402823466e+38
+ float_value: 1.175494351e-38
+ float_value: 3.402823466e+38
+ }
+}
+field {
+ name: "int64_value"
+ dtype: DT_INT64
+ expected {
+ int64_value: -9223372036854775808
+ int64_value: 9223372036854775807
+ }
+}
+field {
+ name: "uint64_value"
+ dtype: DT_INT64
+ expected {
+ int64_value: 0
+ int64_value: -1
+ }
+}
+field {
+ name: "int32_value"
+ dtype: DT_INT32
+ expected {
+ int32_value: -2147483648
+ int32_value: 2147483647
+ }
+}
+field {
+ name: "fixed64_value"
+ dtype: DT_INT64
+ expected {
+ int64_value: 0
+ int64_value: -1 # unsigned is 18446744073709551615
+ }
+}
+field {
+ name: "fixed32_value"
+ dtype: DT_INT32
+ expected {
+ int32_value: 0
+ int32_value: -1 # unsigned is 4294967295
+ }
+}
+field {
+ name: "bool_value"
+ dtype: DT_BOOL
+ expected {
+ bool_value: false
+ bool_value: true
+ }
+}
+field {
+ name: "string_value"
+ dtype: DT_STRING
+ expected {
+ string_value: ""
+ string_value: "I refer to the infinite."
+ }
+}
+field {
+ name: "uint32_value"
+ dtype: DT_INT32
+ expected {
+ int32_value: 0
+ int32_value: -1 # unsigned is 4294967295
+ }
+}
+field {
+ name: "sfixed32_value"
+ dtype: DT_INT32
+ expected {
+ int32_value: -2147483648
+ int32_value: 2147483647
+ }
+}
+field {
+ name: "sfixed64_value"
+ dtype: DT_INT64
+ expected {
+ int64_value: -9223372036854775808
+ int64_value: 9223372036854775807
+ }
+}
+field {
+ name: "sint32_value"
+ dtype: DT_INT32
+ expected {
+ int32_value: -2147483648
+ int32_value: 2147483647
+ }
+}
+field {
+ name: "sint64_value"
+ dtype: DT_INT64
+ expected {
+ int64_value: -9223372036854775808
+ int64_value: 9223372036854775807
+ }
+}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt
new file mode 100644
index 0000000000..c664e52851
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt
@@ -0,0 +1,16 @@
+primitive {
+ message_value {
+ double_value: 23.5
+ }
+}
+shape: 1
+sizes: 1
+field {
+ name: "message_value"
+ dtype: DT_STRING
+ expected {
+ message_value {
+ double_value: 23.5
+ }
+ }
+}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt
new file mode 100644
index 0000000000..125651d7ea
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt
@@ -0,0 +1,20 @@
+primitive {
+ bool_value: true
+}
+shape: 1
+sizes: 1
+sizes: 0
+field {
+ name: "bool_value"
+ dtype: DT_BOOL
+ expected {
+ bool_value: true
+ }
+}
+field {
+ name: "double_value"
+ dtype: DT_DOUBLE
+ expected {
+ double_value: 0.0
+ }
+}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt
new file mode 100644
index 0000000000..db7555bf2d
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt
@@ -0,0 +1,21 @@
+primitive {
+ fixed32_value: 4294967295
+ uint32_value: 4294967295
+}
+shape: 1
+sizes: 1
+sizes: 1
+field {
+ name: "fixed32_value"
+ dtype: DT_INT64
+ expected {
+ int64_value: 4294967295
+ }
+}
+field {
+ name: "uint32_value"
+ dtype: DT_INT64
+ expected {
+ int64_value: 4294967295
+ }
+}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt
new file mode 100644
index 0000000000..61c7ac53f7
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt
@@ -0,0 +1,32 @@
+primitive {
+ double_value: 23.5
+ double_value: 123.0
+ bool_value: true
+}
+primitive {
+ double_value: 3.1
+ bool_value: false
+}
+shape: 2
+sizes: 2
+sizes: 1
+sizes: 1
+sizes: 1
+field {
+ name: "double_value"
+ dtype: DT_DOUBLE
+ expected {
+ double_value: 23.5
+ double_value: 123.0
+ double_value: 3.1
+ double_value: 0.0
+ }
+}
+field {
+ name: "bool_value"
+ dtype: DT_BOOL
+ expected {
+ bool_value: true
+ bool_value: false
+ }
+}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt
new file mode 100644
index 0000000000..f4828076d5
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt
@@ -0,0 +1,62 @@
+primitive {
+ double_value: 23.5
+ bool_value: true
+}
+primitive {
+ double_value: 44.0
+ bool_value: false
+}
+primitive {
+ double_value: 3.14159
+ bool_value: true
+}
+primitive {
+ double_value: 1.414
+ bool_value: true
+}
+primitive {
+ double_value: -32.2
+ bool_value: false
+}
+primitive {
+ double_value: 0.0001
+ bool_value: true
+}
+shape: 3
+shape: 2
+sizes: 1
+sizes: 1
+sizes: 1
+sizes: 1
+sizes: 1
+sizes: 1
+sizes: 1
+sizes: 1
+sizes: 1
+sizes: 1
+sizes: 1
+sizes: 1
+field {
+ name: "double_value"
+ dtype: DT_DOUBLE
+ expected {
+ double_value: 23.5
+ double_value: 44.0
+ double_value: 3.14159
+ double_value: 1.414
+ double_value: -32.2
+ double_value: 0.0001
+ }
+}
+field {
+ name: "bool_value"
+ dtype: DT_BOOL
+ expected {
+ bool_value: true
+ bool_value: false
+ bool_value: true
+ bool_value: true
+ bool_value: false
+ bool_value: true
+ }
+}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt
new file mode 100644
index 0000000000..dc20ac147b
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt
@@ -0,0 +1,21 @@
+primitive {
+ double_value: 23.5
+ bool_value: true
+}
+shape: 1
+sizes: 1
+sizes: 1
+field {
+ name: "double_value"
+ dtype: DT_DOUBLE
+ expected {
+ double_value: 23.5
+ }
+}
+field {
+ name: "bool_value"
+ dtype: DT_BOOL
+ expected {
+ bool_value: true
+ }
+}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/test_case.py b/tensorflow/contrib/proto/python/kernel_tests/test_case.py
new file mode 100644
index 0000000000..b95202c5df
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/test_case.py
@@ -0,0 +1,35 @@
+# =============================================================================
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Test case base for testing proto operations."""
+
+# Python3 preparedness imports.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ctypes as ct
+import os
+
+from tensorflow.python.platform import test
+
+
+class ProtoOpTestCase(test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ super(ProtoOpTestCase, self).__init__(methodName)
+ lib = os.path.join(os.path.dirname(__file__), 'libtestexample.so')
+ if os.path.isfile(lib):
+ ct.cdll.LoadLibrary(lib)
diff --git a/tensorflow/contrib/proto/python/kernel_tests/test_example.proto b/tensorflow/contrib/proto/python/kernel_tests/test_example.proto
new file mode 100644
index 0000000000..dc495034ff
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/test_example.proto
@@ -0,0 +1,149 @@
+// Test description and protos to work with it.
+//
+// Many of the protos in this file are for unit tests that haven't been written yet.
+
+syntax = "proto2";
+
+import "tensorflow/core/framework/types.proto";
+
+package tensorflow.contrib.proto;
+
+// A TestCase holds a proto and a bunch of assertions
+// about how it should decode.
+message TestCase {
+ // A batch of primitives to be serialized and decoded.
+ repeated RepeatedPrimitiveValue primitive = 1;
+ // The shape of the batch.
+ repeated int32 shape = 2;
+ // Expected sizes for each field.
+ repeated int32 sizes = 3;
+ // Expected values for each field.
+ repeated FieldSpec field = 4;
+};
+
+// FieldSpec describes the expected output for a single field.
+message FieldSpec {
+ optional string name = 1;
+ optional tensorflow.DataType dtype = 2;
+ optional RepeatedPrimitiveValue expected = 3;
+};
+
+message TestValue {
+ optional PrimitiveValue primitive_value = 1;
+ optional EnumValue enum_value = 2;
+ optional MessageValue message_value = 3;
+ optional RepeatedMessageValue repeated_message_value = 4;
+ optional RepeatedPrimitiveValue repeated_primitive_value = 6;
+}
+
+message PrimitiveValue {
+ optional double double_value = 1;
+ optional float float_value = 2;
+ optional int64 int64_value = 3;
+ optional uint64 uint64_value = 4;
+ optional int32 int32_value = 5;
+ optional fixed64 fixed64_value = 6;
+ optional fixed32 fixed32_value = 7;
+ optional bool bool_value = 8;
+ optional string string_value = 9;
+ optional bytes bytes_value = 12;
+ optional uint32 uint32_value = 13;
+ optional sfixed32 sfixed32_value = 15;
+ optional sfixed64 sfixed64_value = 16;
+ optional sint32 sint32_value = 17;
+ optional sint64 sint64_value = 18;
+}
+
+// NOTE: This definition must be kept in sync with PackedPrimitiveValue.
+message RepeatedPrimitiveValue {
+ repeated double double_value = 1;
+ repeated float float_value = 2;
+ repeated int64 int64_value = 3;
+ repeated uint64 uint64_value = 4;
+ repeated int32 int32_value = 5;
+ repeated fixed64 fixed64_value = 6;
+ repeated fixed32 fixed32_value = 7;
+ repeated bool bool_value = 8;
+ repeated string string_value = 9;
+ repeated bytes bytes_value = 12;
+ repeated uint32 uint32_value = 13;
+ repeated sfixed32 sfixed32_value = 15;
+ repeated sfixed64 sfixed64_value = 16;
+ repeated sint32 sint32_value = 17;
+ repeated sint64 sint64_value = 18;
+ repeated PrimitiveValue message_value = 19;
+}
+
+// A PackedPrimitiveValue looks exactly the same as a RepeatedPrimitiveValue
+// in the text format, but the binary serializion is different.
+// We test the packed representations by loading the same test cases
+// using this definition instead of RepeatedPrimitiveValue.
+// NOTE: This definition must be kept in sync with RepeatedPrimitiveValue
+// in every way except the packed=true declaration.
+message PackedPrimitiveValue {
+ repeated double double_value = 1 [packed = true];
+ repeated float float_value = 2 [packed = true];
+ repeated int64 int64_value = 3 [packed = true];
+ repeated uint64 uint64_value = 4 [packed = true];
+ repeated int32 int32_value = 5 [packed = true];
+ repeated fixed64 fixed64_value = 6 [packed = true];
+ repeated fixed32 fixed32_value = 7 [packed = true];
+ repeated bool bool_value = 8 [packed = true];
+ repeated string string_value = 9;
+ repeated bytes bytes_value = 12;
+ repeated uint32 uint32_value = 13 [packed = true];
+ repeated sfixed32 sfixed32_value = 15 [packed = true];
+ repeated sfixed64 sfixed64_value = 16 [packed = true];
+ repeated sint32 sint32_value = 17 [packed = true];
+ repeated sint64 sint64_value = 18 [packed = true];
+ repeated PrimitiveValue message_value = 19;
+}
+
+message EnumValue {
+ enum Color {
+ RED = 0;
+ ORANGE = 1;
+ YELLOW = 2;
+ GREEN = 3;
+ BLUE = 4;
+ INDIGO = 5;
+ VIOLET = 6;
+ };
+ optional Color enum_value = 14;
+ repeated Color repeated_enum_value = 15;
+}
+
+
+message InnerMessageValue {
+ optional float float_value = 2;
+ repeated bytes bytes_values = 8;
+}
+
+message MiddleMessageValue {
+ repeated int32 int32_values = 5;
+ optional InnerMessageValue message_value = 11;
+ optional uint32 uint32_value = 13;
+}
+
+message MessageValue {
+ optional double double_value = 1;
+ optional MiddleMessageValue message_value = 11;
+}
+
+message RepeatedMessageValue {
+ message NestedMessageValue {
+ optional float float_value = 2;
+ repeated bytes bytes_values = 8;
+ }
+
+ repeated NestedMessageValue message_values = 11;
+}
+
+// Message containing fields with field numbers higher than any field above. An
+// instance of this message is prepended to each binary message in the test to
+// exercise the code path that handles fields encoded out of order of field
+// number.
+message ExtraFields {
+ optional string string_value = 1776;
+ optional bool bool_value = 1777;
+}