diff options
Diffstat (limited to 'tensorflow/contrib/proto')
20 files changed, 1251 insertions, 1153 deletions
diff --git a/tensorflow/contrib/proto/BUILD b/tensorflow/contrib/proto/BUILD index 3e9b1a0b8d..b27142cf4a 100644 --- a/tensorflow/contrib/proto/BUILD +++ b/tensorflow/contrib/proto/BUILD @@ -16,17 +16,3 @@ py_library( "//tensorflow/contrib/proto/python/ops:encode_proto_op_py", ], ) - -py_library( - name = "proto_pip", - data = [ - "//tensorflow/contrib/proto/python/kernel_tests:test_messages", - ] + if_static( - [], - otherwise = ["//tensorflow/contrib/proto/python/kernel_tests:libtestexample.so"], - ), - deps = [ - ":proto", - "//tensorflow/contrib/proto/python/kernel_tests:py_test_deps", - ], -) diff --git a/tensorflow/contrib/proto/python/kernel_tests/BUILD b/tensorflow/contrib/proto/python/kernel_tests/BUILD index a380a131f8..125c1cee29 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/BUILD +++ b/tensorflow/contrib/proto/python/kernel_tests/BUILD @@ -4,47 +4,41 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -# Much of the work in this BUILD file actually happens in the corresponding -# build_defs.bzl, which creates an individual testcase for each example .pbtxt -# file in this directory. -# -load(":build_defs.bzl", "decode_proto_test_suite") -load(":build_defs.bzl", "encode_proto_test_suite") - -# This expands to a tf_py_test for each test file. -# It defines the test_suite :decode_proto_op_tests. -decode_proto_test_suite( - name = "decode_proto_tests", - examples = glob(["*.pbtxt"]), -) - -# This expands to a tf_py_test for each test file. -# It defines the test_suite :encode_proto_op_tests. -encode_proto_test_suite( - name = "encode_proto_tests", - examples = glob(["*.pbtxt"]), -) - -# Below here are tests that are not tied to an example text proto. -filegroup( - name = "test_messages", - srcs = glob(["*.pbtxt"]), -) - load("//tensorflow:tensorflow.bzl", "tf_py_test") load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object") load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static") load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") tf_py_test( - name = "decode_proto_fail_test", + name = "decode_proto_op_test", size = "small", - srcs = ["decode_proto_fail_test.py"], + srcs = ["decode_proto_op_test.py"], additional_deps = [ + ":decode_proto_op_test_base", + ":py_test_deps", + "//tensorflow/contrib/proto:proto", + "//tensorflow/contrib/proto/python/ops:decode_proto_op_py", + ], + data = if_static( + [], + otherwise = [":libtestexample.so"], + ), + tags = [ + "no_pip", # TODO(b/78026780) + "no_windows", # TODO(b/78028010) + ], +) + +tf_py_test( + name = "encode_proto_op_test", + size = "small", + srcs = ["encode_proto_op_test.py"], + additional_deps = [ + ":encode_proto_op_test_base", ":py_test_deps", - "//third_party/py/numpy", "//tensorflow/contrib/proto:proto", "//tensorflow/contrib/proto/python/ops:decode_proto_op_py", + "//tensorflow/contrib/proto/python/ops:encode_proto_op_py", ], data = if_static( [], @@ -57,19 +51,41 @@ tf_py_test( ) py_library( - name = "test_case", - srcs = ["test_case.py"], - deps = ["//tensorflow/python:client_testlib"], + name = "proto_op_test_base", + testonly = 1, + srcs = ["proto_op_test_base.py"], + deps = [ + ":test_example_proto_py", + "//tensorflow/python:client_testlib", + ], +) + +py_library( + name = "decode_proto_op_test_base", + testonly = 1, + srcs = ["decode_proto_op_test_base.py"], + deps = [ + ":proto_op_test_base", + ":test_example_proto_py", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], ) py_library( - name = "py_test_deps", + name = "encode_proto_op_test_base", + testonly = 1, + srcs = ["encode_proto_op_test_base.py"], deps = [ - ":test_case", + ":proto_op_test_base", ":test_example_proto_py", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) +py_library(name = "py_test_deps") + tf_proto_library( name = "test_example_proto", srcs = ["test_example.proto"], @@ -84,3 +100,30 @@ tf_cc_shared_object( ":test_example_proto_cc", ], ) + +py_library( + name = "descriptor_source_test_base", + testonly = 1, + srcs = ["descriptor_source_test_base.py"], + deps = [ + ":proto_op_test_base", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + "@protobuf_archive//:protobuf_python", + ], +) + +tf_py_test( + name = "descriptor_source_test", + size = "small", + srcs = ["descriptor_source_test.py"], + additional_deps = [ + ":descriptor_source_test_base", + "//tensorflow/contrib/proto/python/ops:decode_proto_op_py", + "//tensorflow/contrib/proto/python/ops:encode_proto_op_py", + "//tensorflow/python:client_testlib", + ], + tags = [ + "no_pip", + ], +) diff --git a/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl b/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl deleted file mode 100644 index f425601691..0000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl +++ /dev/null @@ -1,89 +0,0 @@ -"""BUILD rules for generating file-driven proto test cases. - -The decode_proto_test_suite() and encode_proto_test_suite() rules take a list -of text protos and generates a tf_py_test() for each one. -""" - -load("//tensorflow:tensorflow.bzl", "tf_py_test") -load("//tensorflow:tensorflow.bzl", "register_extension_info") -load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static") - -def _test_name(test, path): - return "%s_%s_test" % (test, path.split("/")[-1].split(".")[0]) - -def decode_proto_test_suite(name, examples): - """Build the decode_proto py_test for each test filename.""" - for test_filename in examples: - tf_py_test( - name = _test_name("decode_proto", test_filename), - srcs = ["decode_proto_op_test.py"], - size = "small", - data = [test_filename] + if_static( - [], - otherwise = [":libtestexample.so"], - ), - main = "decode_proto_op_test.py", - args = [ - "--message_text_file=\"%s/%s\"" % (native.package_name(), test_filename), - ], - additional_deps = [ - ":py_test_deps", - "//third_party/py/numpy", - "//tensorflow/contrib/proto:proto", - "//tensorflow/contrib/proto/python/ops:decode_proto_op_py", - ], - tags = [ - "no_pip", # TODO(b/78026780) - "no_windows", # TODO(b/78028010) - ], - ) - native.test_suite( - name = name, - tests = [":" + _test_name("decode_proto", test_filename) - for test_filename in examples], - ) - -def encode_proto_test_suite(name, examples): - """Build the encode_proto py_test for each test filename.""" - for test_filename in examples: - tf_py_test( - name = _test_name("encode_proto", test_filename), - srcs = ["encode_proto_op_test.py"], - size = "small", - data = [test_filename] + if_static( - [], - otherwise = [":libtestexample.so"], - ), - main = "encode_proto_op_test.py", - args = [ - "--message_text_file=\"%s/%s\"" % (native.package_name(), test_filename), - ], - additional_deps = [ - ":py_test_deps", - "//third_party/py/numpy", - "//tensorflow/contrib/proto:proto", - "//tensorflow/contrib/proto/python/ops:decode_proto_op_py", - "//tensorflow/contrib/proto/python/ops:encode_proto_op_py", - ], - tags = [ - "no_pip", # TODO(b/78026780) - "no_windows", # TODO(b/78028010) - ], - ) - native.test_suite( - name = name, - tests = [":" + _test_name("encode_proto", test_filename) - for test_filename in examples], - ) - -register_extension_info( - extension_name = "decode_proto_test_suite", - label_regex_map = { - "deps": "deps:decode_example_.*", - }) - -register_extension_info( - extension_name = "encode_proto_test_suite", - label_regex_map = { - "deps": "deps:encode_example_.*", - }) diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py deleted file mode 100644 index 5298342ee7..0000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py +++ /dev/null @@ -1,68 +0,0 @@ -# ============================================================================= -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================= - -# Python3 preparedness imports. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.proto.python.kernel_tests import test_case -from tensorflow.contrib.proto.python.ops import decode_proto_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.platform import test - - -class DecodeProtoFailTest(test_case.ProtoOpTestCase): - """Test failure cases for DecodeToProto.""" - - def _TestCorruptProtobuf(self, sanitize): - """Test failure cases for DecodeToProto.""" - - # The goal here is to check the error reporting. - # Testing against a variety of corrupt protobufs is - # done by fuzzing. - corrupt_proto = 'This is not a binary protobuf' - - # Numpy silently truncates the strings if you don't specify dtype=object. - batch = np.array(corrupt_proto, dtype=object) - msg_type = 'tensorflow.contrib.proto.TestCase' - field_names = ['sizes'] - field_types = [dtypes.int32] - - with self.test_session() as sess: - ctensor, vtensor = decode_proto_op.decode_proto( - batch, - message_type=msg_type, - field_names=field_names, - output_types=field_types, - sanitize=sanitize) - with self.assertRaisesRegexp(errors.DataLossError, - 'Unable to parse binary protobuf' - '|Failed to consume entire buffer'): - _ = sess.run([ctensor] + vtensor) - - def testCorrupt(self): - self._TestCorruptProtobuf(sanitize=False) - - def testSanitizerCorrupt(self): - self._TestCorruptProtobuf(sanitize=True) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py index d1c13c82bc..934035ec4c 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py +++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py @@ -13,287 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= -"""Table-driven test for decode_proto op. +"""Tests for decode_proto op.""" -This test is run once with each of the *.TestCase.pbtxt files -in the test directory. -""" # Python3 preparedness imports. from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - -from google.protobuf import text_format - -from tensorflow.contrib.proto.python.kernel_tests import test_case -from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2 +from tensorflow.contrib.proto.python.kernel_tests import decode_proto_op_test_base as test_base from tensorflow.contrib.proto.python.ops import decode_proto_op -from tensorflow.python.framework import dtypes -from tensorflow.python.platform import flags from tensorflow.python.platform import test -FLAGS = flags.FLAGS - -flags.DEFINE_string('message_text_file', None, - 'A file containing a text serialized TestCase protobuf.') - - -class DecodeProtoOpTest(test_case.ProtoOpTestCase): - - def _compareValues(self, fd, vs, evs): - """Compare lists/arrays of field values.""" - - if len(vs) != len(evs): - self.fail('Field %s decoded %d outputs, expected %d' % - (fd.name, len(vs), len(evs))) - for i, ev in enumerate(evs): - # Special case fuzzy match for float32. TensorFlow seems to mess with - # MAX_FLT slightly and the test doesn't work otherwise. - # TODO(nix): ask on TF list about why MAX_FLT doesn't pass through. - if fd.cpp_type == fd.CPPTYPE_FLOAT: - # Numpy isclose() is better than assertIsClose() which uses an absolute - # value comparison. - self.assertTrue( - np.isclose(vs[i], ev), 'expected %r, actual %r' % (ev, vs[i])) - elif fd.cpp_type == fd.CPPTYPE_STRING: - # In Python3 string tensor values will be represented as bytes, so we - # reencode the proto values to match that. - self.assertEqual(vs[i], ev.encode('ascii')) - else: - # Doubles and other types pass through unscathed. - self.assertEqual(vs[i], ev) - - def _compareRepeatedPrimitiveValue(self, batch_shape, sizes, fields, - field_dict): - """Compare protos of type RepeatedPrimitiveValue. - - Args: - batch_shape: the shape of the input tensor of serialized messages. - sizes: int matrix of repeat counts returned by decode_proto - fields: list of test_example_pb2.FieldSpec (types and expected values) - field_dict: map from field names to decoded numpy tensors of values - """ - - # Check that expected values match. - for field in fields: - values = field_dict[field.name] - self.assertEqual(dtypes.as_dtype(values.dtype), field.dtype) - - fd = field.expected.DESCRIPTOR.fields_by_name[field.name] - - # Values has the same shape as the input plus an extra - # dimension for repeats. - self.assertEqual(list(values.shape)[:-1], batch_shape) - - # Nested messages are represented as TF strings, requiring - # some special handling. - if field.name == 'message_value': - vs = [] - for buf in values.flat: - msg = test_example_pb2.PrimitiveValue() - msg.ParseFromString(buf) - vs.append(msg) - evs = getattr(field.expected, field.name) - if len(vs) != len(evs): - self.fail('Field %s decoded %d outputs, expected %d' % - (fd.name, len(vs), len(evs))) - for v, ev in zip(vs, evs): - self.assertEqual(v, ev) - continue - - # This can be a little confusing. For testing we are using - # RepeatedPrimitiveValue in two ways: it's the proto that we - # decode for testing, and it's used in the expected value as a - # union type. The two cases are slightly different: this is the - # second case. - # We may be fetching the uint64_value from the test proto, but - # in the expected proto we store it in the int64_value field - # because TensorFlow doesn't support unsigned int64. - tf_type_to_primitive_value_field = { - dtypes.float32: - 'float_value', - dtypes.float64: - 'double_value', - dtypes.int32: - 'int32_value', - dtypes.uint8: - 'uint8_value', - dtypes.int8: - 'int8_value', - dtypes.string: - 'string_value', - dtypes.int64: - 'int64_value', - dtypes.bool: - 'bool_value', - # Unhandled TensorFlow types: - # DT_INT16 DT_COMPLEX64 DT_QINT8 DT_QUINT8 DT_QINT32 - # DT_BFLOAT16 DT_QINT16 DT_QUINT16 DT_UINT16 - } - tf_field_name = tf_type_to_primitive_value_field.get(field.dtype) - if tf_field_name is None: - self.fail('Unhandled tensorflow type %d' % field.dtype) - - self._compareValues(fd, values.flat, - getattr(field.expected, tf_field_name)) - - def _runDecodeProtoTests(self, fields, case_sizes, batch_shape, batch, - message_type, message_format, sanitize, - force_disordered=False): - """Run decode tests on a batch of messages. - - Args: - fields: list of test_example_pb2.FieldSpec (types and expected values) - case_sizes: expected sizes array - batch_shape: the shape of the input tensor of serialized messages - batch: list of serialized messages - message_type: descriptor name for messages - message_format: format of messages, 'text' or 'binary' - sanitize: whether to sanitize binary protobuf inputs - force_disordered: whether to force fields encoded out of order. - """ - - if force_disordered: - # Exercise code path that handles out-of-order fields by prepending extra - # fields with tag numbers higher than any real field. Note that this won't - # work with sanitization because that forces reserialization using a - # trusted decoder and encoder. - assert not sanitize - extra_fields = test_example_pb2.ExtraFields() - extra_fields.string_value = 'IGNORE ME' - extra_fields.bool_value = False - extra_msg = extra_fields.SerializeToString() - batch = [extra_msg + msg for msg in batch] - - # Numpy silently truncates the strings if you don't specify dtype=object. - batch = np.array(batch, dtype=object) - batch = np.reshape(batch, batch_shape) - - field_names = [f.name for f in fields] - output_types = [f.dtype for f in fields] - - with self.test_session() as sess: - sizes, vtensor = decode_proto_op.decode_proto( - batch, - message_type=message_type, - field_names=field_names, - output_types=output_types, - message_format=message_format, - sanitize=sanitize) - - vlist = sess.run([sizes] + vtensor) - sizes = vlist[0] - # Values is a list of tensors, one for each field. - value_tensors = vlist[1:] - - # Check that the repeat sizes are correct. - self.assertTrue( - np.all(np.array(sizes.shape) == batch_shape + [len(field_names)])) - - # Check that the decoded sizes match the expected sizes. - self.assertEqual(len(sizes.flat), len(case_sizes)) - self.assertTrue( - np.all(sizes.flat == np.array( - case_sizes, dtype=np.int32))) - - field_dict = dict(zip(field_names, value_tensors)) - - self._compareRepeatedPrimitiveValue(batch_shape, sizes, fields, - field_dict) - - def testBinary(self): - with open(FLAGS.message_text_file, 'r') as fp: - case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) - - batch = [primitive.SerializeToString() for primitive in case.primitive] - self._runDecodeProtoTests( - case.field, - case.sizes, - list(case.shape), - batch, - 'tensorflow.contrib.proto.RepeatedPrimitiveValue', - 'binary', - sanitize=False) - - def testBinaryDisordered(self): - with open(FLAGS.message_text_file, 'r') as fp: - case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) - - batch = [primitive.SerializeToString() for primitive in case.primitive] - self._runDecodeProtoTests( - case.field, - case.sizes, - list(case.shape), - batch, - 'tensorflow.contrib.proto.RepeatedPrimitiveValue', - 'binary', - sanitize=False, - force_disordered=True) - - def testPacked(self): - with open(FLAGS.message_text_file, 'r') as fp: - case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) - - # Now try with the packed serialization. - # We test the packed representations by loading the same test cases - # using PackedPrimitiveValue instead of RepeatedPrimitiveValue. - # To do this we rely on the text format being the same for packed and - # unpacked fields, and reparse the test message using the packed version - # of the proto. - packed_batch = [ - # Note: float_format='.17g' is necessary to ensure preservation of - # doubles and floats in text format. - text_format.Parse( - text_format.MessageToString( - primitive, float_format='.17g'), - test_example_pb2.PackedPrimitiveValue()).SerializeToString() - for primitive in case.primitive - ] - - self._runDecodeProtoTests( - case.field, - case.sizes, - list(case.shape), - packed_batch, - 'tensorflow.contrib.proto.PackedPrimitiveValue', - 'binary', - sanitize=False) - - def testText(self): - with open(FLAGS.message_text_file, 'r') as fp: - case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) - - # Note: float_format='.17g' is necessary to ensure preservation of - # doubles and floats in text format. - text_batch = [ - text_format.MessageToString( - primitive, float_format='.17g') for primitive in case.primitive - ] - - self._runDecodeProtoTests( - case.field, - case.sizes, - list(case.shape), - text_batch, - 'tensorflow.contrib.proto.RepeatedPrimitiveValue', - 'text', - sanitize=False) - def testSanitizerGood(self): - with open(FLAGS.message_text_file, 'r') as fp: - case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) +class DecodeProtoOpTest(test_base.DecodeProtoOpTestBase): - batch = [primitive.SerializeToString() for primitive in case.primitive] - self._runDecodeProtoTests( - case.field, - case.sizes, - list(case.shape), - batch, - 'tensorflow.contrib.proto.RepeatedPrimitiveValue', - 'binary', - sanitize=True) + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + super(DecodeProtoOpTest, self).__init__(decode_proto_op, methodName) if __name__ == '__main__': diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py new file mode 100644 index 0000000000..17b69c7b35 --- /dev/null +++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py @@ -0,0 +1,303 @@ +# ============================================================================= +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Tests for decode_proto op.""" + +# Python3 preparedness imports. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + + +from google.protobuf import text_format + +from tensorflow.contrib.proto.python.kernel_tests import proto_op_test_base as test_base +from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors + + +class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): + """Base class for testing proto decoding ops.""" + + def __init__(self, decode_module, methodName='runTest'): # pylint: disable=invalid-name + """DecodeProtoOpTestBase initializer. + + Args: + decode_module: a module containing the `decode_proto_op` method + methodName: the name of the test method (same as for test.TestCase) + """ + + super(DecodeProtoOpTestBase, self).__init__(methodName) + self._decode_module = decode_module + + def _compareValues(self, fd, vs, evs): + """Compare lists/arrays of field values.""" + + if len(vs) != len(evs): + self.fail('Field %s decoded %d outputs, expected %d' % + (fd.name, len(vs), len(evs))) + for i, ev in enumerate(evs): + # Special case fuzzy match for float32. TensorFlow seems to mess with + # MAX_FLT slightly and the test doesn't work otherwise. + # TODO(nix): ask on TF list about why MAX_FLT doesn't pass through. + if fd.cpp_type == fd.CPPTYPE_FLOAT: + # Numpy isclose() is better than assertIsClose() which uses an absolute + # value comparison. + self.assertTrue( + np.isclose(vs[i], ev), 'expected %r, actual %r' % (ev, vs[i])) + elif fd.cpp_type == fd.CPPTYPE_STRING: + # In Python3 string tensor values will be represented as bytes, so we + # reencode the proto values to match that. + self.assertEqual(vs[i], ev.encode('ascii')) + else: + # Doubles and other types pass through unscathed. + self.assertEqual(vs[i], ev) + + def _compareProtos(self, batch_shape, sizes, fields, field_dict): + """Compare protos of type TestValue. + + Args: + batch_shape: the shape of the input tensor of serialized messages. + sizes: int matrix of repeat counts returned by decode_proto + fields: list of test_example_pb2.FieldSpec (types and expected values) + field_dict: map from field names to decoded numpy tensors of values + """ + + # Check that expected values match. + for field in fields: + values = field_dict[field.name] + self.assertEqual(dtypes.as_dtype(values.dtype), field.dtype) + + fd = field.value.DESCRIPTOR.fields_by_name[field.name] + + # Values has the same shape as the input plus an extra + # dimension for repeats. + self.assertEqual(list(values.shape)[:-1], batch_shape) + + # Nested messages are represented as TF strings, requiring + # some special handling. + if field.name == 'message_value': + vs = [] + for buf in values.flat: + msg = test_example_pb2.PrimitiveValue() + msg.ParseFromString(buf) + vs.append(msg) + evs = getattr(field.value, field.name) + if len(vs) != len(evs): + self.fail('Field %s decoded %d outputs, expected %d' % + (fd.name, len(vs), len(evs))) + for v, ev in zip(vs, evs): + self.assertEqual(v, ev) + continue + + tf_type_to_primitive_value_field = { + dtypes.bool: + 'bool_value', + dtypes.float32: + 'float_value', + dtypes.float64: + 'double_value', + dtypes.int8: + 'int8_value', + dtypes.int32: + 'int32_value', + dtypes.int64: + 'int64_value', + dtypes.string: + 'string_value', + dtypes.uint8: + 'uint8_value', + dtypes.uint32: + 'uint32_value', + dtypes.uint64: + 'uint64_value', + } + tf_field_name = tf_type_to_primitive_value_field.get(field.dtype) + if tf_field_name is None: + self.fail('Unhandled tensorflow type %d' % field.dtype) + + self._compareValues(fd, values.flat, + getattr(field.value, tf_field_name)) + + def _runDecodeProtoTests(self, fields, case_sizes, batch_shape, batch, + message_type, message_format, sanitize, + force_disordered=False): + """Run decode tests on a batch of messages. + + Args: + fields: list of test_example_pb2.FieldSpec (types and expected values) + case_sizes: expected sizes array + batch_shape: the shape of the input tensor of serialized messages + batch: list of serialized messages + message_type: descriptor name for messages + message_format: format of messages, 'text' or 'binary' + sanitize: whether to sanitize binary protobuf inputs + force_disordered: whether to force fields encoded out of order. + """ + + if force_disordered: + # Exercise code path that handles out-of-order fields by prepending extra + # fields with tag numbers higher than any real field. Note that this won't + # work with sanitization because that forces reserialization using a + # trusted decoder and encoder. + assert not sanitize + extra_fields = test_example_pb2.ExtraFields() + extra_fields.string_value = 'IGNORE ME' + extra_fields.bool_value = False + extra_msg = extra_fields.SerializeToString() + batch = [extra_msg + msg for msg in batch] + + # Numpy silently truncates the strings if you don't specify dtype=object. + batch = np.array(batch, dtype=object) + batch = np.reshape(batch, batch_shape) + + field_names = [f.name for f in fields] + output_types = [f.dtype for f in fields] + + with self.cached_session() as sess: + sizes, vtensor = self._decode_module.decode_proto( + batch, + message_type=message_type, + field_names=field_names, + output_types=output_types, + message_format=message_format, + sanitize=sanitize) + + vlist = sess.run([sizes] + vtensor) + sizes = vlist[0] + # Values is a list of tensors, one for each field. + value_tensors = vlist[1:] + + # Check that the repeat sizes are correct. + self.assertTrue( + np.all(np.array(sizes.shape) == batch_shape + [len(field_names)])) + + # Check that the decoded sizes match the expected sizes. + self.assertEqual(len(sizes.flat), len(case_sizes)) + self.assertTrue( + np.all(sizes.flat == np.array( + case_sizes, dtype=np.int32))) + + field_dict = dict(zip(field_names, value_tensors)) + + self._compareProtos(batch_shape, sizes, fields, field_dict) + + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testBinary(self, case): + batch = [value.SerializeToString() for value in case.values] + self._runDecodeProtoTests( + case.fields, + case.sizes, + list(case.shapes), + batch, + 'tensorflow.contrib.proto.TestValue', + 'binary', + sanitize=False) + + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testBinaryDisordered(self, case): + batch = [value.SerializeToString() for value in case.values] + self._runDecodeProtoTests( + case.fields, + case.sizes, + list(case.shapes), + batch, + 'tensorflow.contrib.proto.TestValue', + 'binary', + sanitize=False, + force_disordered=True) + + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testPacked(self, case): + # Now try with the packed serialization. + # + # We test the packed representations by loading the same test case using + # PackedTestValue instead of TestValue. To do this we rely on the text + # format being the same for packed and unpacked fields, and reparse the + # test message using the packed version of the proto. + packed_batch = [ + # Note: float_format='.17g' is necessary to ensure preservation of + # doubles and floats in text format. + text_format.Parse( + text_format.MessageToString( + value, float_format='.17g'), + test_example_pb2.PackedTestValue()).SerializeToString() + for value in case.values + ] + + self._runDecodeProtoTests( + case.fields, + case.sizes, + list(case.shapes), + packed_batch, + 'tensorflow.contrib.proto.PackedTestValue', + 'binary', + sanitize=False) + + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testText(self, case): + # Note: float_format='.17g' is necessary to ensure preservation of + # doubles and floats in text format. + text_batch = [ + text_format.MessageToString( + value, float_format='.17g') for value in case.values + ] + + self._runDecodeProtoTests( + case.fields, + case.sizes, + list(case.shapes), + text_batch, + 'tensorflow.contrib.proto.TestValue', + 'text', + sanitize=False) + + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testSanitizerGood(self, case): + batch = [value.SerializeToString() for value in case.values] + self._runDecodeProtoTests( + case.fields, + case.sizes, + list(case.shapes), + batch, + 'tensorflow.contrib.proto.TestValue', + 'binary', + sanitize=True) + + @parameterized.parameters((False), (True)) + def testCorruptProtobuf(self, sanitize): + corrupt_proto = 'This is not a binary protobuf' + + # Numpy silently truncates the strings if you don't specify dtype=object. + batch = np.array(corrupt_proto, dtype=object) + msg_type = 'tensorflow.contrib.proto.TestCase' + field_names = ['sizes'] + field_types = [dtypes.int32] + + with self.cached_session() as sess: + ctensor, vtensor = self._decode_module.decode_proto( + batch, + message_type=msg_type, + field_names=field_names, + output_types=field_types, + sanitize=sanitize) + with self.assertRaisesRegexp(errors.DataLossError, + 'Unable to parse binary protobuf' + '|Failed to consume entire buffer'): + _ = sess.run([ctensor] + vtensor) diff --git a/tensorflow/contrib/proto/python/kernel_tests/defaut_values.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/defaut_values.TestCase.pbtxt deleted file mode 100644 index 4e31681907..0000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/defaut_values.TestCase.pbtxt +++ /dev/null @@ -1,94 +0,0 @@ -primitive { - # No fields specified, so we get all defaults -} -shape: 1 -sizes: 0 -field { - name: "double_default" - dtype: DT_DOUBLE - expected { double_value: 1.0 } -} -sizes: 0 -field { - name: "float_default" - dtype: DT_DOUBLE # Try casting the float field to double. - expected { double_value: 2.0 } -} -sizes: 0 -field { - name: "int64_default" - dtype: DT_INT64 - expected { int64_value: 3 } -} -sizes: 0 -field { - name: "uint64_default" - dtype: DT_INT64 - expected { int64_value: 4 } -} -sizes: 0 -field { - name: "int32_default" - dtype: DT_INT32 - expected { int32_value: 5 } -} -sizes: 0 -field { - name: "fixed64_default" - dtype: DT_INT64 - expected { int64_value: 6 } -} -sizes: 0 -field { - name: "fixed32_default" - dtype: DT_INT32 - expected { int32_value: 7 } -} -sizes: 0 -field { - name: "bool_default" - dtype: DT_BOOL - expected { bool_value: true } -} -sizes: 0 -field { - name: "string_default" - dtype: DT_STRING - expected { string_value: "a" } -} -sizes: 0 -field { - name: "bytes_default" - dtype: DT_STRING - expected { string_value: "a longer default string" } -} -sizes: 0 -field { - name: "uint32_default" - dtype: DT_INT32 - expected { int32_value: -1 } -} -sizes: 0 -field { - name: "sfixed32_default" - dtype: DT_INT32 - expected { int32_value: 10 } -} -sizes: 0 -field { - name: "sfixed64_default" - dtype: DT_INT64 - expected { int64_value: 11 } -} -sizes: 0 -field { - name: "sint32_default" - dtype: DT_INT32 - expected { int32_value: 12 } -} -sizes: 0 -field { - name: "sint64_default" - dtype: DT_INT64 - expected { int64_value: 13 } -} diff --git a/tensorflow/contrib/proto/python/kernel_tests/test_case.py b/tensorflow/contrib/proto/python/kernel_tests/descriptor_source_test.py index b95202c5df..32ca318f73 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/test_case.py +++ b/tensorflow/contrib/proto/python/kernel_tests/descriptor_source_test.py @@ -13,23 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= -"""Test case base for testing proto operations.""" - +"""Tests for proto ops reading descriptors from other sources.""" # Python3 preparedness imports. from __future__ import absolute_import from __future__ import division from __future__ import print_function -import ctypes as ct -import os - +from tensorflow.contrib.proto.python.kernel_tests import descriptor_source_test_base as test_base +from tensorflow.contrib.proto.python.ops import decode_proto_op +from tensorflow.contrib.proto.python.ops import encode_proto_op from tensorflow.python.platform import test -class ProtoOpTestCase(test.TestCase): +class DescriptorSourceTest(test_base.DescriptorSourceTestBase): def __init__(self, methodName='runTest'): # pylint: disable=invalid-name - super(ProtoOpTestCase, self).__init__(methodName) - lib = os.path.join(os.path.dirname(__file__), 'libtestexample.so') - if os.path.isfile(lib): - ct.cdll.LoadLibrary(lib) + super(DescriptorSourceTest, self).__init__(decode_proto_op, encode_proto_op, + methodName) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/proto/python/kernel_tests/descriptor_source_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/descriptor_source_test_base.py new file mode 100644 index 0000000000..7e9b355c69 --- /dev/null +++ b/tensorflow/contrib/proto/python/kernel_tests/descriptor_source_test_base.py @@ -0,0 +1,176 @@ +# ============================================================================= +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Tests for proto ops reading descriptors from other sources.""" +# Python3 preparedness imports. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import numpy as np + +from google.protobuf.descriptor_pb2 import FieldDescriptorProto +from google.protobuf.descriptor_pb2 import FileDescriptorSet +from tensorflow.contrib.proto.python.kernel_tests import proto_op_test_base as test_base +from tensorflow.python.framework import dtypes +from tensorflow.python.platform import test + + +class DescriptorSourceTestBase(test.TestCase): + """Base class for testing descriptor sources.""" + + def __init__(self, decode_module, encode_module, methodName='runTest'): # pylint: disable=invalid-name + """DescriptorSourceTestBase initializer. + + Args: + decode_module: a module containing the `decode_proto_op` method + encode_module: a module containing the `encode_proto_op` method + methodName: the name of the test method (same as for test.TestCase) + """ + + super(DescriptorSourceTestBase, self).__init__(methodName) + self._decode_module = decode_module + self._encode_module = encode_module + + # NOTE: We generate the descriptor programmatically instead of via a compiler + # because of differences between different versions of the compiler. + # + # The generated descriptor should capture the subset of `test_example.proto` + # used in `test_base.simple_test_case()`. + def _createDescriptorFile(self): + set_proto = FileDescriptorSet() + + file_proto = set_proto.file.add( + name='types.proto', + package='tensorflow', + syntax='proto3') + enum_proto = file_proto.enum_type.add(name='DataType') + enum_proto.value.add(name='DT_DOUBLE', number=0) + enum_proto.value.add(name='DT_BOOL', number=1) + + file_proto = set_proto.file.add( + name='test_example.proto', + package='tensorflow.contrib.proto', + dependency=['types.proto']) + message_proto = file_proto.message_type.add(name='TestCase') + message_proto.field.add( + name='values', + number=1, + type=FieldDescriptorProto.TYPE_MESSAGE, + type_name='.tensorflow.contrib.proto.TestValue', + label=FieldDescriptorProto.LABEL_REPEATED) + message_proto.field.add( + name='shapes', + number=2, + type=FieldDescriptorProto.TYPE_INT32, + label=FieldDescriptorProto.LABEL_REPEATED) + message_proto.field.add( + name='sizes', + number=3, + type=FieldDescriptorProto.TYPE_INT32, + label=FieldDescriptorProto.LABEL_REPEATED) + message_proto.field.add( + name='fields', + number=4, + type=FieldDescriptorProto.TYPE_MESSAGE, + type_name='.tensorflow.contrib.proto.FieldSpec', + label=FieldDescriptorProto.LABEL_REPEATED) + + message_proto = file_proto.message_type.add( + name='TestValue') + message_proto.field.add( + name='double_value', + number=1, + type=FieldDescriptorProto.TYPE_DOUBLE, + label=FieldDescriptorProto.LABEL_REPEATED) + message_proto.field.add( + name='bool_value', + number=2, + type=FieldDescriptorProto.TYPE_BOOL, + label=FieldDescriptorProto.LABEL_REPEATED) + + message_proto = file_proto.message_type.add( + name='FieldSpec') + message_proto.field.add( + name='name', + number=1, + type=FieldDescriptorProto.TYPE_STRING, + label=FieldDescriptorProto.LABEL_OPTIONAL) + message_proto.field.add( + name='dtype', + number=2, + type=FieldDescriptorProto.TYPE_ENUM, + type_name='.tensorflow.DataType', + label=FieldDescriptorProto.LABEL_OPTIONAL) + message_proto.field.add( + name='value', + number=3, + type=FieldDescriptorProto.TYPE_MESSAGE, + type_name='.tensorflow.contrib.proto.TestValue', + label=FieldDescriptorProto.LABEL_OPTIONAL) + + fn = os.path.join(self.get_temp_dir(), 'descriptor.pb') + with open(fn, 'wb') as f: + f.write(set_proto.SerializeToString()) + return fn + + def _testRoundtrip(self, descriptor_source): + # Numpy silently truncates the strings if you don't specify dtype=object. + in_bufs = np.array( + [test_base.ProtoOpTestBase.simple_test_case().SerializeToString()], + dtype=object) + message_type = 'tensorflow.contrib.proto.TestCase' + field_names = ['values', 'shapes', 'sizes', 'fields'] + tensor_types = [dtypes.string, dtypes.int32, dtypes.int32, dtypes.string] + + with self.cached_session() as sess: + sizes, field_tensors = self._decode_module.decode_proto( + in_bufs, + message_type=message_type, + field_names=field_names, + output_types=tensor_types, + descriptor_source=descriptor_source) + + out_tensors = self._encode_module.encode_proto( + sizes, + field_tensors, + message_type=message_type, + field_names=field_names, + descriptor_source=descriptor_source) + + out_bufs, = sess.run([out_tensors]) + + # Check that the re-encoded tensor has the same shape. + self.assertEqual(in_bufs.shape, out_bufs.shape) + + # Compare the input and output. + for in_buf, out_buf in zip(in_bufs.flat, out_bufs.flat): + # Check that the input and output serialized messages are identical. + # If we fail here, there is a difference in the serialized + # representation but the new serialization still parses. This could + # be harmless (a change in map ordering?) or it could be bad (e.g. + # loss of packing in the encoding). + self.assertEqual(in_buf, out_buf) + + def testWithFileDescriptorSet(self): + # First try parsing with a local proto db, which should fail. + with self.assertRaisesOpError('No descriptor found for message type'): + self._testRoundtrip('local://') + + # Now try parsing with a FileDescriptorSet which contains the test proto. + descriptor_file = self._createDescriptorFile() + self._testRoundtrip(descriptor_file) diff --git a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py index 30e58e6336..fc5cd25d43 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py +++ b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py @@ -13,167 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= -"""Table-driven test for encode_proto op. +"""Tests for encode_proto op.""" -This test is run once with each of the *.TestCase.pbtxt files -in the test directory. - -It tests that encode_proto is a lossless inverse of decode_proto -(for the specified fields). -""" # Python3 readiness boilerplate from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - -from google.protobuf import text_format - -from tensorflow.contrib.proto.python.kernel_tests import test_case -from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2 +from tensorflow.contrib.proto.python.kernel_tests import encode_proto_op_test_base as test_base from tensorflow.contrib.proto.python.ops import decode_proto_op from tensorflow.contrib.proto.python.ops import encode_proto_op -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.platform import flags from tensorflow.python.platform import test -FLAGS = flags.FLAGS - -flags.DEFINE_string('message_text_file', None, - 'A file containing a text serialized TestCase protobuf.') - - -class EncodeProtoOpTest(test_case.ProtoOpTestCase): - - def testBadInputs(self): - # Invalid field name - with self.test_session(): - with self.assertRaisesOpError('Unknown field: non_existent_field'): - encode_proto_op.encode_proto( - sizes=[[1]], - values=[np.array([[0.0]], dtype=np.int32)], - message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue', - field_names=['non_existent_field']).eval() - - # Incorrect types. - with self.test_session(): - with self.assertRaisesOpError( - 'Incompatible type for field double_value.'): - encode_proto_op.encode_proto( - sizes=[[1]], - values=[np.array([[0.0]], dtype=np.int32)], - message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue', - field_names=['double_value']).eval() - - # Incorrect shapes of sizes. - with self.test_session(): - with self.assertRaisesOpError( - r'sizes should be batch_size \+ \[len\(field_names\)\]'): - sizes = array_ops.placeholder(dtypes.int32) - values = array_ops.placeholder(dtypes.float64) - encode_proto_op.encode_proto( - sizes=sizes, - values=[values], - message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue', - field_names=['double_value']).eval(feed_dict={ - sizes: [[[0, 0]]], - values: [[0.0]] - }) - - # Inconsistent shapes of values. - with self.test_session(): - with self.assertRaisesOpError( - 'Values must match up to the last dimension'): - sizes = array_ops.placeholder(dtypes.int32) - values1 = array_ops.placeholder(dtypes.float64) - values2 = array_ops.placeholder(dtypes.int32) - (encode_proto_op.encode_proto( - sizes=[[1, 1]], - values=[values1, values2], - message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue', - field_names=['double_value', 'int32_value']).eval(feed_dict={ - values1: [[0.0]], - values2: [[0], [0]] - })) - - def _testRoundtrip(self, in_bufs, message_type, fields): - - field_names = [f.name for f in fields] - out_types = [f.dtype for f in fields] - - with self.test_session() as sess: - sizes, field_tensors = decode_proto_op.decode_proto( - in_bufs, - message_type=message_type, - field_names=field_names, - output_types=out_types) - - out_tensors = encode_proto_op.encode_proto( - sizes, - field_tensors, - message_type=message_type, - field_names=field_names) - - out_bufs, = sess.run([out_tensors]) - - # Check that the re-encoded tensor has the same shape. - self.assertEqual(in_bufs.shape, out_bufs.shape) - - # Compare the input and output. - for in_buf, out_buf in zip(in_bufs.flat, out_bufs.flat): - in_obj = test_example_pb2.RepeatedPrimitiveValue() - in_obj.ParseFromString(in_buf) - - out_obj = test_example_pb2.RepeatedPrimitiveValue() - out_obj.ParseFromString(out_buf) - - # Check that the deserialized objects are identical. - self.assertEqual(in_obj, out_obj) - - # Check that the input and output serialized messages are identical. - # If we fail here, there is a difference in the serialized - # representation but the new serialization still parses. This could - # be harmless (a change in map ordering?) or it could be bad (e.g. - # loss of packing in the encoding). - self.assertEqual(in_buf, out_buf) - - def testRoundtrip(self): - with open(FLAGS.message_text_file, 'r') as fp: - case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) - - in_bufs = [primitive.SerializeToString() for primitive in case.primitive] - - # np.array silently truncates strings if you don't specify dtype=object. - in_bufs = np.reshape(np.array(in_bufs, dtype=object), list(case.shape)) - return self._testRoundtrip( - in_bufs, 'tensorflow.contrib.proto.RepeatedPrimitiveValue', case.field) - - def testRoundtripPacked(self): - with open(FLAGS.message_text_file, 'r') as fp: - case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) - # Now try with the packed serialization. - # We test the packed representations by loading the same test cases - # using PackedPrimitiveValue instead of RepeatedPrimitiveValue. - # To do this we rely on the text format being the same for packed and - # unpacked fields, and reparse the test message using the packed version - # of the proto. - in_bufs = [ - # Note: float_format='.17g' is necessary to ensure preservation of - # doubles and floats in text format. - text_format.Parse( - text_format.MessageToString( - primitive, float_format='.17g'), - test_example_pb2.PackedPrimitiveValue()).SerializeToString() - for primitive in case.primitive - ] +class EncodeProtoOpTest(test_base.EncodeProtoOpTestBase): - # np.array silently truncates strings if you don't specify dtype=object. - in_bufs = np.reshape(np.array(in_bufs, dtype=object), list(case.shape)) - return self._testRoundtrip( - in_bufs, 'tensorflow.contrib.proto.PackedPrimitiveValue', case.field) + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + super(EncodeProtoOpTest, self).__init__(decode_proto_op, encode_proto_op, + methodName) if __name__ == '__main__': diff --git a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py new file mode 100644 index 0000000000..01b3ccc7fd --- /dev/null +++ b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py @@ -0,0 +1,177 @@ +# ============================================================================= +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Table-driven test for encode_proto op. + +This test is run once with each of the *.TestCase.pbtxt files +in the test directory. + +It tests that encode_proto is a lossless inverse of decode_proto +(for the specified fields). +""" +# Python3 readiness boilerplate +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from google.protobuf import text_format + +from tensorflow.contrib.proto.python.kernel_tests import proto_op_test_base as test_base +from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops + + +class EncodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): + """Base class for testing proto encoding ops.""" + + def __init__(self, decode_module, encode_module, methodName='runTest'): # pylint: disable=invalid-name + """EncodeProtoOpTestBase initializer. + + Args: + decode_module: a module containing the `decode_proto_op` method + encode_module: a module containing the `encode_proto_op` method + methodName: the name of the test method (same as for test.TestCase) + """ + + super(EncodeProtoOpTestBase, self).__init__(methodName) + self._decode_module = decode_module + self._encode_module = encode_module + + def testBadInputs(self): + # Invalid field name + with self.cached_session(): + with self.assertRaisesOpError('Unknown field: non_existent_field'): + self._encode_module.encode_proto( + sizes=[[1]], + values=[np.array([[0.0]], dtype=np.int32)], + message_type='tensorflow.contrib.proto.TestValue', + field_names=['non_existent_field']).eval() + + # Incorrect types. + with self.cached_session(): + with self.assertRaisesOpError( + 'Incompatible type for field double_value.'): + self._encode_module.encode_proto( + sizes=[[1]], + values=[np.array([[0.0]], dtype=np.int32)], + message_type='tensorflow.contrib.proto.TestValue', + field_names=['double_value']).eval() + + # Incorrect shapes of sizes. + with self.cached_session(): + with self.assertRaisesOpError( + r'sizes should be batch_size \+ \[len\(field_names\)\]'): + sizes = array_ops.placeholder(dtypes.int32) + values = array_ops.placeholder(dtypes.float64) + self._encode_module.encode_proto( + sizes=sizes, + values=[values], + message_type='tensorflow.contrib.proto.TestValue', + field_names=['double_value']).eval(feed_dict={ + sizes: [[[0, 0]]], + values: [[0.0]] + }) + + # Inconsistent shapes of values. + with self.cached_session(): + with self.assertRaisesOpError( + 'Values must match up to the last dimension'): + sizes = array_ops.placeholder(dtypes.int32) + values1 = array_ops.placeholder(dtypes.float64) + values2 = array_ops.placeholder(dtypes.int32) + (self._encode_module.encode_proto( + sizes=[[1, 1]], + values=[values1, values2], + message_type='tensorflow.contrib.proto.TestValue', + field_names=['double_value', 'int32_value']).eval(feed_dict={ + values1: [[0.0]], + values2: [[0], [0]] + })) + + def _testRoundtrip(self, in_bufs, message_type, fields): + + field_names = [f.name for f in fields] + out_types = [f.dtype for f in fields] + + with self.cached_session() as sess: + sizes, field_tensors = self._decode_module.decode_proto( + in_bufs, + message_type=message_type, + field_names=field_names, + output_types=out_types) + + out_tensors = self._encode_module.encode_proto( + sizes, + field_tensors, + message_type=message_type, + field_names=field_names) + + out_bufs, = sess.run([out_tensors]) + + # Check that the re-encoded tensor has the same shape. + self.assertEqual(in_bufs.shape, out_bufs.shape) + + # Compare the input and output. + for in_buf, out_buf in zip(in_bufs.flat, out_bufs.flat): + in_obj = test_example_pb2.TestValue() + in_obj.ParseFromString(in_buf) + + out_obj = test_example_pb2.TestValue() + out_obj.ParseFromString(out_buf) + + # Check that the deserialized objects are identical. + self.assertEqual(in_obj, out_obj) + + # Check that the input and output serialized messages are identical. + # If we fail here, there is a difference in the serialized + # representation but the new serialization still parses. This could + # be harmless (a change in map ordering?) or it could be bad (e.g. + # loss of packing in the encoding). + self.assertEqual(in_buf, out_buf) + + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testRoundtrip(self, case): + in_bufs = [value.SerializeToString() for value in case.values] + + # np.array silently truncates strings if you don't specify dtype=object. + in_bufs = np.reshape(np.array(in_bufs, dtype=object), list(case.shapes)) + return self._testRoundtrip( + in_bufs, 'tensorflow.contrib.proto.TestValue', case.fields) + + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testRoundtripPacked(self, case): + # Now try with the packed serialization. + # We test the packed representations by loading the same test cases using + # PackedTestValue instead of TestValue. To do this we rely on the text + # format being the same for packed and unpacked fields, and reparse the test + # message using the packed version of the proto. + in_bufs = [ + # Note: float_format='.17g' is necessary to ensure preservation of + # doubles and floats in text format. + text_format.Parse( + text_format.MessageToString( + value, float_format='.17g'), + test_example_pb2.PackedTestValue()).SerializeToString() + for value in case.values + ] + + # np.array silently truncates strings if you don't specify dtype=object. + in_bufs = np.reshape(np.array(in_bufs, dtype=object), list(case.shapes)) + return self._testRoundtrip( + in_bufs, 'tensorflow.contrib.proto.PackedTestValue', case.fields) diff --git a/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt deleted file mode 100644 index b170f89c0f..0000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt +++ /dev/null @@ -1,161 +0,0 @@ -primitive { - double_value: -1.7976931348623158e+308 - double_value: 2.2250738585072014e-308 - double_value: 1.7976931348623158e+308 - float_value: -3.402823466e+38 - float_value: 1.175494351e-38 - float_value: 3.402823466e+38 - int64_value: -9223372036854775808 - int64_value: 9223372036854775807 - uint64_value: 0 - uint64_value: 18446744073709551615 - int32_value: -2147483648 - int32_value: 2147483647 - fixed64_value: 0 - fixed64_value: 18446744073709551615 - fixed32_value: 0 - fixed32_value: 4294967295 - bool_value: false - bool_value: true - string_value: "" - string_value: "I refer to the infinite." - uint32_value: 0 - uint32_value: 4294967295 - sfixed32_value: -2147483648 - sfixed32_value: 2147483647 - sfixed64_value: -9223372036854775808 - sfixed64_value: 9223372036854775807 - sint32_value: -2147483648 - sint32_value: 2147483647 - sint64_value: -9223372036854775808 - sint64_value: 9223372036854775807 -} -shape: 1 -sizes: 3 -sizes: 3 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -field { - name: "double_value" - dtype: DT_DOUBLE - expected { - double_value: -1.7976931348623158e+308 - double_value: 2.2250738585072014e-308 - double_value: 1.7976931348623158e+308 - } -} -field { - name: "float_value" - dtype: DT_FLOAT - expected { - float_value: -3.402823466e+38 - float_value: 1.175494351e-38 - float_value: 3.402823466e+38 - } -} -field { - name: "int64_value" - dtype: DT_INT64 - expected { - int64_value: -9223372036854775808 - int64_value: 9223372036854775807 - } -} -field { - name: "uint64_value" - dtype: DT_INT64 - expected { - int64_value: 0 - int64_value: -1 - } -} -field { - name: "int32_value" - dtype: DT_INT32 - expected { - int32_value: -2147483648 - int32_value: 2147483647 - } -} -field { - name: "fixed64_value" - dtype: DT_INT64 - expected { - int64_value: 0 - int64_value: -1 # unsigned is 18446744073709551615 - } -} -field { - name: "fixed32_value" - dtype: DT_INT32 - expected { - int32_value: 0 - int32_value: -1 # unsigned is 4294967295 - } -} -field { - name: "bool_value" - dtype: DT_BOOL - expected { - bool_value: false - bool_value: true - } -} -field { - name: "string_value" - dtype: DT_STRING - expected { - string_value: "" - string_value: "I refer to the infinite." - } -} -field { - name: "uint32_value" - dtype: DT_INT32 - expected { - int32_value: 0 - int32_value: -1 # unsigned is 4294967295 - } -} -field { - name: "sfixed32_value" - dtype: DT_INT32 - expected { - int32_value: -2147483648 - int32_value: 2147483647 - } -} -field { - name: "sfixed64_value" - dtype: DT_INT64 - expected { - int64_value: -9223372036854775808 - int64_value: 9223372036854775807 - } -} -field { - name: "sint32_value" - dtype: DT_INT32 - expected { - int32_value: -2147483648 - int32_value: 2147483647 - } -} -field { - name: "sint64_value" - dtype: DT_INT64 - expected { - int64_value: -9223372036854775808 - int64_value: 9223372036854775807 - } -} diff --git a/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt deleted file mode 100644 index c664e52851..0000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt +++ /dev/null @@ -1,16 +0,0 @@ -primitive { - message_value { - double_value: 23.5 - } -} -shape: 1 -sizes: 1 -field { - name: "message_value" - dtype: DT_STRING - expected { - message_value { - double_value: 23.5 - } - } -} diff --git a/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt deleted file mode 100644 index 125651d7ea..0000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt +++ /dev/null @@ -1,20 +0,0 @@ -primitive { - bool_value: true -} -shape: 1 -sizes: 1 -sizes: 0 -field { - name: "bool_value" - dtype: DT_BOOL - expected { - bool_value: true - } -} -field { - name: "double_value" - dtype: DT_DOUBLE - expected { - double_value: 0.0 - } -} diff --git a/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt deleted file mode 100644 index bc07efc8f3..0000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt +++ /dev/null @@ -1,29 +0,0 @@ -primitive { - fixed32_value: 4294967295 - uint32_value: 4294967295 -} -shape: 1 -sizes: 1 -field { - name: "fixed32_value" - dtype: DT_INT64 - expected { - int64_value: 4294967295 - } -} -sizes: 1 -field { - name: "uint32_value" - dtype: DT_INT64 - expected { - int64_value: 4294967295 - } -} -sizes: 0 -field { - name: "uint32_default" - dtype: DT_INT64 - expected { - int64_value: 4294967295 # Comes from an explicitly-specified default - } -} diff --git a/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py new file mode 100644 index 0000000000..2950c7dfdc --- /dev/null +++ b/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py @@ -0,0 +1,419 @@ +# ============================================================================= +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Test case base for testing proto operations.""" + +# Python3 preparedness imports. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ctypes as ct +import os + +from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2 +from tensorflow.core.framework import types_pb2 +from tensorflow.python.platform import test + + +class ProtoOpTestBase(test.TestCase): + """Base class for testing proto decoding and encoding ops.""" + + def __init__(self, methodName="runTest"): # pylint: disable=invalid-name + super(ProtoOpTestBase, self).__init__(methodName) + lib = os.path.join(os.path.dirname(__file__), "libtestexample.so") + if os.path.isfile(lib): + ct.cdll.LoadLibrary(lib) + + @staticmethod + def named_parameters(): + return ( + ("defaults", ProtoOpTestBase.defaults_test_case()), + ("minmax", ProtoOpTestBase.minmax_test_case()), + ("nested", ProtoOpTestBase.nested_test_case()), + ("optional", ProtoOpTestBase.optional_test_case()), + ("promote", ProtoOpTestBase.promote_test_case()), + ("ragged", ProtoOpTestBase.ragged_test_case()), + ("shaped_batch", ProtoOpTestBase.shaped_batch_test_case()), + ("simple", ProtoOpTestBase.simple_test_case()), + ) + + @staticmethod + def defaults_test_case(): + test_case = test_example_pb2.TestCase() + test_case.values.add() # No fields specified, so we get all defaults. + test_case.shapes.append(1) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "double_value_with_default" + field.dtype = types_pb2.DT_DOUBLE + field.value.double_value.append(1.0) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "float_value_with_default" + field.dtype = types_pb2.DT_FLOAT + field.value.float_value.append(2.0) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "int64_value_with_default" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(3) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "sfixed64_value_with_default" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(11) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "sint64_value_with_default" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(13) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "uint64_value_with_default" + field.dtype = types_pb2.DT_UINT64 + field.value.uint64_value.append(4) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "fixed64_value_with_default" + field.dtype = types_pb2.DT_UINT64 + field.value.uint64_value.append(6) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "int32_value_with_default" + field.dtype = types_pb2.DT_INT32 + field.value.int32_value.append(5) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "sfixed32_value_with_default" + field.dtype = types_pb2.DT_INT32 + field.value.int32_value.append(10) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "sint32_value_with_default" + field.dtype = types_pb2.DT_INT32 + field.value.int32_value.append(12) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "uint32_value_with_default" + field.dtype = types_pb2.DT_UINT32 + field.value.uint32_value.append(9) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "fixed32_value_with_default" + field.dtype = types_pb2.DT_UINT32 + field.value.uint32_value.append(7) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "bool_value_with_default" + field.dtype = types_pb2.DT_BOOL + field.value.bool_value.append(True) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "string_value_with_default" + field.dtype = types_pb2.DT_STRING + field.value.string_value.append("a") + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "bytes_value_with_default" + field.dtype = types_pb2.DT_STRING + field.value.string_value.append("a longer default string") + return test_case + + @staticmethod + def minmax_test_case(): + test_case = test_example_pb2.TestCase() + value = test_case.values.add() + value.double_value.append(-1.7976931348623158e+308) + value.double_value.append(2.2250738585072014e-308) + value.double_value.append(1.7976931348623158e+308) + value.float_value.append(-3.402823466e+38) + value.float_value.append(1.175494351e-38) + value.float_value.append(3.402823466e+38) + value.int64_value.append(-9223372036854775808) + value.int64_value.append(9223372036854775807) + value.sfixed64_value.append(-9223372036854775808) + value.sfixed64_value.append(9223372036854775807) + value.sint64_value.append(-9223372036854775808) + value.sint64_value.append(9223372036854775807) + value.uint64_value.append(0) + value.uint64_value.append(18446744073709551615) + value.fixed64_value.append(0) + value.fixed64_value.append(18446744073709551615) + value.int32_value.append(-2147483648) + value.int32_value.append(2147483647) + value.sfixed32_value.append(-2147483648) + value.sfixed32_value.append(2147483647) + value.sint32_value.append(-2147483648) + value.sint32_value.append(2147483647) + value.uint32_value.append(0) + value.uint32_value.append(4294967295) + value.fixed32_value.append(0) + value.fixed32_value.append(4294967295) + value.bool_value.append(False) + value.bool_value.append(True) + value.string_value.append("") + value.string_value.append("I refer to the infinite.") + test_case.shapes.append(1) + test_case.sizes.append(3) + field = test_case.fields.add() + field.name = "double_value" + field.dtype = types_pb2.DT_DOUBLE + field.value.double_value.append(-1.7976931348623158e+308) + field.value.double_value.append(2.2250738585072014e-308) + field.value.double_value.append(1.7976931348623158e+308) + test_case.sizes.append(3) + field = test_case.fields.add() + field.name = "float_value" + field.dtype = types_pb2.DT_FLOAT + field.value.float_value.append(-3.402823466e+38) + field.value.float_value.append(1.175494351e-38) + field.value.float_value.append(3.402823466e+38) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "int64_value" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(-9223372036854775808) + field.value.int64_value.append(9223372036854775807) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "sfixed64_value" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(-9223372036854775808) + field.value.int64_value.append(9223372036854775807) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "sint64_value" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(-9223372036854775808) + field.value.int64_value.append(9223372036854775807) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "uint64_value" + field.dtype = types_pb2.DT_UINT64 + field.value.uint64_value.append(0) + field.value.uint64_value.append(18446744073709551615) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "fixed64_value" + field.dtype = types_pb2.DT_UINT64 + field.value.uint64_value.append(0) + field.value.uint64_value.append(18446744073709551615) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "int32_value" + field.dtype = types_pb2.DT_INT32 + field.value.int32_value.append(-2147483648) + field.value.int32_value.append(2147483647) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "sfixed32_value" + field.dtype = types_pb2.DT_INT32 + field.value.int32_value.append(-2147483648) + field.value.int32_value.append(2147483647) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "sint32_value" + field.dtype = types_pb2.DT_INT32 + field.value.int32_value.append(-2147483648) + field.value.int32_value.append(2147483647) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "uint32_value" + field.dtype = types_pb2.DT_UINT32 + field.value.uint32_value.append(0) + field.value.uint32_value.append(4294967295) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "fixed32_value" + field.dtype = types_pb2.DT_UINT32 + field.value.uint32_value.append(0) + field.value.uint32_value.append(4294967295) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "bool_value" + field.dtype = types_pb2.DT_BOOL + field.value.bool_value.append(False) + field.value.bool_value.append(True) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "string_value" + field.dtype = types_pb2.DT_STRING + field.value.string_value.append("") + field.value.string_value.append("I refer to the infinite.") + return test_case + + @staticmethod + def nested_test_case(): + test_case = test_example_pb2.TestCase() + value = test_case.values.add() + message_value = value.message_value.add() + message_value.double_value = 23.5 + test_case.shapes.append(1) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "message_value" + field.dtype = types_pb2.DT_STRING + message_value = field.value.message_value.add() + message_value.double_value = 23.5 + return test_case + + @staticmethod + def optional_test_case(): + test_case = test_example_pb2.TestCase() + value = test_case.values.add() + value.bool_value.append(True) + test_case.shapes.append(1) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "bool_value" + field.dtype = types_pb2.DT_BOOL + field.value.bool_value.append(True) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "double_value" + field.dtype = types_pb2.DT_DOUBLE + field.value.double_value.append(0.0) + return test_case + + @staticmethod + def promote_test_case(): + test_case = test_example_pb2.TestCase() + value = test_case.values.add() + value.sint32_value.append(2147483647) + value.sfixed32_value.append(2147483647) + value.int32_value.append(2147483647) + value.fixed32_value.append(4294967295) + value.uint32_value.append(4294967295) + test_case.shapes.append(1) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "sint32_value" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(2147483647) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "sfixed32_value" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(2147483647) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "int32_value" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(2147483647) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "fixed32_value" + field.dtype = types_pb2.DT_UINT64 + field.value.uint64_value.append(4294967295) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "uint32_value" + field.dtype = types_pb2.DT_UINT64 + field.value.uint64_value.append(4294967295) + return test_case + + @staticmethod + def ragged_test_case(): + test_case = test_example_pb2.TestCase() + value = test_case.values.add() + value.double_value.append(23.5) + value.double_value.append(123.0) + value.bool_value.append(True) + value = test_case.values.add() + value.double_value.append(3.1) + value.bool_value.append(False) + test_case.shapes.append(2) + test_case.sizes.append(2) + test_case.sizes.append(1) + test_case.sizes.append(1) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "double_value" + field.dtype = types_pb2.DT_DOUBLE + field.value.double_value.append(23.5) + field.value.double_value.append(123.0) + field.value.double_value.append(3.1) + field.value.double_value.append(0.0) + field = test_case.fields.add() + field.name = "bool_value" + field.dtype = types_pb2.DT_BOOL + field.value.bool_value.append(True) + field.value.bool_value.append(False) + return test_case + + @staticmethod + def shaped_batch_test_case(): + test_case = test_example_pb2.TestCase() + value = test_case.values.add() + value.double_value.append(23.5) + value.bool_value.append(True) + value = test_case.values.add() + value.double_value.append(44.0) + value.bool_value.append(False) + value = test_case.values.add() + value.double_value.append(3.14159) + value.bool_value.append(True) + value = test_case.values.add() + value.double_value.append(1.414) + value.bool_value.append(True) + value = test_case.values.add() + value.double_value.append(-32.2) + value.bool_value.append(False) + value = test_case.values.add() + value.double_value.append(0.0001) + value.bool_value.append(True) + test_case.shapes.append(3) + test_case.shapes.append(2) + for _ in range(12): + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "double_value" + field.dtype = types_pb2.DT_DOUBLE + field.value.double_value.append(23.5) + field.value.double_value.append(44.0) + field.value.double_value.append(3.14159) + field.value.double_value.append(1.414) + field.value.double_value.append(-32.2) + field.value.double_value.append(0.0001) + field = test_case.fields.add() + field.name = "bool_value" + field.dtype = types_pb2.DT_BOOL + field.value.bool_value.append(True) + field.value.bool_value.append(False) + field.value.bool_value.append(True) + field.value.bool_value.append(True) + field.value.bool_value.append(False) + field.value.bool_value.append(True) + return test_case + + @staticmethod + def simple_test_case(): + test_case = test_example_pb2.TestCase() + value = test_case.values.add() + value.double_value.append(23.5) + value.bool_value.append(True) + test_case.shapes.append(1) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "double_value" + field.dtype = types_pb2.DT_DOUBLE + field.value.double_value.append(23.5) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "bool_value" + field.dtype = types_pb2.DT_BOOL + field.value.bool_value.append(True) + return test_case diff --git a/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt deleted file mode 100644 index 61c7ac53f7..0000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt +++ /dev/null @@ -1,32 +0,0 @@ -primitive { - double_value: 23.5 - double_value: 123.0 - bool_value: true -} -primitive { - double_value: 3.1 - bool_value: false -} -shape: 2 -sizes: 2 -sizes: 1 -sizes: 1 -sizes: 1 -field { - name: "double_value" - dtype: DT_DOUBLE - expected { - double_value: 23.5 - double_value: 123.0 - double_value: 3.1 - double_value: 0.0 - } -} -field { - name: "bool_value" - dtype: DT_BOOL - expected { - bool_value: true - bool_value: false - } -} diff --git a/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt deleted file mode 100644 index f4828076d5..0000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt +++ /dev/null @@ -1,62 +0,0 @@ -primitive { - double_value: 23.5 - bool_value: true -} -primitive { - double_value: 44.0 - bool_value: false -} -primitive { - double_value: 3.14159 - bool_value: true -} -primitive { - double_value: 1.414 - bool_value: true -} -primitive { - double_value: -32.2 - bool_value: false -} -primitive { - double_value: 0.0001 - bool_value: true -} -shape: 3 -shape: 2 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -field { - name: "double_value" - dtype: DT_DOUBLE - expected { - double_value: 23.5 - double_value: 44.0 - double_value: 3.14159 - double_value: 1.414 - double_value: -32.2 - double_value: 0.0001 - } -} -field { - name: "bool_value" - dtype: DT_BOOL - expected { - bool_value: true - bool_value: false - bool_value: true - bool_value: true - bool_value: false - bool_value: true - } -} diff --git a/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt deleted file mode 100644 index dc20ac147b..0000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt +++ /dev/null @@ -1,21 +0,0 @@ -primitive { - double_value: 23.5 - bool_value: true -} -shape: 1 -sizes: 1 -sizes: 1 -field { - name: "double_value" - dtype: DT_DOUBLE - expected { - double_value: 23.5 - } -} -field { - name: "bool_value" - dtype: DT_BOOL - expected { - bool_value: true - } -} diff --git a/tensorflow/contrib/proto/python/kernel_tests/test_example.proto b/tensorflow/contrib/proto/python/kernel_tests/test_example.proto index a2c88e372b..674d881220 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/test_example.proto +++ b/tensorflow/contrib/proto/python/kernel_tests/test_example.proto @@ -1,6 +1,4 @@ // Test description and protos to work with it. -// -// Many of the protos in this file are for unit tests that haven't been written yet. syntax = "proto2"; @@ -8,54 +6,27 @@ import "tensorflow/core/framework/types.proto"; package tensorflow.contrib.proto; -// A TestCase holds a proto and a bunch of assertions -// about how it should decode. +// A TestCase holds a proto and assertions about how it should decode. message TestCase { - // A batch of primitives to be serialized and decoded. - repeated RepeatedPrimitiveValue primitive = 1; - // The shape of the batch. - repeated int32 shape = 2; + // Batches of primitive values. + repeated TestValue values = 1; + // The batch shapes. + repeated int32 shapes = 2; // Expected sizes for each field. repeated int32 sizes = 3; // Expected values for each field. - repeated FieldSpec field = 4; + repeated FieldSpec fields = 4; }; // FieldSpec describes the expected output for a single field. message FieldSpec { optional string name = 1; optional tensorflow.DataType dtype = 2; - optional RepeatedPrimitiveValue expected = 3; + optional TestValue value = 3; }; +// NOTE: This definition must be kept in sync with PackedTestValue. message TestValue { - optional PrimitiveValue primitive_value = 1; - optional EnumValue enum_value = 2; - optional MessageValue message_value = 3; - optional RepeatedMessageValue repeated_message_value = 4; - optional RepeatedPrimitiveValue repeated_primitive_value = 6; -} - -message PrimitiveValue { - optional double double_value = 1; - optional float float_value = 2; - optional int64 int64_value = 3; - optional uint64 uint64_value = 4; - optional int32 int32_value = 5; - optional fixed64 fixed64_value = 6; - optional fixed32 fixed32_value = 7; - optional bool bool_value = 8; - optional string string_value = 9; - optional bytes bytes_value = 12; - optional uint32 uint32_value = 13; - optional sfixed32 sfixed32_value = 15; - optional sfixed64 sfixed64_value = 16; - optional sint32 sint32_value = 17; - optional sint64 sint64_value = 18; -} - -// NOTE: This definition must be kept in sync with PackedPrimitiveValue. -message RepeatedPrimitiveValue { repeated double double_value = 1; repeated float float_value = 2; repeated int64 int64_value = 3; @@ -74,30 +45,31 @@ message RepeatedPrimitiveValue { repeated PrimitiveValue message_value = 19; // Optional fields with explicitly-specified defaults. - optional double double_default = 20 [default = 1.0]; - optional float float_default = 21 [default = 2.0]; - optional int64 int64_default = 22 [default = 3]; - optional uint64 uint64_default = 23 [default = 4]; - optional int32 int32_default = 24 [default = 5]; - optional fixed64 fixed64_default = 25 [default = 6]; - optional fixed32 fixed32_default = 26 [default = 7]; - optional bool bool_default = 27 [default = true]; - optional string string_default = 28 [default = "a"]; - optional bytes bytes_default = 29 [default = "a longer default string"]; - optional uint32 uint32_default = 30 [default = 4294967295]; - optional sfixed32 sfixed32_default = 31 [default = 10]; - optional sfixed64 sfixed64_default = 32 [default = 11]; - optional sint32 sint32_default = 33 [default = 12]; - optional sint64 sint64_default = 34 [default = 13]; + optional double double_value_with_default = 20 [default = 1.0]; + optional float float_value_with_default = 21 [default = 2.0]; + optional int64 int64_value_with_default = 22 [default = 3]; + optional uint64 uint64_value_with_default = 23 [default = 4]; + optional int32 int32_value_with_default = 24 [default = 5]; + optional fixed64 fixed64_value_with_default = 25 [default = 6]; + optional fixed32 fixed32_value_with_default = 26 [default = 7]; + optional bool bool_value_with_default = 27 [default = true]; + optional string string_value_with_default = 28 [default = "a"]; + optional bytes bytes_value_with_default = 29 + [default = "a longer default string"]; + optional uint32 uint32_value_with_default = 30 [default = 9]; + optional sfixed32 sfixed32_value_with_default = 31 [default = 10]; + optional sfixed64 sfixed64_value_with_default = 32 [default = 11]; + optional sint32 sint32_value_with_default = 33 [default = 12]; + optional sint64 sint64_value_with_default = 34 [default = 13]; } -// A PackedPrimitiveValue looks exactly the same as a RepeatedPrimitiveValue -// in the text format, but the binary serializion is different. -// We test the packed representations by loading the same test cases -// using this definition instead of RepeatedPrimitiveValue. -// NOTE: This definition must be kept in sync with RepeatedPrimitiveValue -// in every way except the packed=true declaration. -message PackedPrimitiveValue { +// A PackedTestValue looks exactly the same as a TestValue in the text format, +// but the binary serializion is different. We test the packed representations +// by loading the same test cases using this definition instead of TestValue. +// +// NOTE: This definition must be kept in sync with TestValue in every way except +// the packed=true declaration. +message PackedTestValue { repeated double double_value = 1 [packed = true]; repeated float float_value = 2 [packed = true]; repeated int64 int64_value = 3 [packed = true]; @@ -115,23 +87,53 @@ message PackedPrimitiveValue { repeated sint64 sint64_value = 18 [packed = true]; repeated PrimitiveValue message_value = 19; - optional double double_default = 20 [default = 1.0]; - optional float float_default = 21 [default = 2.0]; - optional int64 int64_default = 22 [default = 3]; - optional uint64 uint64_default = 23 [default = 4]; - optional int32 int32_default = 24 [default = 5]; - optional fixed64 fixed64_default = 25 [default = 6]; - optional fixed32 fixed32_default = 26 [default = 7]; - optional bool bool_default = 27 [default = true]; - optional string string_default = 28 [default = "a"]; - optional bytes bytes_default = 29 [default = "a longer default string"]; - optional uint32 uint32_default = 30 [default = 4294967295]; - optional sfixed32 sfixed32_default = 31 [default = 10]; - optional sfixed64 sfixed64_default = 32 [default = 11]; - optional sint32 sint32_default = 33 [default = 12]; - optional sint64 sint64_default = 34 [default = 13]; + optional double double_value_with_default = 20 [default = 1.0]; + optional float float_value_with_default = 21 [default = 2.0]; + optional int64 int64_value_with_default = 22 [default = 3]; + optional uint64 uint64_value_with_default = 23 [default = 4]; + optional int32 int32_value_with_default = 24 [default = 5]; + optional fixed64 fixed64_value_with_default = 25 [default = 6]; + optional fixed32 fixed32_value_with_default = 26 [default = 7]; + optional bool bool_value_with_default = 27 [default = true]; + optional string string_value_with_default = 28 [default = "a"]; + optional bytes bytes_value_with_default = 29 + [default = "a longer default string"]; + optional uint32 uint32_value_with_default = 30 [default = 9]; + optional sfixed32 sfixed32_value_with_default = 31 [default = 10]; + optional sfixed64 sfixed64_value_with_default = 32 [default = 11]; + optional sint32 sint32_value_with_default = 33 [default = 12]; + optional sint64 sint64_value_with_default = 34 [default = 13]; } +message PrimitiveValue { + optional double double_value = 1; + optional float float_value = 2; + optional int64 int64_value = 3; + optional uint64 uint64_value = 4; + optional int32 int32_value = 5; + optional fixed64 fixed64_value = 6; + optional fixed32 fixed32_value = 7; + optional bool bool_value = 8; + optional string string_value = 9; + optional bytes bytes_value = 12; + optional uint32 uint32_value = 13; + optional sfixed32 sfixed32_value = 15; + optional sfixed64 sfixed64_value = 16; + optional sint32 sint32_value = 17; + optional sint64 sint64_value = 18; +} + +// Message containing fields with field numbers higher than any field above. +// An instance of this message is prepended to each binary message in the test +// to exercise the code path that handles fields encoded out of order of field +// number. +message ExtraFields { + optional string string_value = 1776; + optional bool bool_value = 1777; +} + +// The messages below are for yet-to-be created tests. + message EnumValue { enum Color { RED = 0; @@ -171,12 +173,3 @@ message RepeatedMessageValue { repeated NestedMessageValue message_values = 11; } - -// Message containing fields with field numbers higher than any field above. An -// instance of this message is prepended to each binary message in the test to -// exercise the code path that handles fields encoded out of order of field -// number. -message ExtraFields { - optional string string_value = 1776; - optional bool bool_value = 1777; -} |