From 1cbf75706031247460e588f281a47f0ae00d6812 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Mon, 16 Apr 2018 11:24:43 -0700 Subject: Porting tests for the `decode_proto` and `encode_proto` to OS. PiperOrigin-RevId: 193070420 --- tensorflow/contrib/BUILD | 1 + tensorflow/contrib/__init__.py | 1 + tensorflow/contrib/cmake/tf_python.cmake | 6 +- tensorflow/contrib/proto/BUILD | 16 ++ tensorflow/contrib/proto/python/kernel_tests/BUILD | 86 ++++++ .../proto/python/kernel_tests/build_defs.bzl | 89 ++++++ .../python/kernel_tests/decode_proto_fail_test.py | 68 +++++ .../python/kernel_tests/decode_proto_op_test.py | 300 +++++++++++++++++++++ .../python/kernel_tests/encode_proto_op_test.py | 180 +++++++++++++ .../python/kernel_tests/minmax.TestCase.pbtxt | 161 +++++++++++ .../python/kernel_tests/nested.TestCase.pbtxt | 16 ++ .../python/kernel_tests/optional.TestCase.pbtxt | 20 ++ .../kernel_tests/promote_unsigned.TestCase.pbtxt | 21 ++ .../python/kernel_tests/ragged.TestCase.pbtxt | 32 +++ .../kernel_tests/shaped_batch.TestCase.pbtxt | 62 +++++ .../python/kernel_tests/simple.TestCase.pbtxt | 21 ++ .../contrib/proto/python/kernel_tests/test_case.py | 35 +++ .../proto/python/kernel_tests/test_example.proto | 149 ++++++++++ 18 files changed, 1262 insertions(+), 2 deletions(-) create mode 100644 tensorflow/contrib/proto/python/kernel_tests/BUILD create mode 100644 tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl create mode 100644 tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py create mode 100644 tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py create mode 100644 tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py create mode 100644 tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt create mode 100644 tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt create mode 100644 tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt create mode 100644 tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt create mode 100644 tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt create mode 100644 tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt create mode 100644 tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt create mode 100644 tensorflow/contrib/proto/python/kernel_tests/test_case.py create mode 100644 tensorflow/contrib/proto/python/kernel_tests/test_example.proto (limited to 'tensorflow/contrib') diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 9bef0d8b61..ae68f4aec4 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -77,6 +77,7 @@ py_library( "//tensorflow/contrib/optimizer_v2:optimizer_v2_py", "//tensorflow/contrib/periodic_resample:init_py", "//tensorflow/contrib/predictor", + "//tensorflow/contrib/proto", "//tensorflow/contrib/quantization:quantization_py", "//tensorflow/contrib/quantize:quantize_graph", "//tensorflow/contrib/autograph", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index aaddb06fa0..e27ece8fa5 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -64,6 +64,7 @@ from tensorflow.contrib import nn from tensorflow.contrib import opt from tensorflow.contrib import periodic_resample from tensorflow.contrib import predictor +from tensorflow.contrib import proto from tensorflow.contrib import quantization from tensorflow.contrib import quantize from tensorflow.contrib import recurrent diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index ded15b4b66..21f59d2563 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -330,8 +330,10 @@ GENERATE_PYTHON_OP_LIB("ctc_ops") GENERATE_PYTHON_OP_LIB("cudnn_rnn_ops") GENERATE_PYTHON_OP_LIB("data_flow_ops") GENERATE_PYTHON_OP_LIB("dataset_ops") -GENERATE_PYTHON_OP_LIB("decode_proto_ops") -GENERATE_PYTHON_OP_LIB("encode_proto_ops") +GENERATE_PYTHON_OP_LIB("decode_proto_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/proto/python/ops/gen_decode_proto_op.py) +GENERATE_PYTHON_OP_LIB("encode_proto_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/proto/python/ops/gen_encode_proto_op.py) GENERATE_PYTHON_OP_LIB("image_ops") GENERATE_PYTHON_OP_LIB("io_ops") GENERATE_PYTHON_OP_LIB("linalg_ops") diff --git a/tensorflow/contrib/proto/BUILD b/tensorflow/contrib/proto/BUILD index 046652cbc5..3e9b1a0b8d 100644 --- a/tensorflow/contrib/proto/BUILD +++ b/tensorflow/contrib/proto/BUILD @@ -4,6 +4,8 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) +load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static") + py_library( name = "proto", srcs = [ @@ -14,3 +16,17 @@ py_library( "//tensorflow/contrib/proto/python/ops:encode_proto_op_py", ], ) + +py_library( + name = "proto_pip", + data = [ + "//tensorflow/contrib/proto/python/kernel_tests:test_messages", + ] + if_static( + [], + otherwise = ["//tensorflow/contrib/proto/python/kernel_tests:libtestexample.so"], + ), + deps = [ + ":proto", + "//tensorflow/contrib/proto/python/kernel_tests:py_test_deps", + ], +) diff --git a/tensorflow/contrib/proto/python/kernel_tests/BUILD b/tensorflow/contrib/proto/python/kernel_tests/BUILD new file mode 100644 index 0000000000..a380a131f8 --- /dev/null +++ b/tensorflow/contrib/proto/python/kernel_tests/BUILD @@ -0,0 +1,86 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +# Much of the work in this BUILD file actually happens in the corresponding +# build_defs.bzl, which creates an individual testcase for each example .pbtxt +# file in this directory. +# +load(":build_defs.bzl", "decode_proto_test_suite") +load(":build_defs.bzl", "encode_proto_test_suite") + +# This expands to a tf_py_test for each test file. +# It defines the test_suite :decode_proto_op_tests. +decode_proto_test_suite( + name = "decode_proto_tests", + examples = glob(["*.pbtxt"]), +) + +# This expands to a tf_py_test for each test file. +# It defines the test_suite :encode_proto_op_tests. +encode_proto_test_suite( + name = "encode_proto_tests", + examples = glob(["*.pbtxt"]), +) + +# Below here are tests that are not tied to an example text proto. +filegroup( + name = "test_messages", + srcs = glob(["*.pbtxt"]), +) + +load("//tensorflow:tensorflow.bzl", "tf_py_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object") +load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static") +load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") + +tf_py_test( + name = "decode_proto_fail_test", + size = "small", + srcs = ["decode_proto_fail_test.py"], + additional_deps = [ + ":py_test_deps", + "//third_party/py/numpy", + "//tensorflow/contrib/proto:proto", + "//tensorflow/contrib/proto/python/ops:decode_proto_op_py", + ], + data = if_static( + [], + otherwise = [":libtestexample.so"], + ), + tags = [ + "no_pip", # TODO(b/78026780) + "no_windows", # TODO(b/78028010) + ], +) + +py_library( + name = "test_case", + srcs = ["test_case.py"], + deps = ["//tensorflow/python:client_testlib"], +) + +py_library( + name = "py_test_deps", + deps = [ + ":test_case", + ":test_example_proto_py", + ], +) + +tf_proto_library( + name = "test_example_proto", + srcs = ["test_example.proto"], + cc_api_version = 2, + protodeps = ["//tensorflow/core:protos_all"], +) + +tf_cc_shared_object( + name = "libtestexample.so", + linkstatic = 1, + deps = [ + ":test_example_proto_cc", + ], +) diff --git a/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl b/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl new file mode 100644 index 0000000000..f425601691 --- /dev/null +++ b/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl @@ -0,0 +1,89 @@ +"""BUILD rules for generating file-driven proto test cases. + +The decode_proto_test_suite() and encode_proto_test_suite() rules take a list +of text protos and generates a tf_py_test() for each one. +""" + +load("//tensorflow:tensorflow.bzl", "tf_py_test") +load("//tensorflow:tensorflow.bzl", "register_extension_info") +load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static") + +def _test_name(test, path): + return "%s_%s_test" % (test, path.split("/")[-1].split(".")[0]) + +def decode_proto_test_suite(name, examples): + """Build the decode_proto py_test for each test filename.""" + for test_filename in examples: + tf_py_test( + name = _test_name("decode_proto", test_filename), + srcs = ["decode_proto_op_test.py"], + size = "small", + data = [test_filename] + if_static( + [], + otherwise = [":libtestexample.so"], + ), + main = "decode_proto_op_test.py", + args = [ + "--message_text_file=\"%s/%s\"" % (native.package_name(), test_filename), + ], + additional_deps = [ + ":py_test_deps", + "//third_party/py/numpy", + "//tensorflow/contrib/proto:proto", + "//tensorflow/contrib/proto/python/ops:decode_proto_op_py", + ], + tags = [ + "no_pip", # TODO(b/78026780) + "no_windows", # TODO(b/78028010) + ], + ) + native.test_suite( + name = name, + tests = [":" + _test_name("decode_proto", test_filename) + for test_filename in examples], + ) + +def encode_proto_test_suite(name, examples): + """Build the encode_proto py_test for each test filename.""" + for test_filename in examples: + tf_py_test( + name = _test_name("encode_proto", test_filename), + srcs = ["encode_proto_op_test.py"], + size = "small", + data = [test_filename] + if_static( + [], + otherwise = [":libtestexample.so"], + ), + main = "encode_proto_op_test.py", + args = [ + "--message_text_file=\"%s/%s\"" % (native.package_name(), test_filename), + ], + additional_deps = [ + ":py_test_deps", + "//third_party/py/numpy", + "//tensorflow/contrib/proto:proto", + "//tensorflow/contrib/proto/python/ops:decode_proto_op_py", + "//tensorflow/contrib/proto/python/ops:encode_proto_op_py", + ], + tags = [ + "no_pip", # TODO(b/78026780) + "no_windows", # TODO(b/78028010) + ], + ) + native.test_suite( + name = name, + tests = [":" + _test_name("encode_proto", test_filename) + for test_filename in examples], + ) + +register_extension_info( + extension_name = "decode_proto_test_suite", + label_regex_map = { + "deps": "deps:decode_example_.*", + }) + +register_extension_info( + extension_name = "encode_proto_test_suite", + label_regex_map = { + "deps": "deps:encode_example_.*", + }) diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py new file mode 100644 index 0000000000..5298342ee7 --- /dev/null +++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py @@ -0,0 +1,68 @@ +# ============================================================================= +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +# Python3 preparedness imports. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.proto.python.kernel_tests import test_case +from tensorflow.contrib.proto.python.ops import decode_proto_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.platform import test + + +class DecodeProtoFailTest(test_case.ProtoOpTestCase): + """Test failure cases for DecodeToProto.""" + + def _TestCorruptProtobuf(self, sanitize): + """Test failure cases for DecodeToProto.""" + + # The goal here is to check the error reporting. + # Testing against a variety of corrupt protobufs is + # done by fuzzing. + corrupt_proto = 'This is not a binary protobuf' + + # Numpy silently truncates the strings if you don't specify dtype=object. + batch = np.array(corrupt_proto, dtype=object) + msg_type = 'tensorflow.contrib.proto.TestCase' + field_names = ['sizes'] + field_types = [dtypes.int32] + + with self.test_session() as sess: + ctensor, vtensor = decode_proto_op.decode_proto( + batch, + message_type=msg_type, + field_names=field_names, + output_types=field_types, + sanitize=sanitize) + with self.assertRaisesRegexp(errors.DataLossError, + 'Unable to parse binary protobuf' + '|Failed to consume entire buffer'): + _ = sess.run([ctensor] + vtensor) + + def testCorrupt(self): + self._TestCorruptProtobuf(sanitize=False) + + def testSanitizerCorrupt(self): + self._TestCorruptProtobuf(sanitize=True) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py new file mode 100644 index 0000000000..d1c13c82bc --- /dev/null +++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py @@ -0,0 +1,300 @@ +# ============================================================================= +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Table-driven test for decode_proto op. + +This test is run once with each of the *.TestCase.pbtxt files +in the test directory. +""" +# Python3 preparedness imports. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from google.protobuf import text_format + +from tensorflow.contrib.proto.python.kernel_tests import test_case +from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2 +from tensorflow.contrib.proto.python.ops import decode_proto_op +from tensorflow.python.framework import dtypes +from tensorflow.python.platform import flags +from tensorflow.python.platform import test + +FLAGS = flags.FLAGS + +flags.DEFINE_string('message_text_file', None, + 'A file containing a text serialized TestCase protobuf.') + + +class DecodeProtoOpTest(test_case.ProtoOpTestCase): + + def _compareValues(self, fd, vs, evs): + """Compare lists/arrays of field values.""" + + if len(vs) != len(evs): + self.fail('Field %s decoded %d outputs, expected %d' % + (fd.name, len(vs), len(evs))) + for i, ev in enumerate(evs): + # Special case fuzzy match for float32. TensorFlow seems to mess with + # MAX_FLT slightly and the test doesn't work otherwise. + # TODO(nix): ask on TF list about why MAX_FLT doesn't pass through. + if fd.cpp_type == fd.CPPTYPE_FLOAT: + # Numpy isclose() is better than assertIsClose() which uses an absolute + # value comparison. + self.assertTrue( + np.isclose(vs[i], ev), 'expected %r, actual %r' % (ev, vs[i])) + elif fd.cpp_type == fd.CPPTYPE_STRING: + # In Python3 string tensor values will be represented as bytes, so we + # reencode the proto values to match that. + self.assertEqual(vs[i], ev.encode('ascii')) + else: + # Doubles and other types pass through unscathed. + self.assertEqual(vs[i], ev) + + def _compareRepeatedPrimitiveValue(self, batch_shape, sizes, fields, + field_dict): + """Compare protos of type RepeatedPrimitiveValue. + + Args: + batch_shape: the shape of the input tensor of serialized messages. + sizes: int matrix of repeat counts returned by decode_proto + fields: list of test_example_pb2.FieldSpec (types and expected values) + field_dict: map from field names to decoded numpy tensors of values + """ + + # Check that expected values match. + for field in fields: + values = field_dict[field.name] + self.assertEqual(dtypes.as_dtype(values.dtype), field.dtype) + + fd = field.expected.DESCRIPTOR.fields_by_name[field.name] + + # Values has the same shape as the input plus an extra + # dimension for repeats. + self.assertEqual(list(values.shape)[:-1], batch_shape) + + # Nested messages are represented as TF strings, requiring + # some special handling. + if field.name == 'message_value': + vs = [] + for buf in values.flat: + msg = test_example_pb2.PrimitiveValue() + msg.ParseFromString(buf) + vs.append(msg) + evs = getattr(field.expected, field.name) + if len(vs) != len(evs): + self.fail('Field %s decoded %d outputs, expected %d' % + (fd.name, len(vs), len(evs))) + for v, ev in zip(vs, evs): + self.assertEqual(v, ev) + continue + + # This can be a little confusing. For testing we are using + # RepeatedPrimitiveValue in two ways: it's the proto that we + # decode for testing, and it's used in the expected value as a + # union type. The two cases are slightly different: this is the + # second case. + # We may be fetching the uint64_value from the test proto, but + # in the expected proto we store it in the int64_value field + # because TensorFlow doesn't support unsigned int64. + tf_type_to_primitive_value_field = { + dtypes.float32: + 'float_value', + dtypes.float64: + 'double_value', + dtypes.int32: + 'int32_value', + dtypes.uint8: + 'uint8_value', + dtypes.int8: + 'int8_value', + dtypes.string: + 'string_value', + dtypes.int64: + 'int64_value', + dtypes.bool: + 'bool_value', + # Unhandled TensorFlow types: + # DT_INT16 DT_COMPLEX64 DT_QINT8 DT_QUINT8 DT_QINT32 + # DT_BFLOAT16 DT_QINT16 DT_QUINT16 DT_UINT16 + } + tf_field_name = tf_type_to_primitive_value_field.get(field.dtype) + if tf_field_name is None: + self.fail('Unhandled tensorflow type %d' % field.dtype) + + self._compareValues(fd, values.flat, + getattr(field.expected, tf_field_name)) + + def _runDecodeProtoTests(self, fields, case_sizes, batch_shape, batch, + message_type, message_format, sanitize, + force_disordered=False): + """Run decode tests on a batch of messages. + + Args: + fields: list of test_example_pb2.FieldSpec (types and expected values) + case_sizes: expected sizes array + batch_shape: the shape of the input tensor of serialized messages + batch: list of serialized messages + message_type: descriptor name for messages + message_format: format of messages, 'text' or 'binary' + sanitize: whether to sanitize binary protobuf inputs + force_disordered: whether to force fields encoded out of order. + """ + + if force_disordered: + # Exercise code path that handles out-of-order fields by prepending extra + # fields with tag numbers higher than any real field. Note that this won't + # work with sanitization because that forces reserialization using a + # trusted decoder and encoder. + assert not sanitize + extra_fields = test_example_pb2.ExtraFields() + extra_fields.string_value = 'IGNORE ME' + extra_fields.bool_value = False + extra_msg = extra_fields.SerializeToString() + batch = [extra_msg + msg for msg in batch] + + # Numpy silently truncates the strings if you don't specify dtype=object. + batch = np.array(batch, dtype=object) + batch = np.reshape(batch, batch_shape) + + field_names = [f.name for f in fields] + output_types = [f.dtype for f in fields] + + with self.test_session() as sess: + sizes, vtensor = decode_proto_op.decode_proto( + batch, + message_type=message_type, + field_names=field_names, + output_types=output_types, + message_format=message_format, + sanitize=sanitize) + + vlist = sess.run([sizes] + vtensor) + sizes = vlist[0] + # Values is a list of tensors, one for each field. + value_tensors = vlist[1:] + + # Check that the repeat sizes are correct. + self.assertTrue( + np.all(np.array(sizes.shape) == batch_shape + [len(field_names)])) + + # Check that the decoded sizes match the expected sizes. + self.assertEqual(len(sizes.flat), len(case_sizes)) + self.assertTrue( + np.all(sizes.flat == np.array( + case_sizes, dtype=np.int32))) + + field_dict = dict(zip(field_names, value_tensors)) + + self._compareRepeatedPrimitiveValue(batch_shape, sizes, fields, + field_dict) + + def testBinary(self): + with open(FLAGS.message_text_file, 'r') as fp: + case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) + + batch = [primitive.SerializeToString() for primitive in case.primitive] + self._runDecodeProtoTests( + case.field, + case.sizes, + list(case.shape), + batch, + 'tensorflow.contrib.proto.RepeatedPrimitiveValue', + 'binary', + sanitize=False) + + def testBinaryDisordered(self): + with open(FLAGS.message_text_file, 'r') as fp: + case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) + + batch = [primitive.SerializeToString() for primitive in case.primitive] + self._runDecodeProtoTests( + case.field, + case.sizes, + list(case.shape), + batch, + 'tensorflow.contrib.proto.RepeatedPrimitiveValue', + 'binary', + sanitize=False, + force_disordered=True) + + def testPacked(self): + with open(FLAGS.message_text_file, 'r') as fp: + case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) + + # Now try with the packed serialization. + # We test the packed representations by loading the same test cases + # using PackedPrimitiveValue instead of RepeatedPrimitiveValue. + # To do this we rely on the text format being the same for packed and + # unpacked fields, and reparse the test message using the packed version + # of the proto. + packed_batch = [ + # Note: float_format='.17g' is necessary to ensure preservation of + # doubles and floats in text format. + text_format.Parse( + text_format.MessageToString( + primitive, float_format='.17g'), + test_example_pb2.PackedPrimitiveValue()).SerializeToString() + for primitive in case.primitive + ] + + self._runDecodeProtoTests( + case.field, + case.sizes, + list(case.shape), + packed_batch, + 'tensorflow.contrib.proto.PackedPrimitiveValue', + 'binary', + sanitize=False) + + def testText(self): + with open(FLAGS.message_text_file, 'r') as fp: + case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) + + # Note: float_format='.17g' is necessary to ensure preservation of + # doubles and floats in text format. + text_batch = [ + text_format.MessageToString( + primitive, float_format='.17g') for primitive in case.primitive + ] + + self._runDecodeProtoTests( + case.field, + case.sizes, + list(case.shape), + text_batch, + 'tensorflow.contrib.proto.RepeatedPrimitiveValue', + 'text', + sanitize=False) + + def testSanitizerGood(self): + with open(FLAGS.message_text_file, 'r') as fp: + case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) + + batch = [primitive.SerializeToString() for primitive in case.primitive] + self._runDecodeProtoTests( + case.field, + case.sizes, + list(case.shape), + batch, + 'tensorflow.contrib.proto.RepeatedPrimitiveValue', + 'binary', + sanitize=True) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py new file mode 100644 index 0000000000..30e58e6336 --- /dev/null +++ b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py @@ -0,0 +1,180 @@ +# ============================================================================= +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Table-driven test for encode_proto op. + +This test is run once with each of the *.TestCase.pbtxt files +in the test directory. + +It tests that encode_proto is a lossless inverse of decode_proto +(for the specified fields). +""" +# Python3 readiness boilerplate +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from google.protobuf import text_format + +from tensorflow.contrib.proto.python.kernel_tests import test_case +from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2 +from tensorflow.contrib.proto.python.ops import decode_proto_op +from tensorflow.contrib.proto.python.ops import encode_proto_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import flags +from tensorflow.python.platform import test + +FLAGS = flags.FLAGS + +flags.DEFINE_string('message_text_file', None, + 'A file containing a text serialized TestCase protobuf.') + + +class EncodeProtoOpTest(test_case.ProtoOpTestCase): + + def testBadInputs(self): + # Invalid field name + with self.test_session(): + with self.assertRaisesOpError('Unknown field: non_existent_field'): + encode_proto_op.encode_proto( + sizes=[[1]], + values=[np.array([[0.0]], dtype=np.int32)], + message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue', + field_names=['non_existent_field']).eval() + + # Incorrect types. + with self.test_session(): + with self.assertRaisesOpError( + 'Incompatible type for field double_value.'): + encode_proto_op.encode_proto( + sizes=[[1]], + values=[np.array([[0.0]], dtype=np.int32)], + message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue', + field_names=['double_value']).eval() + + # Incorrect shapes of sizes. + with self.test_session(): + with self.assertRaisesOpError( + r'sizes should be batch_size \+ \[len\(field_names\)\]'): + sizes = array_ops.placeholder(dtypes.int32) + values = array_ops.placeholder(dtypes.float64) + encode_proto_op.encode_proto( + sizes=sizes, + values=[values], + message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue', + field_names=['double_value']).eval(feed_dict={ + sizes: [[[0, 0]]], + values: [[0.0]] + }) + + # Inconsistent shapes of values. + with self.test_session(): + with self.assertRaisesOpError( + 'Values must match up to the last dimension'): + sizes = array_ops.placeholder(dtypes.int32) + values1 = array_ops.placeholder(dtypes.float64) + values2 = array_ops.placeholder(dtypes.int32) + (encode_proto_op.encode_proto( + sizes=[[1, 1]], + values=[values1, values2], + message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue', + field_names=['double_value', 'int32_value']).eval(feed_dict={ + values1: [[0.0]], + values2: [[0], [0]] + })) + + def _testRoundtrip(self, in_bufs, message_type, fields): + + field_names = [f.name for f in fields] + out_types = [f.dtype for f in fields] + + with self.test_session() as sess: + sizes, field_tensors = decode_proto_op.decode_proto( + in_bufs, + message_type=message_type, + field_names=field_names, + output_types=out_types) + + out_tensors = encode_proto_op.encode_proto( + sizes, + field_tensors, + message_type=message_type, + field_names=field_names) + + out_bufs, = sess.run([out_tensors]) + + # Check that the re-encoded tensor has the same shape. + self.assertEqual(in_bufs.shape, out_bufs.shape) + + # Compare the input and output. + for in_buf, out_buf in zip(in_bufs.flat, out_bufs.flat): + in_obj = test_example_pb2.RepeatedPrimitiveValue() + in_obj.ParseFromString(in_buf) + + out_obj = test_example_pb2.RepeatedPrimitiveValue() + out_obj.ParseFromString(out_buf) + + # Check that the deserialized objects are identical. + self.assertEqual(in_obj, out_obj) + + # Check that the input and output serialized messages are identical. + # If we fail here, there is a difference in the serialized + # representation but the new serialization still parses. This could + # be harmless (a change in map ordering?) or it could be bad (e.g. + # loss of packing in the encoding). + self.assertEqual(in_buf, out_buf) + + def testRoundtrip(self): + with open(FLAGS.message_text_file, 'r') as fp: + case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) + + in_bufs = [primitive.SerializeToString() for primitive in case.primitive] + + # np.array silently truncates strings if you don't specify dtype=object. + in_bufs = np.reshape(np.array(in_bufs, dtype=object), list(case.shape)) + return self._testRoundtrip( + in_bufs, 'tensorflow.contrib.proto.RepeatedPrimitiveValue', case.field) + + def testRoundtripPacked(self): + with open(FLAGS.message_text_file, 'r') as fp: + case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) + + # Now try with the packed serialization. + # We test the packed representations by loading the same test cases + # using PackedPrimitiveValue instead of RepeatedPrimitiveValue. + # To do this we rely on the text format being the same for packed and + # unpacked fields, and reparse the test message using the packed version + # of the proto. + in_bufs = [ + # Note: float_format='.17g' is necessary to ensure preservation of + # doubles and floats in text format. + text_format.Parse( + text_format.MessageToString( + primitive, float_format='.17g'), + test_example_pb2.PackedPrimitiveValue()).SerializeToString() + for primitive in case.primitive + ] + + # np.array silently truncates strings if you don't specify dtype=object. + in_bufs = np.reshape(np.array(in_bufs, dtype=object), list(case.shape)) + return self._testRoundtrip( + in_bufs, 'tensorflow.contrib.proto.PackedPrimitiveValue', case.field) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt new file mode 100644 index 0000000000..b170f89c0f --- /dev/null +++ b/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt @@ -0,0 +1,161 @@ +primitive { + double_value: -1.7976931348623158e+308 + double_value: 2.2250738585072014e-308 + double_value: 1.7976931348623158e+308 + float_value: -3.402823466e+38 + float_value: 1.175494351e-38 + float_value: 3.402823466e+38 + int64_value: -9223372036854775808 + int64_value: 9223372036854775807 + uint64_value: 0 + uint64_value: 18446744073709551615 + int32_value: -2147483648 + int32_value: 2147483647 + fixed64_value: 0 + fixed64_value: 18446744073709551615 + fixed32_value: 0 + fixed32_value: 4294967295 + bool_value: false + bool_value: true + string_value: "" + string_value: "I refer to the infinite." + uint32_value: 0 + uint32_value: 4294967295 + sfixed32_value: -2147483648 + sfixed32_value: 2147483647 + sfixed64_value: -9223372036854775808 + sfixed64_value: 9223372036854775807 + sint32_value: -2147483648 + sint32_value: 2147483647 + sint64_value: -9223372036854775808 + sint64_value: 9223372036854775807 +} +shape: 1 +sizes: 3 +sizes: 3 +sizes: 2 +sizes: 2 +sizes: 2 +sizes: 2 +sizes: 2 +sizes: 2 +sizes: 2 +sizes: 2 +sizes: 2 +sizes: 2 +sizes: 2 +sizes: 2 +field { + name: "double_value" + dtype: DT_DOUBLE + expected { + double_value: -1.7976931348623158e+308 + double_value: 2.2250738585072014e-308 + double_value: 1.7976931348623158e+308 + } +} +field { + name: "float_value" + dtype: DT_FLOAT + expected { + float_value: -3.402823466e+38 + float_value: 1.175494351e-38 + float_value: 3.402823466e+38 + } +} +field { + name: "int64_value" + dtype: DT_INT64 + expected { + int64_value: -9223372036854775808 + int64_value: 9223372036854775807 + } +} +field { + name: "uint64_value" + dtype: DT_INT64 + expected { + int64_value: 0 + int64_value: -1 + } +} +field { + name: "int32_value" + dtype: DT_INT32 + expected { + int32_value: -2147483648 + int32_value: 2147483647 + } +} +field { + name: "fixed64_value" + dtype: DT_INT64 + expected { + int64_value: 0 + int64_value: -1 # unsigned is 18446744073709551615 + } +} +field { + name: "fixed32_value" + dtype: DT_INT32 + expected { + int32_value: 0 + int32_value: -1 # unsigned is 4294967295 + } +} +field { + name: "bool_value" + dtype: DT_BOOL + expected { + bool_value: false + bool_value: true + } +} +field { + name: "string_value" + dtype: DT_STRING + expected { + string_value: "" + string_value: "I refer to the infinite." + } +} +field { + name: "uint32_value" + dtype: DT_INT32 + expected { + int32_value: 0 + int32_value: -1 # unsigned is 4294967295 + } +} +field { + name: "sfixed32_value" + dtype: DT_INT32 + expected { + int32_value: -2147483648 + int32_value: 2147483647 + } +} +field { + name: "sfixed64_value" + dtype: DT_INT64 + expected { + int64_value: -9223372036854775808 + int64_value: 9223372036854775807 + } +} +field { + name: "sint32_value" + dtype: DT_INT32 + expected { + int32_value: -2147483648 + int32_value: 2147483647 + } +} +field { + name: "sint64_value" + dtype: DT_INT64 + expected { + int64_value: -9223372036854775808 + int64_value: 9223372036854775807 + } +} diff --git a/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt new file mode 100644 index 0000000000..c664e52851 --- /dev/null +++ b/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt @@ -0,0 +1,16 @@ +primitive { + message_value { + double_value: 23.5 + } +} +shape: 1 +sizes: 1 +field { + name: "message_value" + dtype: DT_STRING + expected { + message_value { + double_value: 23.5 + } + } +} diff --git a/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt new file mode 100644 index 0000000000..125651d7ea --- /dev/null +++ b/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt @@ -0,0 +1,20 @@ +primitive { + bool_value: true +} +shape: 1 +sizes: 1 +sizes: 0 +field { + name: "bool_value" + dtype: DT_BOOL + expected { + bool_value: true + } +} +field { + name: "double_value" + dtype: DT_DOUBLE + expected { + double_value: 0.0 + } +} diff --git a/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt new file mode 100644 index 0000000000..db7555bf2d --- /dev/null +++ b/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt @@ -0,0 +1,21 @@ +primitive { + fixed32_value: 4294967295 + uint32_value: 4294967295 +} +shape: 1 +sizes: 1 +sizes: 1 +field { + name: "fixed32_value" + dtype: DT_INT64 + expected { + int64_value: 4294967295 + } +} +field { + name: "uint32_value" + dtype: DT_INT64 + expected { + int64_value: 4294967295 + } +} diff --git a/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt new file mode 100644 index 0000000000..61c7ac53f7 --- /dev/null +++ b/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt @@ -0,0 +1,32 @@ +primitive { + double_value: 23.5 + double_value: 123.0 + bool_value: true +} +primitive { + double_value: 3.1 + bool_value: false +} +shape: 2 +sizes: 2 +sizes: 1 +sizes: 1 +sizes: 1 +field { + name: "double_value" + dtype: DT_DOUBLE + expected { + double_value: 23.5 + double_value: 123.0 + double_value: 3.1 + double_value: 0.0 + } +} +field { + name: "bool_value" + dtype: DT_BOOL + expected { + bool_value: true + bool_value: false + } +} diff --git a/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt new file mode 100644 index 0000000000..f4828076d5 --- /dev/null +++ b/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt @@ -0,0 +1,62 @@ +primitive { + double_value: 23.5 + bool_value: true +} +primitive { + double_value: 44.0 + bool_value: false +} +primitive { + double_value: 3.14159 + bool_value: true +} +primitive { + double_value: 1.414 + bool_value: true +} +primitive { + double_value: -32.2 + bool_value: false +} +primitive { + double_value: 0.0001 + bool_value: true +} +shape: 3 +shape: 2 +sizes: 1 +sizes: 1 +sizes: 1 +sizes: 1 +sizes: 1 +sizes: 1 +sizes: 1 +sizes: 1 +sizes: 1 +sizes: 1 +sizes: 1 +sizes: 1 +field { + name: "double_value" + dtype: DT_DOUBLE + expected { + double_value: 23.5 + double_value: 44.0 + double_value: 3.14159 + double_value: 1.414 + double_value: -32.2 + double_value: 0.0001 + } +} +field { + name: "bool_value" + dtype: DT_BOOL + expected { + bool_value: true + bool_value: false + bool_value: true + bool_value: true + bool_value: false + bool_value: true + } +} diff --git a/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt new file mode 100644 index 0000000000..dc20ac147b --- /dev/null +++ b/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt @@ -0,0 +1,21 @@ +primitive { + double_value: 23.5 + bool_value: true +} +shape: 1 +sizes: 1 +sizes: 1 +field { + name: "double_value" + dtype: DT_DOUBLE + expected { + double_value: 23.5 + } +} +field { + name: "bool_value" + dtype: DT_BOOL + expected { + bool_value: true + } +} diff --git a/tensorflow/contrib/proto/python/kernel_tests/test_case.py b/tensorflow/contrib/proto/python/kernel_tests/test_case.py new file mode 100644 index 0000000000..b95202c5df --- /dev/null +++ b/tensorflow/contrib/proto/python/kernel_tests/test_case.py @@ -0,0 +1,35 @@ +# ============================================================================= +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Test case base for testing proto operations.""" + +# Python3 preparedness imports. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ctypes as ct +import os + +from tensorflow.python.platform import test + + +class ProtoOpTestCase(test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + super(ProtoOpTestCase, self).__init__(methodName) + lib = os.path.join(os.path.dirname(__file__), 'libtestexample.so') + if os.path.isfile(lib): + ct.cdll.LoadLibrary(lib) diff --git a/tensorflow/contrib/proto/python/kernel_tests/test_example.proto b/tensorflow/contrib/proto/python/kernel_tests/test_example.proto new file mode 100644 index 0000000000..dc495034ff --- /dev/null +++ b/tensorflow/contrib/proto/python/kernel_tests/test_example.proto @@ -0,0 +1,149 @@ +// Test description and protos to work with it. +// +// Many of the protos in this file are for unit tests that haven't been written yet. + +syntax = "proto2"; + +import "tensorflow/core/framework/types.proto"; + +package tensorflow.contrib.proto; + +// A TestCase holds a proto and a bunch of assertions +// about how it should decode. +message TestCase { + // A batch of primitives to be serialized and decoded. + repeated RepeatedPrimitiveValue primitive = 1; + // The shape of the batch. + repeated int32 shape = 2; + // Expected sizes for each field. + repeated int32 sizes = 3; + // Expected values for each field. + repeated FieldSpec field = 4; +}; + +// FieldSpec describes the expected output for a single field. +message FieldSpec { + optional string name = 1; + optional tensorflow.DataType dtype = 2; + optional RepeatedPrimitiveValue expected = 3; +}; + +message TestValue { + optional PrimitiveValue primitive_value = 1; + optional EnumValue enum_value = 2; + optional MessageValue message_value = 3; + optional RepeatedMessageValue repeated_message_value = 4; + optional RepeatedPrimitiveValue repeated_primitive_value = 6; +} + +message PrimitiveValue { + optional double double_value = 1; + optional float float_value = 2; + optional int64 int64_value = 3; + optional uint64 uint64_value = 4; + optional int32 int32_value = 5; + optional fixed64 fixed64_value = 6; + optional fixed32 fixed32_value = 7; + optional bool bool_value = 8; + optional string string_value = 9; + optional bytes bytes_value = 12; + optional uint32 uint32_value = 13; + optional sfixed32 sfixed32_value = 15; + optional sfixed64 sfixed64_value = 16; + optional sint32 sint32_value = 17; + optional sint64 sint64_value = 18; +} + +// NOTE: This definition must be kept in sync with PackedPrimitiveValue. +message RepeatedPrimitiveValue { + repeated double double_value = 1; + repeated float float_value = 2; + repeated int64 int64_value = 3; + repeated uint64 uint64_value = 4; + repeated int32 int32_value = 5; + repeated fixed64 fixed64_value = 6; + repeated fixed32 fixed32_value = 7; + repeated bool bool_value = 8; + repeated string string_value = 9; + repeated bytes bytes_value = 12; + repeated uint32 uint32_value = 13; + repeated sfixed32 sfixed32_value = 15; + repeated sfixed64 sfixed64_value = 16; + repeated sint32 sint32_value = 17; + repeated sint64 sint64_value = 18; + repeated PrimitiveValue message_value = 19; +} + +// A PackedPrimitiveValue looks exactly the same as a RepeatedPrimitiveValue +// in the text format, but the binary serializion is different. +// We test the packed representations by loading the same test cases +// using this definition instead of RepeatedPrimitiveValue. +// NOTE: This definition must be kept in sync with RepeatedPrimitiveValue +// in every way except the packed=true declaration. +message PackedPrimitiveValue { + repeated double double_value = 1 [packed = true]; + repeated float float_value = 2 [packed = true]; + repeated int64 int64_value = 3 [packed = true]; + repeated uint64 uint64_value = 4 [packed = true]; + repeated int32 int32_value = 5 [packed = true]; + repeated fixed64 fixed64_value = 6 [packed = true]; + repeated fixed32 fixed32_value = 7 [packed = true]; + repeated bool bool_value = 8 [packed = true]; + repeated string string_value = 9; + repeated bytes bytes_value = 12; + repeated uint32 uint32_value = 13 [packed = true]; + repeated sfixed32 sfixed32_value = 15 [packed = true]; + repeated sfixed64 sfixed64_value = 16 [packed = true]; + repeated sint32 sint32_value = 17 [packed = true]; + repeated sint64 sint64_value = 18 [packed = true]; + repeated PrimitiveValue message_value = 19; +} + +message EnumValue { + enum Color { + RED = 0; + ORANGE = 1; + YELLOW = 2; + GREEN = 3; + BLUE = 4; + INDIGO = 5; + VIOLET = 6; + }; + optional Color enum_value = 14; + repeated Color repeated_enum_value = 15; +} + + +message InnerMessageValue { + optional float float_value = 2; + repeated bytes bytes_values = 8; +} + +message MiddleMessageValue { + repeated int32 int32_values = 5; + optional InnerMessageValue message_value = 11; + optional uint32 uint32_value = 13; +} + +message MessageValue { + optional double double_value = 1; + optional MiddleMessageValue message_value = 11; +} + +message RepeatedMessageValue { + message NestedMessageValue { + optional float float_value = 2; + repeated bytes bytes_values = 8; + } + + repeated NestedMessageValue message_values = 11; +} + +// Message containing fields with field numbers higher than any field above. An +// instance of this message is prepended to each binary message in the test to +// exercise the code path that handles fields encoded out of order of field +// number. +message ExtraFields { + optional string string_value = 1776; + optional bool bool_value = 1777; +} -- cgit v1.2.3