aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/proto
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/proto')
-rw-r--r--tensorflow/contrib/proto/BUILD14
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/BUILD113
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl89
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py68
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py275
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py303
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/defaut_values.TestCase.pbtxt94
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/descriptor_source_test.py (renamed from tensorflow/contrib/proto/python/kernel_tests/test_case.py)21
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/descriptor_source_test_base.py176
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py155
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py177
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt161
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt16
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt20
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt29
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py419
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt32
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt62
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt21
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/test_example.proto159
20 files changed, 1251 insertions, 1153 deletions
diff --git a/tensorflow/contrib/proto/BUILD b/tensorflow/contrib/proto/BUILD
index 3e9b1a0b8d..b27142cf4a 100644
--- a/tensorflow/contrib/proto/BUILD
+++ b/tensorflow/contrib/proto/BUILD
@@ -16,17 +16,3 @@ py_library(
"//tensorflow/contrib/proto/python/ops:encode_proto_op_py",
],
)
-
-py_library(
- name = "proto_pip",
- data = [
- "//tensorflow/contrib/proto/python/kernel_tests:test_messages",
- ] + if_static(
- [],
- otherwise = ["//tensorflow/contrib/proto/python/kernel_tests:libtestexample.so"],
- ),
- deps = [
- ":proto",
- "//tensorflow/contrib/proto/python/kernel_tests:py_test_deps",
- ],
-)
diff --git a/tensorflow/contrib/proto/python/kernel_tests/BUILD b/tensorflow/contrib/proto/python/kernel_tests/BUILD
index a380a131f8..125c1cee29 100644
--- a/tensorflow/contrib/proto/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/proto/python/kernel_tests/BUILD
@@ -4,47 +4,41 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
-# Much of the work in this BUILD file actually happens in the corresponding
-# build_defs.bzl, which creates an individual testcase for each example .pbtxt
-# file in this directory.
-#
-load(":build_defs.bzl", "decode_proto_test_suite")
-load(":build_defs.bzl", "encode_proto_test_suite")
-
-# This expands to a tf_py_test for each test file.
-# It defines the test_suite :decode_proto_op_tests.
-decode_proto_test_suite(
- name = "decode_proto_tests",
- examples = glob(["*.pbtxt"]),
-)
-
-# This expands to a tf_py_test for each test file.
-# It defines the test_suite :encode_proto_op_tests.
-encode_proto_test_suite(
- name = "encode_proto_tests",
- examples = glob(["*.pbtxt"]),
-)
-
-# Below here are tests that are not tied to an example text proto.
-filegroup(
- name = "test_messages",
- srcs = glob(["*.pbtxt"]),
-)
-
load("//tensorflow:tensorflow.bzl", "tf_py_test")
load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
tf_py_test(
- name = "decode_proto_fail_test",
+ name = "decode_proto_op_test",
size = "small",
- srcs = ["decode_proto_fail_test.py"],
+ srcs = ["decode_proto_op_test.py"],
additional_deps = [
+ ":decode_proto_op_test_base",
+ ":py_test_deps",
+ "//tensorflow/contrib/proto:proto",
+ "//tensorflow/contrib/proto/python/ops:decode_proto_op_py",
+ ],
+ data = if_static(
+ [],
+ otherwise = [":libtestexample.so"],
+ ),
+ tags = [
+ "no_pip", # TODO(b/78026780)
+ "no_windows", # TODO(b/78028010)
+ ],
+)
+
+tf_py_test(
+ name = "encode_proto_op_test",
+ size = "small",
+ srcs = ["encode_proto_op_test.py"],
+ additional_deps = [
+ ":encode_proto_op_test_base",
":py_test_deps",
- "//third_party/py/numpy",
"//tensorflow/contrib/proto:proto",
"//tensorflow/contrib/proto/python/ops:decode_proto_op_py",
+ "//tensorflow/contrib/proto/python/ops:encode_proto_op_py",
],
data = if_static(
[],
@@ -57,19 +51,41 @@ tf_py_test(
)
py_library(
- name = "test_case",
- srcs = ["test_case.py"],
- deps = ["//tensorflow/python:client_testlib"],
+ name = "proto_op_test_base",
+ testonly = 1,
+ srcs = ["proto_op_test_base.py"],
+ deps = [
+ ":test_example_proto_py",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_library(
+ name = "decode_proto_op_test_base",
+ testonly = 1,
+ srcs = ["decode_proto_op_test_base.py"],
+ deps = [
+ ":proto_op_test_base",
+ ":test_example_proto_py",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
)
py_library(
- name = "py_test_deps",
+ name = "encode_proto_op_test_base",
+ testonly = 1,
+ srcs = ["encode_proto_op_test_base.py"],
deps = [
- ":test_case",
+ ":proto_op_test_base",
":test_example_proto_py",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
+py_library(name = "py_test_deps")
+
tf_proto_library(
name = "test_example_proto",
srcs = ["test_example.proto"],
@@ -84,3 +100,30 @@ tf_cc_shared_object(
":test_example_proto_cc",
],
)
+
+py_library(
+ name = "descriptor_source_test_base",
+ testonly = 1,
+ srcs = ["descriptor_source_test_base.py"],
+ deps = [
+ ":proto_op_test_base",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ "@protobuf_archive//:protobuf_python",
+ ],
+)
+
+tf_py_test(
+ name = "descriptor_source_test",
+ size = "small",
+ srcs = ["descriptor_source_test.py"],
+ additional_deps = [
+ ":descriptor_source_test_base",
+ "//tensorflow/contrib/proto/python/ops:decode_proto_op_py",
+ "//tensorflow/contrib/proto/python/ops:encode_proto_op_py",
+ "//tensorflow/python:client_testlib",
+ ],
+ tags = [
+ "no_pip",
+ ],
+)
diff --git a/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl b/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl
deleted file mode 100644
index f425601691..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl
+++ /dev/null
@@ -1,89 +0,0 @@
-"""BUILD rules for generating file-driven proto test cases.
-
-The decode_proto_test_suite() and encode_proto_test_suite() rules take a list
-of text protos and generates a tf_py_test() for each one.
-"""
-
-load("//tensorflow:tensorflow.bzl", "tf_py_test")
-load("//tensorflow:tensorflow.bzl", "register_extension_info")
-load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
-
-def _test_name(test, path):
- return "%s_%s_test" % (test, path.split("/")[-1].split(".")[0])
-
-def decode_proto_test_suite(name, examples):
- """Build the decode_proto py_test for each test filename."""
- for test_filename in examples:
- tf_py_test(
- name = _test_name("decode_proto", test_filename),
- srcs = ["decode_proto_op_test.py"],
- size = "small",
- data = [test_filename] + if_static(
- [],
- otherwise = [":libtestexample.so"],
- ),
- main = "decode_proto_op_test.py",
- args = [
- "--message_text_file=\"%s/%s\"" % (native.package_name(), test_filename),
- ],
- additional_deps = [
- ":py_test_deps",
- "//third_party/py/numpy",
- "//tensorflow/contrib/proto:proto",
- "//tensorflow/contrib/proto/python/ops:decode_proto_op_py",
- ],
- tags = [
- "no_pip", # TODO(b/78026780)
- "no_windows", # TODO(b/78028010)
- ],
- )
- native.test_suite(
- name = name,
- tests = [":" + _test_name("decode_proto", test_filename)
- for test_filename in examples],
- )
-
-def encode_proto_test_suite(name, examples):
- """Build the encode_proto py_test for each test filename."""
- for test_filename in examples:
- tf_py_test(
- name = _test_name("encode_proto", test_filename),
- srcs = ["encode_proto_op_test.py"],
- size = "small",
- data = [test_filename] + if_static(
- [],
- otherwise = [":libtestexample.so"],
- ),
- main = "encode_proto_op_test.py",
- args = [
- "--message_text_file=\"%s/%s\"" % (native.package_name(), test_filename),
- ],
- additional_deps = [
- ":py_test_deps",
- "//third_party/py/numpy",
- "//tensorflow/contrib/proto:proto",
- "//tensorflow/contrib/proto/python/ops:decode_proto_op_py",
- "//tensorflow/contrib/proto/python/ops:encode_proto_op_py",
- ],
- tags = [
- "no_pip", # TODO(b/78026780)
- "no_windows", # TODO(b/78028010)
- ],
- )
- native.test_suite(
- name = name,
- tests = [":" + _test_name("encode_proto", test_filename)
- for test_filename in examples],
- )
-
-register_extension_info(
- extension_name = "decode_proto_test_suite",
- label_regex_map = {
- "deps": "deps:decode_example_.*",
- })
-
-register_extension_info(
- extension_name = "encode_proto_test_suite",
- label_regex_map = {
- "deps": "deps:encode_example_.*",
- })
diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py
deleted file mode 100644
index 5298342ee7..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py
+++ /dev/null
@@ -1,68 +0,0 @@
-# =============================================================================
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# =============================================================================
-
-# Python3 preparedness imports.
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.proto.python.kernel_tests import test_case
-from tensorflow.contrib.proto.python.ops import decode_proto_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.platform import test
-
-
-class DecodeProtoFailTest(test_case.ProtoOpTestCase):
- """Test failure cases for DecodeToProto."""
-
- def _TestCorruptProtobuf(self, sanitize):
- """Test failure cases for DecodeToProto."""
-
- # The goal here is to check the error reporting.
- # Testing against a variety of corrupt protobufs is
- # done by fuzzing.
- corrupt_proto = 'This is not a binary protobuf'
-
- # Numpy silently truncates the strings if you don't specify dtype=object.
- batch = np.array(corrupt_proto, dtype=object)
- msg_type = 'tensorflow.contrib.proto.TestCase'
- field_names = ['sizes']
- field_types = [dtypes.int32]
-
- with self.test_session() as sess:
- ctensor, vtensor = decode_proto_op.decode_proto(
- batch,
- message_type=msg_type,
- field_names=field_names,
- output_types=field_types,
- sanitize=sanitize)
- with self.assertRaisesRegexp(errors.DataLossError,
- 'Unable to parse binary protobuf'
- '|Failed to consume entire buffer'):
- _ = sess.run([ctensor] + vtensor)
-
- def testCorrupt(self):
- self._TestCorruptProtobuf(sanitize=False)
-
- def testSanitizerCorrupt(self):
- self._TestCorruptProtobuf(sanitize=True)
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py
index d1c13c82bc..934035ec4c 100644
--- a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py
+++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py
@@ -13,287 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
-"""Table-driven test for decode_proto op.
+"""Tests for decode_proto op."""
-This test is run once with each of the *.TestCase.pbtxt files
-in the test directory.
-"""
# Python3 preparedness imports.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-
-from google.protobuf import text_format
-
-from tensorflow.contrib.proto.python.kernel_tests import test_case
-from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2
+from tensorflow.contrib.proto.python.kernel_tests import decode_proto_op_test_base as test_base
from tensorflow.contrib.proto.python.ops import decode_proto_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.platform import flags
from tensorflow.python.platform import test
-FLAGS = flags.FLAGS
-
-flags.DEFINE_string('message_text_file', None,
- 'A file containing a text serialized TestCase protobuf.')
-
-
-class DecodeProtoOpTest(test_case.ProtoOpTestCase):
-
- def _compareValues(self, fd, vs, evs):
- """Compare lists/arrays of field values."""
-
- if len(vs) != len(evs):
- self.fail('Field %s decoded %d outputs, expected %d' %
- (fd.name, len(vs), len(evs)))
- for i, ev in enumerate(evs):
- # Special case fuzzy match for float32. TensorFlow seems to mess with
- # MAX_FLT slightly and the test doesn't work otherwise.
- # TODO(nix): ask on TF list about why MAX_FLT doesn't pass through.
- if fd.cpp_type == fd.CPPTYPE_FLOAT:
- # Numpy isclose() is better than assertIsClose() which uses an absolute
- # value comparison.
- self.assertTrue(
- np.isclose(vs[i], ev), 'expected %r, actual %r' % (ev, vs[i]))
- elif fd.cpp_type == fd.CPPTYPE_STRING:
- # In Python3 string tensor values will be represented as bytes, so we
- # reencode the proto values to match that.
- self.assertEqual(vs[i], ev.encode('ascii'))
- else:
- # Doubles and other types pass through unscathed.
- self.assertEqual(vs[i], ev)
-
- def _compareRepeatedPrimitiveValue(self, batch_shape, sizes, fields,
- field_dict):
- """Compare protos of type RepeatedPrimitiveValue.
-
- Args:
- batch_shape: the shape of the input tensor of serialized messages.
- sizes: int matrix of repeat counts returned by decode_proto
- fields: list of test_example_pb2.FieldSpec (types and expected values)
- field_dict: map from field names to decoded numpy tensors of values
- """
-
- # Check that expected values match.
- for field in fields:
- values = field_dict[field.name]
- self.assertEqual(dtypes.as_dtype(values.dtype), field.dtype)
-
- fd = field.expected.DESCRIPTOR.fields_by_name[field.name]
-
- # Values has the same shape as the input plus an extra
- # dimension for repeats.
- self.assertEqual(list(values.shape)[:-1], batch_shape)
-
- # Nested messages are represented as TF strings, requiring
- # some special handling.
- if field.name == 'message_value':
- vs = []
- for buf in values.flat:
- msg = test_example_pb2.PrimitiveValue()
- msg.ParseFromString(buf)
- vs.append(msg)
- evs = getattr(field.expected, field.name)
- if len(vs) != len(evs):
- self.fail('Field %s decoded %d outputs, expected %d' %
- (fd.name, len(vs), len(evs)))
- for v, ev in zip(vs, evs):
- self.assertEqual(v, ev)
- continue
-
- # This can be a little confusing. For testing we are using
- # RepeatedPrimitiveValue in two ways: it's the proto that we
- # decode for testing, and it's used in the expected value as a
- # union type. The two cases are slightly different: this is the
- # second case.
- # We may be fetching the uint64_value from the test proto, but
- # in the expected proto we store it in the int64_value field
- # because TensorFlow doesn't support unsigned int64.
- tf_type_to_primitive_value_field = {
- dtypes.float32:
- 'float_value',
- dtypes.float64:
- 'double_value',
- dtypes.int32:
- 'int32_value',
- dtypes.uint8:
- 'uint8_value',
- dtypes.int8:
- 'int8_value',
- dtypes.string:
- 'string_value',
- dtypes.int64:
- 'int64_value',
- dtypes.bool:
- 'bool_value',
- # Unhandled TensorFlow types:
- # DT_INT16 DT_COMPLEX64 DT_QINT8 DT_QUINT8 DT_QINT32
- # DT_BFLOAT16 DT_QINT16 DT_QUINT16 DT_UINT16
- }
- tf_field_name = tf_type_to_primitive_value_field.get(field.dtype)
- if tf_field_name is None:
- self.fail('Unhandled tensorflow type %d' % field.dtype)
-
- self._compareValues(fd, values.flat,
- getattr(field.expected, tf_field_name))
-
- def _runDecodeProtoTests(self, fields, case_sizes, batch_shape, batch,
- message_type, message_format, sanitize,
- force_disordered=False):
- """Run decode tests on a batch of messages.
-
- Args:
- fields: list of test_example_pb2.FieldSpec (types and expected values)
- case_sizes: expected sizes array
- batch_shape: the shape of the input tensor of serialized messages
- batch: list of serialized messages
- message_type: descriptor name for messages
- message_format: format of messages, 'text' or 'binary'
- sanitize: whether to sanitize binary protobuf inputs
- force_disordered: whether to force fields encoded out of order.
- """
-
- if force_disordered:
- # Exercise code path that handles out-of-order fields by prepending extra
- # fields with tag numbers higher than any real field. Note that this won't
- # work with sanitization because that forces reserialization using a
- # trusted decoder and encoder.
- assert not sanitize
- extra_fields = test_example_pb2.ExtraFields()
- extra_fields.string_value = 'IGNORE ME'
- extra_fields.bool_value = False
- extra_msg = extra_fields.SerializeToString()
- batch = [extra_msg + msg for msg in batch]
-
- # Numpy silently truncates the strings if you don't specify dtype=object.
- batch = np.array(batch, dtype=object)
- batch = np.reshape(batch, batch_shape)
-
- field_names = [f.name for f in fields]
- output_types = [f.dtype for f in fields]
-
- with self.test_session() as sess:
- sizes, vtensor = decode_proto_op.decode_proto(
- batch,
- message_type=message_type,
- field_names=field_names,
- output_types=output_types,
- message_format=message_format,
- sanitize=sanitize)
-
- vlist = sess.run([sizes] + vtensor)
- sizes = vlist[0]
- # Values is a list of tensors, one for each field.
- value_tensors = vlist[1:]
-
- # Check that the repeat sizes are correct.
- self.assertTrue(
- np.all(np.array(sizes.shape) == batch_shape + [len(field_names)]))
-
- # Check that the decoded sizes match the expected sizes.
- self.assertEqual(len(sizes.flat), len(case_sizes))
- self.assertTrue(
- np.all(sizes.flat == np.array(
- case_sizes, dtype=np.int32)))
-
- field_dict = dict(zip(field_names, value_tensors))
-
- self._compareRepeatedPrimitiveValue(batch_shape, sizes, fields,
- field_dict)
-
- def testBinary(self):
- with open(FLAGS.message_text_file, 'r') as fp:
- case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
-
- batch = [primitive.SerializeToString() for primitive in case.primitive]
- self._runDecodeProtoTests(
- case.field,
- case.sizes,
- list(case.shape),
- batch,
- 'tensorflow.contrib.proto.RepeatedPrimitiveValue',
- 'binary',
- sanitize=False)
-
- def testBinaryDisordered(self):
- with open(FLAGS.message_text_file, 'r') as fp:
- case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
-
- batch = [primitive.SerializeToString() for primitive in case.primitive]
- self._runDecodeProtoTests(
- case.field,
- case.sizes,
- list(case.shape),
- batch,
- 'tensorflow.contrib.proto.RepeatedPrimitiveValue',
- 'binary',
- sanitize=False,
- force_disordered=True)
-
- def testPacked(self):
- with open(FLAGS.message_text_file, 'r') as fp:
- case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
-
- # Now try with the packed serialization.
- # We test the packed representations by loading the same test cases
- # using PackedPrimitiveValue instead of RepeatedPrimitiveValue.
- # To do this we rely on the text format being the same for packed and
- # unpacked fields, and reparse the test message using the packed version
- # of the proto.
- packed_batch = [
- # Note: float_format='.17g' is necessary to ensure preservation of
- # doubles and floats in text format.
- text_format.Parse(
- text_format.MessageToString(
- primitive, float_format='.17g'),
- test_example_pb2.PackedPrimitiveValue()).SerializeToString()
- for primitive in case.primitive
- ]
-
- self._runDecodeProtoTests(
- case.field,
- case.sizes,
- list(case.shape),
- packed_batch,
- 'tensorflow.contrib.proto.PackedPrimitiveValue',
- 'binary',
- sanitize=False)
-
- def testText(self):
- with open(FLAGS.message_text_file, 'r') as fp:
- case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
-
- # Note: float_format='.17g' is necessary to ensure preservation of
- # doubles and floats in text format.
- text_batch = [
- text_format.MessageToString(
- primitive, float_format='.17g') for primitive in case.primitive
- ]
-
- self._runDecodeProtoTests(
- case.field,
- case.sizes,
- list(case.shape),
- text_batch,
- 'tensorflow.contrib.proto.RepeatedPrimitiveValue',
- 'text',
- sanitize=False)
- def testSanitizerGood(self):
- with open(FLAGS.message_text_file, 'r') as fp:
- case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
+class DecodeProtoOpTest(test_base.DecodeProtoOpTestBase):
- batch = [primitive.SerializeToString() for primitive in case.primitive]
- self._runDecodeProtoTests(
- case.field,
- case.sizes,
- list(case.shape),
- batch,
- 'tensorflow.contrib.proto.RepeatedPrimitiveValue',
- 'binary',
- sanitize=True)
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ super(DecodeProtoOpTest, self).__init__(decode_proto_op, methodName)
if __name__ == '__main__':
diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py
new file mode 100644
index 0000000000..17b69c7b35
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py
@@ -0,0 +1,303 @@
+# =============================================================================
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Tests for decode_proto op."""
+
+# Python3 preparedness imports.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+
+from google.protobuf import text_format
+
+from tensorflow.contrib.proto.python.kernel_tests import proto_op_test_base as test_base
+from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+
+
+class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase):
+ """Base class for testing proto decoding ops."""
+
+ def __init__(self, decode_module, methodName='runTest'): # pylint: disable=invalid-name
+ """DecodeProtoOpTestBase initializer.
+
+ Args:
+ decode_module: a module containing the `decode_proto_op` method
+ methodName: the name of the test method (same as for test.TestCase)
+ """
+
+ super(DecodeProtoOpTestBase, self).__init__(methodName)
+ self._decode_module = decode_module
+
+ def _compareValues(self, fd, vs, evs):
+ """Compare lists/arrays of field values."""
+
+ if len(vs) != len(evs):
+ self.fail('Field %s decoded %d outputs, expected %d' %
+ (fd.name, len(vs), len(evs)))
+ for i, ev in enumerate(evs):
+ # Special case fuzzy match for float32. TensorFlow seems to mess with
+ # MAX_FLT slightly and the test doesn't work otherwise.
+ # TODO(nix): ask on TF list about why MAX_FLT doesn't pass through.
+ if fd.cpp_type == fd.CPPTYPE_FLOAT:
+ # Numpy isclose() is better than assertIsClose() which uses an absolute
+ # value comparison.
+ self.assertTrue(
+ np.isclose(vs[i], ev), 'expected %r, actual %r' % (ev, vs[i]))
+ elif fd.cpp_type == fd.CPPTYPE_STRING:
+ # In Python3 string tensor values will be represented as bytes, so we
+ # reencode the proto values to match that.
+ self.assertEqual(vs[i], ev.encode('ascii'))
+ else:
+ # Doubles and other types pass through unscathed.
+ self.assertEqual(vs[i], ev)
+
+ def _compareProtos(self, batch_shape, sizes, fields, field_dict):
+ """Compare protos of type TestValue.
+
+ Args:
+ batch_shape: the shape of the input tensor of serialized messages.
+ sizes: int matrix of repeat counts returned by decode_proto
+ fields: list of test_example_pb2.FieldSpec (types and expected values)
+ field_dict: map from field names to decoded numpy tensors of values
+ """
+
+ # Check that expected values match.
+ for field in fields:
+ values = field_dict[field.name]
+ self.assertEqual(dtypes.as_dtype(values.dtype), field.dtype)
+
+ fd = field.value.DESCRIPTOR.fields_by_name[field.name]
+
+ # Values has the same shape as the input plus an extra
+ # dimension for repeats.
+ self.assertEqual(list(values.shape)[:-1], batch_shape)
+
+ # Nested messages are represented as TF strings, requiring
+ # some special handling.
+ if field.name == 'message_value':
+ vs = []
+ for buf in values.flat:
+ msg = test_example_pb2.PrimitiveValue()
+ msg.ParseFromString(buf)
+ vs.append(msg)
+ evs = getattr(field.value, field.name)
+ if len(vs) != len(evs):
+ self.fail('Field %s decoded %d outputs, expected %d' %
+ (fd.name, len(vs), len(evs)))
+ for v, ev in zip(vs, evs):
+ self.assertEqual(v, ev)
+ continue
+
+ tf_type_to_primitive_value_field = {
+ dtypes.bool:
+ 'bool_value',
+ dtypes.float32:
+ 'float_value',
+ dtypes.float64:
+ 'double_value',
+ dtypes.int8:
+ 'int8_value',
+ dtypes.int32:
+ 'int32_value',
+ dtypes.int64:
+ 'int64_value',
+ dtypes.string:
+ 'string_value',
+ dtypes.uint8:
+ 'uint8_value',
+ dtypes.uint32:
+ 'uint32_value',
+ dtypes.uint64:
+ 'uint64_value',
+ }
+ tf_field_name = tf_type_to_primitive_value_field.get(field.dtype)
+ if tf_field_name is None:
+ self.fail('Unhandled tensorflow type %d' % field.dtype)
+
+ self._compareValues(fd, values.flat,
+ getattr(field.value, tf_field_name))
+
+ def _runDecodeProtoTests(self, fields, case_sizes, batch_shape, batch,
+ message_type, message_format, sanitize,
+ force_disordered=False):
+ """Run decode tests on a batch of messages.
+
+ Args:
+ fields: list of test_example_pb2.FieldSpec (types and expected values)
+ case_sizes: expected sizes array
+ batch_shape: the shape of the input tensor of serialized messages
+ batch: list of serialized messages
+ message_type: descriptor name for messages
+ message_format: format of messages, 'text' or 'binary'
+ sanitize: whether to sanitize binary protobuf inputs
+ force_disordered: whether to force fields encoded out of order.
+ """
+
+ if force_disordered:
+ # Exercise code path that handles out-of-order fields by prepending extra
+ # fields with tag numbers higher than any real field. Note that this won't
+ # work with sanitization because that forces reserialization using a
+ # trusted decoder and encoder.
+ assert not sanitize
+ extra_fields = test_example_pb2.ExtraFields()
+ extra_fields.string_value = 'IGNORE ME'
+ extra_fields.bool_value = False
+ extra_msg = extra_fields.SerializeToString()
+ batch = [extra_msg + msg for msg in batch]
+
+ # Numpy silently truncates the strings if you don't specify dtype=object.
+ batch = np.array(batch, dtype=object)
+ batch = np.reshape(batch, batch_shape)
+
+ field_names = [f.name for f in fields]
+ output_types = [f.dtype for f in fields]
+
+ with self.cached_session() as sess:
+ sizes, vtensor = self._decode_module.decode_proto(
+ batch,
+ message_type=message_type,
+ field_names=field_names,
+ output_types=output_types,
+ message_format=message_format,
+ sanitize=sanitize)
+
+ vlist = sess.run([sizes] + vtensor)
+ sizes = vlist[0]
+ # Values is a list of tensors, one for each field.
+ value_tensors = vlist[1:]
+
+ # Check that the repeat sizes are correct.
+ self.assertTrue(
+ np.all(np.array(sizes.shape) == batch_shape + [len(field_names)]))
+
+ # Check that the decoded sizes match the expected sizes.
+ self.assertEqual(len(sizes.flat), len(case_sizes))
+ self.assertTrue(
+ np.all(sizes.flat == np.array(
+ case_sizes, dtype=np.int32)))
+
+ field_dict = dict(zip(field_names, value_tensors))
+
+ self._compareProtos(batch_shape, sizes, fields, field_dict)
+
+ @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
+ def testBinary(self, case):
+ batch = [value.SerializeToString() for value in case.values]
+ self._runDecodeProtoTests(
+ case.fields,
+ case.sizes,
+ list(case.shapes),
+ batch,
+ 'tensorflow.contrib.proto.TestValue',
+ 'binary',
+ sanitize=False)
+
+ @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
+ def testBinaryDisordered(self, case):
+ batch = [value.SerializeToString() for value in case.values]
+ self._runDecodeProtoTests(
+ case.fields,
+ case.sizes,
+ list(case.shapes),
+ batch,
+ 'tensorflow.contrib.proto.TestValue',
+ 'binary',
+ sanitize=False,
+ force_disordered=True)
+
+ @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
+ def testPacked(self, case):
+ # Now try with the packed serialization.
+ #
+ # We test the packed representations by loading the same test case using
+ # PackedTestValue instead of TestValue. To do this we rely on the text
+ # format being the same for packed and unpacked fields, and reparse the
+ # test message using the packed version of the proto.
+ packed_batch = [
+ # Note: float_format='.17g' is necessary to ensure preservation of
+ # doubles and floats in text format.
+ text_format.Parse(
+ text_format.MessageToString(
+ value, float_format='.17g'),
+ test_example_pb2.PackedTestValue()).SerializeToString()
+ for value in case.values
+ ]
+
+ self._runDecodeProtoTests(
+ case.fields,
+ case.sizes,
+ list(case.shapes),
+ packed_batch,
+ 'tensorflow.contrib.proto.PackedTestValue',
+ 'binary',
+ sanitize=False)
+
+ @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
+ def testText(self, case):
+ # Note: float_format='.17g' is necessary to ensure preservation of
+ # doubles and floats in text format.
+ text_batch = [
+ text_format.MessageToString(
+ value, float_format='.17g') for value in case.values
+ ]
+
+ self._runDecodeProtoTests(
+ case.fields,
+ case.sizes,
+ list(case.shapes),
+ text_batch,
+ 'tensorflow.contrib.proto.TestValue',
+ 'text',
+ sanitize=False)
+
+ @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
+ def testSanitizerGood(self, case):
+ batch = [value.SerializeToString() for value in case.values]
+ self._runDecodeProtoTests(
+ case.fields,
+ case.sizes,
+ list(case.shapes),
+ batch,
+ 'tensorflow.contrib.proto.TestValue',
+ 'binary',
+ sanitize=True)
+
+ @parameterized.parameters((False), (True))
+ def testCorruptProtobuf(self, sanitize):
+ corrupt_proto = 'This is not a binary protobuf'
+
+ # Numpy silently truncates the strings if you don't specify dtype=object.
+ batch = np.array(corrupt_proto, dtype=object)
+ msg_type = 'tensorflow.contrib.proto.TestCase'
+ field_names = ['sizes']
+ field_types = [dtypes.int32]
+
+ with self.cached_session() as sess:
+ ctensor, vtensor = self._decode_module.decode_proto(
+ batch,
+ message_type=msg_type,
+ field_names=field_names,
+ output_types=field_types,
+ sanitize=sanitize)
+ with self.assertRaisesRegexp(errors.DataLossError,
+ 'Unable to parse binary protobuf'
+ '|Failed to consume entire buffer'):
+ _ = sess.run([ctensor] + vtensor)
diff --git a/tensorflow/contrib/proto/python/kernel_tests/defaut_values.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/defaut_values.TestCase.pbtxt
deleted file mode 100644
index 4e31681907..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/defaut_values.TestCase.pbtxt
+++ /dev/null
@@ -1,94 +0,0 @@
-primitive {
- # No fields specified, so we get all defaults
-}
-shape: 1
-sizes: 0
-field {
- name: "double_default"
- dtype: DT_DOUBLE
- expected { double_value: 1.0 }
-}
-sizes: 0
-field {
- name: "float_default"
- dtype: DT_DOUBLE # Try casting the float field to double.
- expected { double_value: 2.0 }
-}
-sizes: 0
-field {
- name: "int64_default"
- dtype: DT_INT64
- expected { int64_value: 3 }
-}
-sizes: 0
-field {
- name: "uint64_default"
- dtype: DT_INT64
- expected { int64_value: 4 }
-}
-sizes: 0
-field {
- name: "int32_default"
- dtype: DT_INT32
- expected { int32_value: 5 }
-}
-sizes: 0
-field {
- name: "fixed64_default"
- dtype: DT_INT64
- expected { int64_value: 6 }
-}
-sizes: 0
-field {
- name: "fixed32_default"
- dtype: DT_INT32
- expected { int32_value: 7 }
-}
-sizes: 0
-field {
- name: "bool_default"
- dtype: DT_BOOL
- expected { bool_value: true }
-}
-sizes: 0
-field {
- name: "string_default"
- dtype: DT_STRING
- expected { string_value: "a" }
-}
-sizes: 0
-field {
- name: "bytes_default"
- dtype: DT_STRING
- expected { string_value: "a longer default string" }
-}
-sizes: 0
-field {
- name: "uint32_default"
- dtype: DT_INT32
- expected { int32_value: -1 }
-}
-sizes: 0
-field {
- name: "sfixed32_default"
- dtype: DT_INT32
- expected { int32_value: 10 }
-}
-sizes: 0
-field {
- name: "sfixed64_default"
- dtype: DT_INT64
- expected { int64_value: 11 }
-}
-sizes: 0
-field {
- name: "sint32_default"
- dtype: DT_INT32
- expected { int32_value: 12 }
-}
-sizes: 0
-field {
- name: "sint64_default"
- dtype: DT_INT64
- expected { int64_value: 13 }
-}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/test_case.py b/tensorflow/contrib/proto/python/kernel_tests/descriptor_source_test.py
index b95202c5df..32ca318f73 100644
--- a/tensorflow/contrib/proto/python/kernel_tests/test_case.py
+++ b/tensorflow/contrib/proto/python/kernel_tests/descriptor_source_test.py
@@ -13,23 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
-"""Test case base for testing proto operations."""
-
+"""Tests for proto ops reading descriptors from other sources."""
# Python3 preparedness imports.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import ctypes as ct
-import os
-
+from tensorflow.contrib.proto.python.kernel_tests import descriptor_source_test_base as test_base
+from tensorflow.contrib.proto.python.ops import decode_proto_op
+from tensorflow.contrib.proto.python.ops import encode_proto_op
from tensorflow.python.platform import test
-class ProtoOpTestCase(test.TestCase):
+class DescriptorSourceTest(test_base.DescriptorSourceTestBase):
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
- super(ProtoOpTestCase, self).__init__(methodName)
- lib = os.path.join(os.path.dirname(__file__), 'libtestexample.so')
- if os.path.isfile(lib):
- ct.cdll.LoadLibrary(lib)
+ super(DescriptorSourceTest, self).__init__(decode_proto_op, encode_proto_op,
+ methodName)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/proto/python/kernel_tests/descriptor_source_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/descriptor_source_test_base.py
new file mode 100644
index 0000000000..7e9b355c69
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/descriptor_source_test_base.py
@@ -0,0 +1,176 @@
+# =============================================================================
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Tests for proto ops reading descriptors from other sources."""
+# Python3 preparedness imports.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import numpy as np
+
+from google.protobuf.descriptor_pb2 import FieldDescriptorProto
+from google.protobuf.descriptor_pb2 import FileDescriptorSet
+from tensorflow.contrib.proto.python.kernel_tests import proto_op_test_base as test_base
+from tensorflow.python.framework import dtypes
+from tensorflow.python.platform import test
+
+
+class DescriptorSourceTestBase(test.TestCase):
+ """Base class for testing descriptor sources."""
+
+ def __init__(self, decode_module, encode_module, methodName='runTest'): # pylint: disable=invalid-name
+ """DescriptorSourceTestBase initializer.
+
+ Args:
+ decode_module: a module containing the `decode_proto_op` method
+ encode_module: a module containing the `encode_proto_op` method
+ methodName: the name of the test method (same as for test.TestCase)
+ """
+
+ super(DescriptorSourceTestBase, self).__init__(methodName)
+ self._decode_module = decode_module
+ self._encode_module = encode_module
+
+ # NOTE: We generate the descriptor programmatically instead of via a compiler
+ # because of differences between different versions of the compiler.
+ #
+ # The generated descriptor should capture the subset of `test_example.proto`
+ # used in `test_base.simple_test_case()`.
+ def _createDescriptorFile(self):
+ set_proto = FileDescriptorSet()
+
+ file_proto = set_proto.file.add(
+ name='types.proto',
+ package='tensorflow',
+ syntax='proto3')
+ enum_proto = file_proto.enum_type.add(name='DataType')
+ enum_proto.value.add(name='DT_DOUBLE', number=0)
+ enum_proto.value.add(name='DT_BOOL', number=1)
+
+ file_proto = set_proto.file.add(
+ name='test_example.proto',
+ package='tensorflow.contrib.proto',
+ dependency=['types.proto'])
+ message_proto = file_proto.message_type.add(name='TestCase')
+ message_proto.field.add(
+ name='values',
+ number=1,
+ type=FieldDescriptorProto.TYPE_MESSAGE,
+ type_name='.tensorflow.contrib.proto.TestValue',
+ label=FieldDescriptorProto.LABEL_REPEATED)
+ message_proto.field.add(
+ name='shapes',
+ number=2,
+ type=FieldDescriptorProto.TYPE_INT32,
+ label=FieldDescriptorProto.LABEL_REPEATED)
+ message_proto.field.add(
+ name='sizes',
+ number=3,
+ type=FieldDescriptorProto.TYPE_INT32,
+ label=FieldDescriptorProto.LABEL_REPEATED)
+ message_proto.field.add(
+ name='fields',
+ number=4,
+ type=FieldDescriptorProto.TYPE_MESSAGE,
+ type_name='.tensorflow.contrib.proto.FieldSpec',
+ label=FieldDescriptorProto.LABEL_REPEATED)
+
+ message_proto = file_proto.message_type.add(
+ name='TestValue')
+ message_proto.field.add(
+ name='double_value',
+ number=1,
+ type=FieldDescriptorProto.TYPE_DOUBLE,
+ label=FieldDescriptorProto.LABEL_REPEATED)
+ message_proto.field.add(
+ name='bool_value',
+ number=2,
+ type=FieldDescriptorProto.TYPE_BOOL,
+ label=FieldDescriptorProto.LABEL_REPEATED)
+
+ message_proto = file_proto.message_type.add(
+ name='FieldSpec')
+ message_proto.field.add(
+ name='name',
+ number=1,
+ type=FieldDescriptorProto.TYPE_STRING,
+ label=FieldDescriptorProto.LABEL_OPTIONAL)
+ message_proto.field.add(
+ name='dtype',
+ number=2,
+ type=FieldDescriptorProto.TYPE_ENUM,
+ type_name='.tensorflow.DataType',
+ label=FieldDescriptorProto.LABEL_OPTIONAL)
+ message_proto.field.add(
+ name='value',
+ number=3,
+ type=FieldDescriptorProto.TYPE_MESSAGE,
+ type_name='.tensorflow.contrib.proto.TestValue',
+ label=FieldDescriptorProto.LABEL_OPTIONAL)
+
+ fn = os.path.join(self.get_temp_dir(), 'descriptor.pb')
+ with open(fn, 'wb') as f:
+ f.write(set_proto.SerializeToString())
+ return fn
+
+ def _testRoundtrip(self, descriptor_source):
+ # Numpy silently truncates the strings if you don't specify dtype=object.
+ in_bufs = np.array(
+ [test_base.ProtoOpTestBase.simple_test_case().SerializeToString()],
+ dtype=object)
+ message_type = 'tensorflow.contrib.proto.TestCase'
+ field_names = ['values', 'shapes', 'sizes', 'fields']
+ tensor_types = [dtypes.string, dtypes.int32, dtypes.int32, dtypes.string]
+
+ with self.cached_session() as sess:
+ sizes, field_tensors = self._decode_module.decode_proto(
+ in_bufs,
+ message_type=message_type,
+ field_names=field_names,
+ output_types=tensor_types,
+ descriptor_source=descriptor_source)
+
+ out_tensors = self._encode_module.encode_proto(
+ sizes,
+ field_tensors,
+ message_type=message_type,
+ field_names=field_names,
+ descriptor_source=descriptor_source)
+
+ out_bufs, = sess.run([out_tensors])
+
+ # Check that the re-encoded tensor has the same shape.
+ self.assertEqual(in_bufs.shape, out_bufs.shape)
+
+ # Compare the input and output.
+ for in_buf, out_buf in zip(in_bufs.flat, out_bufs.flat):
+ # Check that the input and output serialized messages are identical.
+ # If we fail here, there is a difference in the serialized
+ # representation but the new serialization still parses. This could
+ # be harmless (a change in map ordering?) or it could be bad (e.g.
+ # loss of packing in the encoding).
+ self.assertEqual(in_buf, out_buf)
+
+ def testWithFileDescriptorSet(self):
+ # First try parsing with a local proto db, which should fail.
+ with self.assertRaisesOpError('No descriptor found for message type'):
+ self._testRoundtrip('local://')
+
+ # Now try parsing with a FileDescriptorSet which contains the test proto.
+ descriptor_file = self._createDescriptorFile()
+ self._testRoundtrip(descriptor_file)
diff --git a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py
index 30e58e6336..fc5cd25d43 100644
--- a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py
+++ b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py
@@ -13,167 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
-"""Table-driven test for encode_proto op.
+"""Tests for encode_proto op."""
-This test is run once with each of the *.TestCase.pbtxt files
-in the test directory.
-
-It tests that encode_proto is a lossless inverse of decode_proto
-(for the specified fields).
-"""
# Python3 readiness boilerplate
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-
-from google.protobuf import text_format
-
-from tensorflow.contrib.proto.python.kernel_tests import test_case
-from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2
+from tensorflow.contrib.proto.python.kernel_tests import encode_proto_op_test_base as test_base
from tensorflow.contrib.proto.python.ops import decode_proto_op
from tensorflow.contrib.proto.python.ops import encode_proto_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import flags
from tensorflow.python.platform import test
-FLAGS = flags.FLAGS
-
-flags.DEFINE_string('message_text_file', None,
- 'A file containing a text serialized TestCase protobuf.')
-
-
-class EncodeProtoOpTest(test_case.ProtoOpTestCase):
-
- def testBadInputs(self):
- # Invalid field name
- with self.test_session():
- with self.assertRaisesOpError('Unknown field: non_existent_field'):
- encode_proto_op.encode_proto(
- sizes=[[1]],
- values=[np.array([[0.0]], dtype=np.int32)],
- message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue',
- field_names=['non_existent_field']).eval()
-
- # Incorrect types.
- with self.test_session():
- with self.assertRaisesOpError(
- 'Incompatible type for field double_value.'):
- encode_proto_op.encode_proto(
- sizes=[[1]],
- values=[np.array([[0.0]], dtype=np.int32)],
- message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue',
- field_names=['double_value']).eval()
-
- # Incorrect shapes of sizes.
- with self.test_session():
- with self.assertRaisesOpError(
- r'sizes should be batch_size \+ \[len\(field_names\)\]'):
- sizes = array_ops.placeholder(dtypes.int32)
- values = array_ops.placeholder(dtypes.float64)
- encode_proto_op.encode_proto(
- sizes=sizes,
- values=[values],
- message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue',
- field_names=['double_value']).eval(feed_dict={
- sizes: [[[0, 0]]],
- values: [[0.0]]
- })
-
- # Inconsistent shapes of values.
- with self.test_session():
- with self.assertRaisesOpError(
- 'Values must match up to the last dimension'):
- sizes = array_ops.placeholder(dtypes.int32)
- values1 = array_ops.placeholder(dtypes.float64)
- values2 = array_ops.placeholder(dtypes.int32)
- (encode_proto_op.encode_proto(
- sizes=[[1, 1]],
- values=[values1, values2],
- message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue',
- field_names=['double_value', 'int32_value']).eval(feed_dict={
- values1: [[0.0]],
- values2: [[0], [0]]
- }))
-
- def _testRoundtrip(self, in_bufs, message_type, fields):
-
- field_names = [f.name for f in fields]
- out_types = [f.dtype for f in fields]
-
- with self.test_session() as sess:
- sizes, field_tensors = decode_proto_op.decode_proto(
- in_bufs,
- message_type=message_type,
- field_names=field_names,
- output_types=out_types)
-
- out_tensors = encode_proto_op.encode_proto(
- sizes,
- field_tensors,
- message_type=message_type,
- field_names=field_names)
-
- out_bufs, = sess.run([out_tensors])
-
- # Check that the re-encoded tensor has the same shape.
- self.assertEqual(in_bufs.shape, out_bufs.shape)
-
- # Compare the input and output.
- for in_buf, out_buf in zip(in_bufs.flat, out_bufs.flat):
- in_obj = test_example_pb2.RepeatedPrimitiveValue()
- in_obj.ParseFromString(in_buf)
-
- out_obj = test_example_pb2.RepeatedPrimitiveValue()
- out_obj.ParseFromString(out_buf)
-
- # Check that the deserialized objects are identical.
- self.assertEqual(in_obj, out_obj)
-
- # Check that the input and output serialized messages are identical.
- # If we fail here, there is a difference in the serialized
- # representation but the new serialization still parses. This could
- # be harmless (a change in map ordering?) or it could be bad (e.g.
- # loss of packing in the encoding).
- self.assertEqual(in_buf, out_buf)
-
- def testRoundtrip(self):
- with open(FLAGS.message_text_file, 'r') as fp:
- case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
-
- in_bufs = [primitive.SerializeToString() for primitive in case.primitive]
-
- # np.array silently truncates strings if you don't specify dtype=object.
- in_bufs = np.reshape(np.array(in_bufs, dtype=object), list(case.shape))
- return self._testRoundtrip(
- in_bufs, 'tensorflow.contrib.proto.RepeatedPrimitiveValue', case.field)
-
- def testRoundtripPacked(self):
- with open(FLAGS.message_text_file, 'r') as fp:
- case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
- # Now try with the packed serialization.
- # We test the packed representations by loading the same test cases
- # using PackedPrimitiveValue instead of RepeatedPrimitiveValue.
- # To do this we rely on the text format being the same for packed and
- # unpacked fields, and reparse the test message using the packed version
- # of the proto.
- in_bufs = [
- # Note: float_format='.17g' is necessary to ensure preservation of
- # doubles and floats in text format.
- text_format.Parse(
- text_format.MessageToString(
- primitive, float_format='.17g'),
- test_example_pb2.PackedPrimitiveValue()).SerializeToString()
- for primitive in case.primitive
- ]
+class EncodeProtoOpTest(test_base.EncodeProtoOpTestBase):
- # np.array silently truncates strings if you don't specify dtype=object.
- in_bufs = np.reshape(np.array(in_bufs, dtype=object), list(case.shape))
- return self._testRoundtrip(
- in_bufs, 'tensorflow.contrib.proto.PackedPrimitiveValue', case.field)
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ super(EncodeProtoOpTest, self).__init__(decode_proto_op, encode_proto_op,
+ methodName)
if __name__ == '__main__':
diff --git a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py
new file mode 100644
index 0000000000..01b3ccc7fd
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py
@@ -0,0 +1,177 @@
+# =============================================================================
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Table-driven test for encode_proto op.
+
+This test is run once with each of the *.TestCase.pbtxt files
+in the test directory.
+
+It tests that encode_proto is a lossless inverse of decode_proto
+(for the specified fields).
+"""
+# Python3 readiness boilerplate
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from google.protobuf import text_format
+
+from tensorflow.contrib.proto.python.kernel_tests import proto_op_test_base as test_base
+from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+
+
+class EncodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase):
+ """Base class for testing proto encoding ops."""
+
+ def __init__(self, decode_module, encode_module, methodName='runTest'): # pylint: disable=invalid-name
+ """EncodeProtoOpTestBase initializer.
+
+ Args:
+ decode_module: a module containing the `decode_proto_op` method
+ encode_module: a module containing the `encode_proto_op` method
+ methodName: the name of the test method (same as for test.TestCase)
+ """
+
+ super(EncodeProtoOpTestBase, self).__init__(methodName)
+ self._decode_module = decode_module
+ self._encode_module = encode_module
+
+ def testBadInputs(self):
+ # Invalid field name
+ with self.cached_session():
+ with self.assertRaisesOpError('Unknown field: non_existent_field'):
+ self._encode_module.encode_proto(
+ sizes=[[1]],
+ values=[np.array([[0.0]], dtype=np.int32)],
+ message_type='tensorflow.contrib.proto.TestValue',
+ field_names=['non_existent_field']).eval()
+
+ # Incorrect types.
+ with self.cached_session():
+ with self.assertRaisesOpError(
+ 'Incompatible type for field double_value.'):
+ self._encode_module.encode_proto(
+ sizes=[[1]],
+ values=[np.array([[0.0]], dtype=np.int32)],
+ message_type='tensorflow.contrib.proto.TestValue',
+ field_names=['double_value']).eval()
+
+ # Incorrect shapes of sizes.
+ with self.cached_session():
+ with self.assertRaisesOpError(
+ r'sizes should be batch_size \+ \[len\(field_names\)\]'):
+ sizes = array_ops.placeholder(dtypes.int32)
+ values = array_ops.placeholder(dtypes.float64)
+ self._encode_module.encode_proto(
+ sizes=sizes,
+ values=[values],
+ message_type='tensorflow.contrib.proto.TestValue',
+ field_names=['double_value']).eval(feed_dict={
+ sizes: [[[0, 0]]],
+ values: [[0.0]]
+ })
+
+ # Inconsistent shapes of values.
+ with self.cached_session():
+ with self.assertRaisesOpError(
+ 'Values must match up to the last dimension'):
+ sizes = array_ops.placeholder(dtypes.int32)
+ values1 = array_ops.placeholder(dtypes.float64)
+ values2 = array_ops.placeholder(dtypes.int32)
+ (self._encode_module.encode_proto(
+ sizes=[[1, 1]],
+ values=[values1, values2],
+ message_type='tensorflow.contrib.proto.TestValue',
+ field_names=['double_value', 'int32_value']).eval(feed_dict={
+ values1: [[0.0]],
+ values2: [[0], [0]]
+ }))
+
+ def _testRoundtrip(self, in_bufs, message_type, fields):
+
+ field_names = [f.name for f in fields]
+ out_types = [f.dtype for f in fields]
+
+ with self.cached_session() as sess:
+ sizes, field_tensors = self._decode_module.decode_proto(
+ in_bufs,
+ message_type=message_type,
+ field_names=field_names,
+ output_types=out_types)
+
+ out_tensors = self._encode_module.encode_proto(
+ sizes,
+ field_tensors,
+ message_type=message_type,
+ field_names=field_names)
+
+ out_bufs, = sess.run([out_tensors])
+
+ # Check that the re-encoded tensor has the same shape.
+ self.assertEqual(in_bufs.shape, out_bufs.shape)
+
+ # Compare the input and output.
+ for in_buf, out_buf in zip(in_bufs.flat, out_bufs.flat):
+ in_obj = test_example_pb2.TestValue()
+ in_obj.ParseFromString(in_buf)
+
+ out_obj = test_example_pb2.TestValue()
+ out_obj.ParseFromString(out_buf)
+
+ # Check that the deserialized objects are identical.
+ self.assertEqual(in_obj, out_obj)
+
+ # Check that the input and output serialized messages are identical.
+ # If we fail here, there is a difference in the serialized
+ # representation but the new serialization still parses. This could
+ # be harmless (a change in map ordering?) or it could be bad (e.g.
+ # loss of packing in the encoding).
+ self.assertEqual(in_buf, out_buf)
+
+ @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
+ def testRoundtrip(self, case):
+ in_bufs = [value.SerializeToString() for value in case.values]
+
+ # np.array silently truncates strings if you don't specify dtype=object.
+ in_bufs = np.reshape(np.array(in_bufs, dtype=object), list(case.shapes))
+ return self._testRoundtrip(
+ in_bufs, 'tensorflow.contrib.proto.TestValue', case.fields)
+
+ @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
+ def testRoundtripPacked(self, case):
+ # Now try with the packed serialization.
+ # We test the packed representations by loading the same test cases using
+ # PackedTestValue instead of TestValue. To do this we rely on the text
+ # format being the same for packed and unpacked fields, and reparse the test
+ # message using the packed version of the proto.
+ in_bufs = [
+ # Note: float_format='.17g' is necessary to ensure preservation of
+ # doubles and floats in text format.
+ text_format.Parse(
+ text_format.MessageToString(
+ value, float_format='.17g'),
+ test_example_pb2.PackedTestValue()).SerializeToString()
+ for value in case.values
+ ]
+
+ # np.array silently truncates strings if you don't specify dtype=object.
+ in_bufs = np.reshape(np.array(in_bufs, dtype=object), list(case.shapes))
+ return self._testRoundtrip(
+ in_bufs, 'tensorflow.contrib.proto.PackedTestValue', case.fields)
diff --git a/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt
deleted file mode 100644
index b170f89c0f..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt
+++ /dev/null
@@ -1,161 +0,0 @@
-primitive {
- double_value: -1.7976931348623158e+308
- double_value: 2.2250738585072014e-308
- double_value: 1.7976931348623158e+308
- float_value: -3.402823466e+38
- float_value: 1.175494351e-38
- float_value: 3.402823466e+38
- int64_value: -9223372036854775808
- int64_value: 9223372036854775807
- uint64_value: 0
- uint64_value: 18446744073709551615
- int32_value: -2147483648
- int32_value: 2147483647
- fixed64_value: 0
- fixed64_value: 18446744073709551615
- fixed32_value: 0
- fixed32_value: 4294967295
- bool_value: false
- bool_value: true
- string_value: ""
- string_value: "I refer to the infinite."
- uint32_value: 0
- uint32_value: 4294967295
- sfixed32_value: -2147483648
- sfixed32_value: 2147483647
- sfixed64_value: -9223372036854775808
- sfixed64_value: 9223372036854775807
- sint32_value: -2147483648
- sint32_value: 2147483647
- sint64_value: -9223372036854775808
- sint64_value: 9223372036854775807
-}
-shape: 1
-sizes: 3
-sizes: 3
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-field {
- name: "double_value"
- dtype: DT_DOUBLE
- expected {
- double_value: -1.7976931348623158e+308
- double_value: 2.2250738585072014e-308
- double_value: 1.7976931348623158e+308
- }
-}
-field {
- name: "float_value"
- dtype: DT_FLOAT
- expected {
- float_value: -3.402823466e+38
- float_value: 1.175494351e-38
- float_value: 3.402823466e+38
- }
-}
-field {
- name: "int64_value"
- dtype: DT_INT64
- expected {
- int64_value: -9223372036854775808
- int64_value: 9223372036854775807
- }
-}
-field {
- name: "uint64_value"
- dtype: DT_INT64
- expected {
- int64_value: 0
- int64_value: -1
- }
-}
-field {
- name: "int32_value"
- dtype: DT_INT32
- expected {
- int32_value: -2147483648
- int32_value: 2147483647
- }
-}
-field {
- name: "fixed64_value"
- dtype: DT_INT64
- expected {
- int64_value: 0
- int64_value: -1 # unsigned is 18446744073709551615
- }
-}
-field {
- name: "fixed32_value"
- dtype: DT_INT32
- expected {
- int32_value: 0
- int32_value: -1 # unsigned is 4294967295
- }
-}
-field {
- name: "bool_value"
- dtype: DT_BOOL
- expected {
- bool_value: false
- bool_value: true
- }
-}
-field {
- name: "string_value"
- dtype: DT_STRING
- expected {
- string_value: ""
- string_value: "I refer to the infinite."
- }
-}
-field {
- name: "uint32_value"
- dtype: DT_INT32
- expected {
- int32_value: 0
- int32_value: -1 # unsigned is 4294967295
- }
-}
-field {
- name: "sfixed32_value"
- dtype: DT_INT32
- expected {
- int32_value: -2147483648
- int32_value: 2147483647
- }
-}
-field {
- name: "sfixed64_value"
- dtype: DT_INT64
- expected {
- int64_value: -9223372036854775808
- int64_value: 9223372036854775807
- }
-}
-field {
- name: "sint32_value"
- dtype: DT_INT32
- expected {
- int32_value: -2147483648
- int32_value: 2147483647
- }
-}
-field {
- name: "sint64_value"
- dtype: DT_INT64
- expected {
- int64_value: -9223372036854775808
- int64_value: 9223372036854775807
- }
-}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt
deleted file mode 100644
index c664e52851..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt
+++ /dev/null
@@ -1,16 +0,0 @@
-primitive {
- message_value {
- double_value: 23.5
- }
-}
-shape: 1
-sizes: 1
-field {
- name: "message_value"
- dtype: DT_STRING
- expected {
- message_value {
- double_value: 23.5
- }
- }
-}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt
deleted file mode 100644
index 125651d7ea..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt
+++ /dev/null
@@ -1,20 +0,0 @@
-primitive {
- bool_value: true
-}
-shape: 1
-sizes: 1
-sizes: 0
-field {
- name: "bool_value"
- dtype: DT_BOOL
- expected {
- bool_value: true
- }
-}
-field {
- name: "double_value"
- dtype: DT_DOUBLE
- expected {
- double_value: 0.0
- }
-}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt
deleted file mode 100644
index bc07efc8f3..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt
+++ /dev/null
@@ -1,29 +0,0 @@
-primitive {
- fixed32_value: 4294967295
- uint32_value: 4294967295
-}
-shape: 1
-sizes: 1
-field {
- name: "fixed32_value"
- dtype: DT_INT64
- expected {
- int64_value: 4294967295
- }
-}
-sizes: 1
-field {
- name: "uint32_value"
- dtype: DT_INT64
- expected {
- int64_value: 4294967295
- }
-}
-sizes: 0
-field {
- name: "uint32_default"
- dtype: DT_INT64
- expected {
- int64_value: 4294967295 # Comes from an explicitly-specified default
- }
-}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py
new file mode 100644
index 0000000000..2950c7dfdc
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py
@@ -0,0 +1,419 @@
+# =============================================================================
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Test case base for testing proto operations."""
+
+# Python3 preparedness imports.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ctypes as ct
+import os
+
+from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2
+from tensorflow.core.framework import types_pb2
+from tensorflow.python.platform import test
+
+
+class ProtoOpTestBase(test.TestCase):
+ """Base class for testing proto decoding and encoding ops."""
+
+ def __init__(self, methodName="runTest"): # pylint: disable=invalid-name
+ super(ProtoOpTestBase, self).__init__(methodName)
+ lib = os.path.join(os.path.dirname(__file__), "libtestexample.so")
+ if os.path.isfile(lib):
+ ct.cdll.LoadLibrary(lib)
+
+ @staticmethod
+ def named_parameters():
+ return (
+ ("defaults", ProtoOpTestBase.defaults_test_case()),
+ ("minmax", ProtoOpTestBase.minmax_test_case()),
+ ("nested", ProtoOpTestBase.nested_test_case()),
+ ("optional", ProtoOpTestBase.optional_test_case()),
+ ("promote", ProtoOpTestBase.promote_test_case()),
+ ("ragged", ProtoOpTestBase.ragged_test_case()),
+ ("shaped_batch", ProtoOpTestBase.shaped_batch_test_case()),
+ ("simple", ProtoOpTestBase.simple_test_case()),
+ )
+
+ @staticmethod
+ def defaults_test_case():
+ test_case = test_example_pb2.TestCase()
+ test_case.values.add() # No fields specified, so we get all defaults.
+ test_case.shapes.append(1)
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "double_value_with_default"
+ field.dtype = types_pb2.DT_DOUBLE
+ field.value.double_value.append(1.0)
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "float_value_with_default"
+ field.dtype = types_pb2.DT_FLOAT
+ field.value.float_value.append(2.0)
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "int64_value_with_default"
+ field.dtype = types_pb2.DT_INT64
+ field.value.int64_value.append(3)
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "sfixed64_value_with_default"
+ field.dtype = types_pb2.DT_INT64
+ field.value.int64_value.append(11)
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "sint64_value_with_default"
+ field.dtype = types_pb2.DT_INT64
+ field.value.int64_value.append(13)
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "uint64_value_with_default"
+ field.dtype = types_pb2.DT_UINT64
+ field.value.uint64_value.append(4)
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "fixed64_value_with_default"
+ field.dtype = types_pb2.DT_UINT64
+ field.value.uint64_value.append(6)
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "int32_value_with_default"
+ field.dtype = types_pb2.DT_INT32
+ field.value.int32_value.append(5)
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "sfixed32_value_with_default"
+ field.dtype = types_pb2.DT_INT32
+ field.value.int32_value.append(10)
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "sint32_value_with_default"
+ field.dtype = types_pb2.DT_INT32
+ field.value.int32_value.append(12)
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "uint32_value_with_default"
+ field.dtype = types_pb2.DT_UINT32
+ field.value.uint32_value.append(9)
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "fixed32_value_with_default"
+ field.dtype = types_pb2.DT_UINT32
+ field.value.uint32_value.append(7)
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "bool_value_with_default"
+ field.dtype = types_pb2.DT_BOOL
+ field.value.bool_value.append(True)
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "string_value_with_default"
+ field.dtype = types_pb2.DT_STRING
+ field.value.string_value.append("a")
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "bytes_value_with_default"
+ field.dtype = types_pb2.DT_STRING
+ field.value.string_value.append("a longer default string")
+ return test_case
+
+ @staticmethod
+ def minmax_test_case():
+ test_case = test_example_pb2.TestCase()
+ value = test_case.values.add()
+ value.double_value.append(-1.7976931348623158e+308)
+ value.double_value.append(2.2250738585072014e-308)
+ value.double_value.append(1.7976931348623158e+308)
+ value.float_value.append(-3.402823466e+38)
+ value.float_value.append(1.175494351e-38)
+ value.float_value.append(3.402823466e+38)
+ value.int64_value.append(-9223372036854775808)
+ value.int64_value.append(9223372036854775807)
+ value.sfixed64_value.append(-9223372036854775808)
+ value.sfixed64_value.append(9223372036854775807)
+ value.sint64_value.append(-9223372036854775808)
+ value.sint64_value.append(9223372036854775807)
+ value.uint64_value.append(0)
+ value.uint64_value.append(18446744073709551615)
+ value.fixed64_value.append(0)
+ value.fixed64_value.append(18446744073709551615)
+ value.int32_value.append(-2147483648)
+ value.int32_value.append(2147483647)
+ value.sfixed32_value.append(-2147483648)
+ value.sfixed32_value.append(2147483647)
+ value.sint32_value.append(-2147483648)
+ value.sint32_value.append(2147483647)
+ value.uint32_value.append(0)
+ value.uint32_value.append(4294967295)
+ value.fixed32_value.append(0)
+ value.fixed32_value.append(4294967295)
+ value.bool_value.append(False)
+ value.bool_value.append(True)
+ value.string_value.append("")
+ value.string_value.append("I refer to the infinite.")
+ test_case.shapes.append(1)
+ test_case.sizes.append(3)
+ field = test_case.fields.add()
+ field.name = "double_value"
+ field.dtype = types_pb2.DT_DOUBLE
+ field.value.double_value.append(-1.7976931348623158e+308)
+ field.value.double_value.append(2.2250738585072014e-308)
+ field.value.double_value.append(1.7976931348623158e+308)
+ test_case.sizes.append(3)
+ field = test_case.fields.add()
+ field.name = "float_value"
+ field.dtype = types_pb2.DT_FLOAT
+ field.value.float_value.append(-3.402823466e+38)
+ field.value.float_value.append(1.175494351e-38)
+ field.value.float_value.append(3.402823466e+38)
+ test_case.sizes.append(2)
+ field = test_case.fields.add()
+ field.name = "int64_value"
+ field.dtype = types_pb2.DT_INT64
+ field.value.int64_value.append(-9223372036854775808)
+ field.value.int64_value.append(9223372036854775807)
+ test_case.sizes.append(2)
+ field = test_case.fields.add()
+ field.name = "sfixed64_value"
+ field.dtype = types_pb2.DT_INT64
+ field.value.int64_value.append(-9223372036854775808)
+ field.value.int64_value.append(9223372036854775807)
+ test_case.sizes.append(2)
+ field = test_case.fields.add()
+ field.name = "sint64_value"
+ field.dtype = types_pb2.DT_INT64
+ field.value.int64_value.append(-9223372036854775808)
+ field.value.int64_value.append(9223372036854775807)
+ test_case.sizes.append(2)
+ field = test_case.fields.add()
+ field.name = "uint64_value"
+ field.dtype = types_pb2.DT_UINT64
+ field.value.uint64_value.append(0)
+ field.value.uint64_value.append(18446744073709551615)
+ test_case.sizes.append(2)
+ field = test_case.fields.add()
+ field.name = "fixed64_value"
+ field.dtype = types_pb2.DT_UINT64
+ field.value.uint64_value.append(0)
+ field.value.uint64_value.append(18446744073709551615)
+ test_case.sizes.append(2)
+ field = test_case.fields.add()
+ field.name = "int32_value"
+ field.dtype = types_pb2.DT_INT32
+ field.value.int32_value.append(-2147483648)
+ field.value.int32_value.append(2147483647)
+ test_case.sizes.append(2)
+ field = test_case.fields.add()
+ field.name = "sfixed32_value"
+ field.dtype = types_pb2.DT_INT32
+ field.value.int32_value.append(-2147483648)
+ field.value.int32_value.append(2147483647)
+ test_case.sizes.append(2)
+ field = test_case.fields.add()
+ field.name = "sint32_value"
+ field.dtype = types_pb2.DT_INT32
+ field.value.int32_value.append(-2147483648)
+ field.value.int32_value.append(2147483647)
+ test_case.sizes.append(2)
+ field = test_case.fields.add()
+ field.name = "uint32_value"
+ field.dtype = types_pb2.DT_UINT32
+ field.value.uint32_value.append(0)
+ field.value.uint32_value.append(4294967295)
+ test_case.sizes.append(2)
+ field = test_case.fields.add()
+ field.name = "fixed32_value"
+ field.dtype = types_pb2.DT_UINT32
+ field.value.uint32_value.append(0)
+ field.value.uint32_value.append(4294967295)
+ test_case.sizes.append(2)
+ field = test_case.fields.add()
+ field.name = "bool_value"
+ field.dtype = types_pb2.DT_BOOL
+ field.value.bool_value.append(False)
+ field.value.bool_value.append(True)
+ test_case.sizes.append(2)
+ field = test_case.fields.add()
+ field.name = "string_value"
+ field.dtype = types_pb2.DT_STRING
+ field.value.string_value.append("")
+ field.value.string_value.append("I refer to the infinite.")
+ return test_case
+
+ @staticmethod
+ def nested_test_case():
+ test_case = test_example_pb2.TestCase()
+ value = test_case.values.add()
+ message_value = value.message_value.add()
+ message_value.double_value = 23.5
+ test_case.shapes.append(1)
+ test_case.sizes.append(1)
+ field = test_case.fields.add()
+ field.name = "message_value"
+ field.dtype = types_pb2.DT_STRING
+ message_value = field.value.message_value.add()
+ message_value.double_value = 23.5
+ return test_case
+
+ @staticmethod
+ def optional_test_case():
+ test_case = test_example_pb2.TestCase()
+ value = test_case.values.add()
+ value.bool_value.append(True)
+ test_case.shapes.append(1)
+ test_case.sizes.append(1)
+ field = test_case.fields.add()
+ field.name = "bool_value"
+ field.dtype = types_pb2.DT_BOOL
+ field.value.bool_value.append(True)
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "double_value"
+ field.dtype = types_pb2.DT_DOUBLE
+ field.value.double_value.append(0.0)
+ return test_case
+
+ @staticmethod
+ def promote_test_case():
+ test_case = test_example_pb2.TestCase()
+ value = test_case.values.add()
+ value.sint32_value.append(2147483647)
+ value.sfixed32_value.append(2147483647)
+ value.int32_value.append(2147483647)
+ value.fixed32_value.append(4294967295)
+ value.uint32_value.append(4294967295)
+ test_case.shapes.append(1)
+ test_case.sizes.append(1)
+ field = test_case.fields.add()
+ field.name = "sint32_value"
+ field.dtype = types_pb2.DT_INT64
+ field.value.int64_value.append(2147483647)
+ test_case.sizes.append(1)
+ field = test_case.fields.add()
+ field.name = "sfixed32_value"
+ field.dtype = types_pb2.DT_INT64
+ field.value.int64_value.append(2147483647)
+ test_case.sizes.append(1)
+ field = test_case.fields.add()
+ field.name = "int32_value"
+ field.dtype = types_pb2.DT_INT64
+ field.value.int64_value.append(2147483647)
+ test_case.sizes.append(1)
+ field = test_case.fields.add()
+ field.name = "fixed32_value"
+ field.dtype = types_pb2.DT_UINT64
+ field.value.uint64_value.append(4294967295)
+ test_case.sizes.append(1)
+ field = test_case.fields.add()
+ field.name = "uint32_value"
+ field.dtype = types_pb2.DT_UINT64
+ field.value.uint64_value.append(4294967295)
+ return test_case
+
+ @staticmethod
+ def ragged_test_case():
+ test_case = test_example_pb2.TestCase()
+ value = test_case.values.add()
+ value.double_value.append(23.5)
+ value.double_value.append(123.0)
+ value.bool_value.append(True)
+ value = test_case.values.add()
+ value.double_value.append(3.1)
+ value.bool_value.append(False)
+ test_case.shapes.append(2)
+ test_case.sizes.append(2)
+ test_case.sizes.append(1)
+ test_case.sizes.append(1)
+ test_case.sizes.append(1)
+ field = test_case.fields.add()
+ field.name = "double_value"
+ field.dtype = types_pb2.DT_DOUBLE
+ field.value.double_value.append(23.5)
+ field.value.double_value.append(123.0)
+ field.value.double_value.append(3.1)
+ field.value.double_value.append(0.0)
+ field = test_case.fields.add()
+ field.name = "bool_value"
+ field.dtype = types_pb2.DT_BOOL
+ field.value.bool_value.append(True)
+ field.value.bool_value.append(False)
+ return test_case
+
+ @staticmethod
+ def shaped_batch_test_case():
+ test_case = test_example_pb2.TestCase()
+ value = test_case.values.add()
+ value.double_value.append(23.5)
+ value.bool_value.append(True)
+ value = test_case.values.add()
+ value.double_value.append(44.0)
+ value.bool_value.append(False)
+ value = test_case.values.add()
+ value.double_value.append(3.14159)
+ value.bool_value.append(True)
+ value = test_case.values.add()
+ value.double_value.append(1.414)
+ value.bool_value.append(True)
+ value = test_case.values.add()
+ value.double_value.append(-32.2)
+ value.bool_value.append(False)
+ value = test_case.values.add()
+ value.double_value.append(0.0001)
+ value.bool_value.append(True)
+ test_case.shapes.append(3)
+ test_case.shapes.append(2)
+ for _ in range(12):
+ test_case.sizes.append(1)
+ field = test_case.fields.add()
+ field.name = "double_value"
+ field.dtype = types_pb2.DT_DOUBLE
+ field.value.double_value.append(23.5)
+ field.value.double_value.append(44.0)
+ field.value.double_value.append(3.14159)
+ field.value.double_value.append(1.414)
+ field.value.double_value.append(-32.2)
+ field.value.double_value.append(0.0001)
+ field = test_case.fields.add()
+ field.name = "bool_value"
+ field.dtype = types_pb2.DT_BOOL
+ field.value.bool_value.append(True)
+ field.value.bool_value.append(False)
+ field.value.bool_value.append(True)
+ field.value.bool_value.append(True)
+ field.value.bool_value.append(False)
+ field.value.bool_value.append(True)
+ return test_case
+
+ @staticmethod
+ def simple_test_case():
+ test_case = test_example_pb2.TestCase()
+ value = test_case.values.add()
+ value.double_value.append(23.5)
+ value.bool_value.append(True)
+ test_case.shapes.append(1)
+ test_case.sizes.append(1)
+ field = test_case.fields.add()
+ field.name = "double_value"
+ field.dtype = types_pb2.DT_DOUBLE
+ field.value.double_value.append(23.5)
+ test_case.sizes.append(1)
+ field = test_case.fields.add()
+ field.name = "bool_value"
+ field.dtype = types_pb2.DT_BOOL
+ field.value.bool_value.append(True)
+ return test_case
diff --git a/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt
deleted file mode 100644
index 61c7ac53f7..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt
+++ /dev/null
@@ -1,32 +0,0 @@
-primitive {
- double_value: 23.5
- double_value: 123.0
- bool_value: true
-}
-primitive {
- double_value: 3.1
- bool_value: false
-}
-shape: 2
-sizes: 2
-sizes: 1
-sizes: 1
-sizes: 1
-field {
- name: "double_value"
- dtype: DT_DOUBLE
- expected {
- double_value: 23.5
- double_value: 123.0
- double_value: 3.1
- double_value: 0.0
- }
-}
-field {
- name: "bool_value"
- dtype: DT_BOOL
- expected {
- bool_value: true
- bool_value: false
- }
-}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt
deleted file mode 100644
index f4828076d5..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt
+++ /dev/null
@@ -1,62 +0,0 @@
-primitive {
- double_value: 23.5
- bool_value: true
-}
-primitive {
- double_value: 44.0
- bool_value: false
-}
-primitive {
- double_value: 3.14159
- bool_value: true
-}
-primitive {
- double_value: 1.414
- bool_value: true
-}
-primitive {
- double_value: -32.2
- bool_value: false
-}
-primitive {
- double_value: 0.0001
- bool_value: true
-}
-shape: 3
-shape: 2
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-field {
- name: "double_value"
- dtype: DT_DOUBLE
- expected {
- double_value: 23.5
- double_value: 44.0
- double_value: 3.14159
- double_value: 1.414
- double_value: -32.2
- double_value: 0.0001
- }
-}
-field {
- name: "bool_value"
- dtype: DT_BOOL
- expected {
- bool_value: true
- bool_value: false
- bool_value: true
- bool_value: true
- bool_value: false
- bool_value: true
- }
-}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt
deleted file mode 100644
index dc20ac147b..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt
+++ /dev/null
@@ -1,21 +0,0 @@
-primitive {
- double_value: 23.5
- bool_value: true
-}
-shape: 1
-sizes: 1
-sizes: 1
-field {
- name: "double_value"
- dtype: DT_DOUBLE
- expected {
- double_value: 23.5
- }
-}
-field {
- name: "bool_value"
- dtype: DT_BOOL
- expected {
- bool_value: true
- }
-}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/test_example.proto b/tensorflow/contrib/proto/python/kernel_tests/test_example.proto
index a2c88e372b..674d881220 100644
--- a/tensorflow/contrib/proto/python/kernel_tests/test_example.proto
+++ b/tensorflow/contrib/proto/python/kernel_tests/test_example.proto
@@ -1,6 +1,4 @@
// Test description and protos to work with it.
-//
-// Many of the protos in this file are for unit tests that haven't been written yet.
syntax = "proto2";
@@ -8,54 +6,27 @@ import "tensorflow/core/framework/types.proto";
package tensorflow.contrib.proto;
-// A TestCase holds a proto and a bunch of assertions
-// about how it should decode.
+// A TestCase holds a proto and assertions about how it should decode.
message TestCase {
- // A batch of primitives to be serialized and decoded.
- repeated RepeatedPrimitiveValue primitive = 1;
- // The shape of the batch.
- repeated int32 shape = 2;
+ // Batches of primitive values.
+ repeated TestValue values = 1;
+ // The batch shapes.
+ repeated int32 shapes = 2;
// Expected sizes for each field.
repeated int32 sizes = 3;
// Expected values for each field.
- repeated FieldSpec field = 4;
+ repeated FieldSpec fields = 4;
};
// FieldSpec describes the expected output for a single field.
message FieldSpec {
optional string name = 1;
optional tensorflow.DataType dtype = 2;
- optional RepeatedPrimitiveValue expected = 3;
+ optional TestValue value = 3;
};
+// NOTE: This definition must be kept in sync with PackedTestValue.
message TestValue {
- optional PrimitiveValue primitive_value = 1;
- optional EnumValue enum_value = 2;
- optional MessageValue message_value = 3;
- optional RepeatedMessageValue repeated_message_value = 4;
- optional RepeatedPrimitiveValue repeated_primitive_value = 6;
-}
-
-message PrimitiveValue {
- optional double double_value = 1;
- optional float float_value = 2;
- optional int64 int64_value = 3;
- optional uint64 uint64_value = 4;
- optional int32 int32_value = 5;
- optional fixed64 fixed64_value = 6;
- optional fixed32 fixed32_value = 7;
- optional bool bool_value = 8;
- optional string string_value = 9;
- optional bytes bytes_value = 12;
- optional uint32 uint32_value = 13;
- optional sfixed32 sfixed32_value = 15;
- optional sfixed64 sfixed64_value = 16;
- optional sint32 sint32_value = 17;
- optional sint64 sint64_value = 18;
-}
-
-// NOTE: This definition must be kept in sync with PackedPrimitiveValue.
-message RepeatedPrimitiveValue {
repeated double double_value = 1;
repeated float float_value = 2;
repeated int64 int64_value = 3;
@@ -74,30 +45,31 @@ message RepeatedPrimitiveValue {
repeated PrimitiveValue message_value = 19;
// Optional fields with explicitly-specified defaults.
- optional double double_default = 20 [default = 1.0];
- optional float float_default = 21 [default = 2.0];
- optional int64 int64_default = 22 [default = 3];
- optional uint64 uint64_default = 23 [default = 4];
- optional int32 int32_default = 24 [default = 5];
- optional fixed64 fixed64_default = 25 [default = 6];
- optional fixed32 fixed32_default = 26 [default = 7];
- optional bool bool_default = 27 [default = true];
- optional string string_default = 28 [default = "a"];
- optional bytes bytes_default = 29 [default = "a longer default string"];
- optional uint32 uint32_default = 30 [default = 4294967295];
- optional sfixed32 sfixed32_default = 31 [default = 10];
- optional sfixed64 sfixed64_default = 32 [default = 11];
- optional sint32 sint32_default = 33 [default = 12];
- optional sint64 sint64_default = 34 [default = 13];
+ optional double double_value_with_default = 20 [default = 1.0];
+ optional float float_value_with_default = 21 [default = 2.0];
+ optional int64 int64_value_with_default = 22 [default = 3];
+ optional uint64 uint64_value_with_default = 23 [default = 4];
+ optional int32 int32_value_with_default = 24 [default = 5];
+ optional fixed64 fixed64_value_with_default = 25 [default = 6];
+ optional fixed32 fixed32_value_with_default = 26 [default = 7];
+ optional bool bool_value_with_default = 27 [default = true];
+ optional string string_value_with_default = 28 [default = "a"];
+ optional bytes bytes_value_with_default = 29
+ [default = "a longer default string"];
+ optional uint32 uint32_value_with_default = 30 [default = 9];
+ optional sfixed32 sfixed32_value_with_default = 31 [default = 10];
+ optional sfixed64 sfixed64_value_with_default = 32 [default = 11];
+ optional sint32 sint32_value_with_default = 33 [default = 12];
+ optional sint64 sint64_value_with_default = 34 [default = 13];
}
-// A PackedPrimitiveValue looks exactly the same as a RepeatedPrimitiveValue
-// in the text format, but the binary serializion is different.
-// We test the packed representations by loading the same test cases
-// using this definition instead of RepeatedPrimitiveValue.
-// NOTE: This definition must be kept in sync with RepeatedPrimitiveValue
-// in every way except the packed=true declaration.
-message PackedPrimitiveValue {
+// A PackedTestValue looks exactly the same as a TestValue in the text format,
+// but the binary serializion is different. We test the packed representations
+// by loading the same test cases using this definition instead of TestValue.
+//
+// NOTE: This definition must be kept in sync with TestValue in every way except
+// the packed=true declaration.
+message PackedTestValue {
repeated double double_value = 1 [packed = true];
repeated float float_value = 2 [packed = true];
repeated int64 int64_value = 3 [packed = true];
@@ -115,23 +87,53 @@ message PackedPrimitiveValue {
repeated sint64 sint64_value = 18 [packed = true];
repeated PrimitiveValue message_value = 19;
- optional double double_default = 20 [default = 1.0];
- optional float float_default = 21 [default = 2.0];
- optional int64 int64_default = 22 [default = 3];
- optional uint64 uint64_default = 23 [default = 4];
- optional int32 int32_default = 24 [default = 5];
- optional fixed64 fixed64_default = 25 [default = 6];
- optional fixed32 fixed32_default = 26 [default = 7];
- optional bool bool_default = 27 [default = true];
- optional string string_default = 28 [default = "a"];
- optional bytes bytes_default = 29 [default = "a longer default string"];
- optional uint32 uint32_default = 30 [default = 4294967295];
- optional sfixed32 sfixed32_default = 31 [default = 10];
- optional sfixed64 sfixed64_default = 32 [default = 11];
- optional sint32 sint32_default = 33 [default = 12];
- optional sint64 sint64_default = 34 [default = 13];
+ optional double double_value_with_default = 20 [default = 1.0];
+ optional float float_value_with_default = 21 [default = 2.0];
+ optional int64 int64_value_with_default = 22 [default = 3];
+ optional uint64 uint64_value_with_default = 23 [default = 4];
+ optional int32 int32_value_with_default = 24 [default = 5];
+ optional fixed64 fixed64_value_with_default = 25 [default = 6];
+ optional fixed32 fixed32_value_with_default = 26 [default = 7];
+ optional bool bool_value_with_default = 27 [default = true];
+ optional string string_value_with_default = 28 [default = "a"];
+ optional bytes bytes_value_with_default = 29
+ [default = "a longer default string"];
+ optional uint32 uint32_value_with_default = 30 [default = 9];
+ optional sfixed32 sfixed32_value_with_default = 31 [default = 10];
+ optional sfixed64 sfixed64_value_with_default = 32 [default = 11];
+ optional sint32 sint32_value_with_default = 33 [default = 12];
+ optional sint64 sint64_value_with_default = 34 [default = 13];
}
+message PrimitiveValue {
+ optional double double_value = 1;
+ optional float float_value = 2;
+ optional int64 int64_value = 3;
+ optional uint64 uint64_value = 4;
+ optional int32 int32_value = 5;
+ optional fixed64 fixed64_value = 6;
+ optional fixed32 fixed32_value = 7;
+ optional bool bool_value = 8;
+ optional string string_value = 9;
+ optional bytes bytes_value = 12;
+ optional uint32 uint32_value = 13;
+ optional sfixed32 sfixed32_value = 15;
+ optional sfixed64 sfixed64_value = 16;
+ optional sint32 sint32_value = 17;
+ optional sint64 sint64_value = 18;
+}
+
+// Message containing fields with field numbers higher than any field above.
+// An instance of this message is prepended to each binary message in the test
+// to exercise the code path that handles fields encoded out of order of field
+// number.
+message ExtraFields {
+ optional string string_value = 1776;
+ optional bool bool_value = 1777;
+}
+
+// The messages below are for yet-to-be created tests.
+
message EnumValue {
enum Color {
RED = 0;
@@ -171,12 +173,3 @@ message RepeatedMessageValue {
repeated NestedMessageValue message_values = 11;
}
-
-// Message containing fields with field numbers higher than any field above. An
-// instance of this message is prepended to each binary message in the test to
-// exercise the code path that handles fields encoded out of order of field
-// number.
-message ExtraFields {
- optional string string_value = 1776;
- optional bool bool_value = 1777;
-}