diff options
author | Jiri Simsa <jsimsa@google.com> | 2018-07-11 08:59:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-11 09:03:40 -0700 |
commit | 50d121e2365a4edffe59741df19e0dcc96291576 (patch) | |
tree | 89c4063303cc2ea7e0151056d4c5e38a42727f45 /tensorflow/contrib/proto | |
parent | 1e2438318dd250132572a23458598f8c4c4d9ce5 (diff) |
Improvements to testing of proto decode / encode ops.
PiperOrigin-RevId: 204132868
Diffstat (limited to 'tensorflow/contrib/proto')
16 files changed, 481 insertions, 633 deletions
diff --git a/tensorflow/contrib/proto/BUILD b/tensorflow/contrib/proto/BUILD index 3e9b1a0b8d..d45622174f 100644 --- a/tensorflow/contrib/proto/BUILD +++ b/tensorflow/contrib/proto/BUILD @@ -19,9 +19,7 @@ py_library( py_library( name = "proto_pip", - data = [ - "//tensorflow/contrib/proto/python/kernel_tests:test_messages", - ] + if_static( + data = if_static( [], otherwise = ["//tensorflow/contrib/proto/python/kernel_tests:libtestexample.so"], ), diff --git a/tensorflow/contrib/proto/python/kernel_tests/BUILD b/tensorflow/contrib/proto/python/kernel_tests/BUILD index a380a131f8..3f53ef1707 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/BUILD +++ b/tensorflow/contrib/proto/python/kernel_tests/BUILD @@ -4,33 +4,6 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -# Much of the work in this BUILD file actually happens in the corresponding -# build_defs.bzl, which creates an individual testcase for each example .pbtxt -# file in this directory. -# -load(":build_defs.bzl", "decode_proto_test_suite") -load(":build_defs.bzl", "encode_proto_test_suite") - -# This expands to a tf_py_test for each test file. -# It defines the test_suite :decode_proto_op_tests. -decode_proto_test_suite( - name = "decode_proto_tests", - examples = glob(["*.pbtxt"]), -) - -# This expands to a tf_py_test for each test file. -# It defines the test_suite :encode_proto_op_tests. -encode_proto_test_suite( - name = "encode_proto_tests", - examples = glob(["*.pbtxt"]), -) - -# Below here are tests that are not tied to an example text proto. -filegroup( - name = "test_messages", - srcs = glob(["*.pbtxt"]), -) - load("//tensorflow:tensorflow.bzl", "tf_py_test") load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object") load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static") @@ -56,16 +29,62 @@ tf_py_test( ], ) +tf_py_test( + name = "decode_proto_op_test", + size = "small", + srcs = ["decode_proto_op_test.py"], + additional_deps = [ + ":py_test_deps", + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", + "//tensorflow/contrib/proto:proto", + "//tensorflow/contrib/proto/python/ops:decode_proto_op_py", + ], + data = if_static( + [], + otherwise = [":libtestexample.so"], + ), + tags = [ + "no_pip", # TODO(b/78026780) + "no_windows", # TODO(b/78028010) + ], +) + +tf_py_test( + name = "encode_proto_op_test", + size = "small", + srcs = ["encode_proto_op_test.py"], + additional_deps = [ + ":py_test_deps", + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", + "//tensorflow/contrib/proto:proto", + "//tensorflow/contrib/proto/python/ops:decode_proto_op_py", + "//tensorflow/contrib/proto/python/ops:encode_proto_op_py", + ], + data = if_static( + [], + otherwise = [":libtestexample.so"], + ), + tags = [ + "no_pip", # TODO(b/78026780) + "no_windows", # TODO(b/78028010) + ], +) + py_library( - name = "test_case", - srcs = ["test_case.py"], - deps = ["//tensorflow/python:client_testlib"], + name = "test_base", + srcs = ["test_base.py"], + deps = [ + ":test_example_proto_py", + "//tensorflow/python:client_testlib", + ], ) py_library( name = "py_test_deps", deps = [ - ":test_case", + ":test_base", ":test_example_proto_py", ], ) diff --git a/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl b/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl deleted file mode 100644 index f425601691..0000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl +++ /dev/null @@ -1,89 +0,0 @@ -"""BUILD rules for generating file-driven proto test cases. - -The decode_proto_test_suite() and encode_proto_test_suite() rules take a list -of text protos and generates a tf_py_test() for each one. -""" - -load("//tensorflow:tensorflow.bzl", "tf_py_test") -load("//tensorflow:tensorflow.bzl", "register_extension_info") -load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static") - -def _test_name(test, path): - return "%s_%s_test" % (test, path.split("/")[-1].split(".")[0]) - -def decode_proto_test_suite(name, examples): - """Build the decode_proto py_test for each test filename.""" - for test_filename in examples: - tf_py_test( - name = _test_name("decode_proto", test_filename), - srcs = ["decode_proto_op_test.py"], - size = "small", - data = [test_filename] + if_static( - [], - otherwise = [":libtestexample.so"], - ), - main = "decode_proto_op_test.py", - args = [ - "--message_text_file=\"%s/%s\"" % (native.package_name(), test_filename), - ], - additional_deps = [ - ":py_test_deps", - "//third_party/py/numpy", - "//tensorflow/contrib/proto:proto", - "//tensorflow/contrib/proto/python/ops:decode_proto_op_py", - ], - tags = [ - "no_pip", # TODO(b/78026780) - "no_windows", # TODO(b/78028010) - ], - ) - native.test_suite( - name = name, - tests = [":" + _test_name("decode_proto", test_filename) - for test_filename in examples], - ) - -def encode_proto_test_suite(name, examples): - """Build the encode_proto py_test for each test filename.""" - for test_filename in examples: - tf_py_test( - name = _test_name("encode_proto", test_filename), - srcs = ["encode_proto_op_test.py"], - size = "small", - data = [test_filename] + if_static( - [], - otherwise = [":libtestexample.so"], - ), - main = "encode_proto_op_test.py", - args = [ - "--message_text_file=\"%s/%s\"" % (native.package_name(), test_filename), - ], - additional_deps = [ - ":py_test_deps", - "//third_party/py/numpy", - "//tensorflow/contrib/proto:proto", - "//tensorflow/contrib/proto/python/ops:decode_proto_op_py", - "//tensorflow/contrib/proto/python/ops:encode_proto_op_py", - ], - tags = [ - "no_pip", # TODO(b/78026780) - "no_windows", # TODO(b/78028010) - ], - ) - native.test_suite( - name = name, - tests = [":" + _test_name("encode_proto", test_filename) - for test_filename in examples], - ) - -register_extension_info( - extension_name = "decode_proto_test_suite", - label_regex_map = { - "deps": "deps:decode_example_.*", - }) - -register_extension_info( - extension_name = "encode_proto_test_suite", - label_regex_map = { - "deps": "deps:encode_example_.*", - }) diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py index 5298342ee7..3b982864bc 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py +++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py @@ -21,14 +21,14 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.proto.python.kernel_tests import test_case +from tensorflow.contrib.proto.python.kernel_tests import test_base from tensorflow.contrib.proto.python.ops import decode_proto_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.platform import test -class DecodeProtoFailTest(test_case.ProtoOpTestCase): +class DecodeProtoFailTest(test_base.ProtoOpTestBase): """Test failure cases for DecodeToProto.""" def _TestCorruptProtobuf(self, sanitize): diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py index d1c13c82bc..2a07794499 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py +++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py @@ -23,24 +23,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np + from google.protobuf import text_format -from tensorflow.contrib.proto.python.kernel_tests import test_case +from tensorflow.contrib.proto.python.kernel_tests import test_base from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2 from tensorflow.contrib.proto.python.ops import decode_proto_op from tensorflow.python.framework import dtypes -from tensorflow.python.platform import flags from tensorflow.python.platform import test -FLAGS = flags.FLAGS - -flags.DEFINE_string('message_text_file', None, - 'A file containing a text serialized TestCase protobuf.') - -class DecodeProtoOpTest(test_case.ProtoOpTestCase): +class DecodeProtoOpTest(test_base.ProtoOpTestBase, parameterized.TestCase): def _compareValues(self, fd, vs, evs): """Compare lists/arrays of field values.""" @@ -203,10 +199,8 @@ class DecodeProtoOpTest(test_case.ProtoOpTestCase): self._compareRepeatedPrimitiveValue(batch_shape, sizes, fields, field_dict) - def testBinary(self): - with open(FLAGS.message_text_file, 'r') as fp: - case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) - + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testBinary(self, case): batch = [primitive.SerializeToString() for primitive in case.primitive] self._runDecodeProtoTests( case.field, @@ -217,10 +211,8 @@ class DecodeProtoOpTest(test_case.ProtoOpTestCase): 'binary', sanitize=False) - def testBinaryDisordered(self): - with open(FLAGS.message_text_file, 'r') as fp: - case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) - + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testBinaryDisordered(self, case): batch = [primitive.SerializeToString() for primitive in case.primitive] self._runDecodeProtoTests( case.field, @@ -232,10 +224,8 @@ class DecodeProtoOpTest(test_case.ProtoOpTestCase): sanitize=False, force_disordered=True) - def testPacked(self): - with open(FLAGS.message_text_file, 'r') as fp: - case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) - + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testPacked(self, case): # Now try with the packed serialization. # We test the packed representations by loading the same test cases # using PackedPrimitiveValue instead of RepeatedPrimitiveValue. @@ -261,10 +251,8 @@ class DecodeProtoOpTest(test_case.ProtoOpTestCase): 'binary', sanitize=False) - def testText(self): - with open(FLAGS.message_text_file, 'r') as fp: - case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) - + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testText(self, case): # Note: float_format='.17g' is necessary to ensure preservation of # doubles and floats in text format. text_batch = [ @@ -281,10 +269,8 @@ class DecodeProtoOpTest(test_case.ProtoOpTestCase): 'text', sanitize=False) - def testSanitizerGood(self): - with open(FLAGS.message_text_file, 'r') as fp: - case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) - + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testSanitizerGood(self, case): batch = [primitive.SerializeToString() for primitive in case.primitive] self._runDecodeProtoTests( case.field, diff --git a/tensorflow/contrib/proto/python/kernel_tests/defaut_values.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/defaut_values.TestCase.pbtxt deleted file mode 100644 index 4e31681907..0000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/defaut_values.TestCase.pbtxt +++ /dev/null @@ -1,94 +0,0 @@ -primitive { - # No fields specified, so we get all defaults -} -shape: 1 -sizes: 0 -field { - name: "double_default" - dtype: DT_DOUBLE - expected { double_value: 1.0 } -} -sizes: 0 -field { - name: "float_default" - dtype: DT_DOUBLE # Try casting the float field to double. - expected { double_value: 2.0 } -} -sizes: 0 -field { - name: "int64_default" - dtype: DT_INT64 - expected { int64_value: 3 } -} -sizes: 0 -field { - name: "uint64_default" - dtype: DT_INT64 - expected { int64_value: 4 } -} -sizes: 0 -field { - name: "int32_default" - dtype: DT_INT32 - expected { int32_value: 5 } -} -sizes: 0 -field { - name: "fixed64_default" - dtype: DT_INT64 - expected { int64_value: 6 } -} -sizes: 0 -field { - name: "fixed32_default" - dtype: DT_INT32 - expected { int32_value: 7 } -} -sizes: 0 -field { - name: "bool_default" - dtype: DT_BOOL - expected { bool_value: true } -} -sizes: 0 -field { - name: "string_default" - dtype: DT_STRING - expected { string_value: "a" } -} -sizes: 0 -field { - name: "bytes_default" - dtype: DT_STRING - expected { string_value: "a longer default string" } -} -sizes: 0 -field { - name: "uint32_default" - dtype: DT_INT32 - expected { int32_value: -1 } -} -sizes: 0 -field { - name: "sfixed32_default" - dtype: DT_INT32 - expected { int32_value: 10 } -} -sizes: 0 -field { - name: "sfixed64_default" - dtype: DT_INT64 - expected { int64_value: 11 } -} -sizes: 0 -field { - name: "sint32_default" - dtype: DT_INT32 - expected { int32_value: 12 } -} -sizes: 0 -field { - name: "sint64_default" - dtype: DT_INT64 - expected { int64_value: 13 } -} diff --git a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py index 30e58e6336..fb33660554 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py +++ b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py @@ -26,11 +26,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from google.protobuf import text_format -from tensorflow.contrib.proto.python.kernel_tests import test_case +from tensorflow.contrib.proto.python.kernel_tests import test_base from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2 from tensorflow.contrib.proto.python.ops import decode_proto_op from tensorflow.contrib.proto.python.ops import encode_proto_op @@ -45,7 +46,7 @@ flags.DEFINE_string('message_text_file', None, 'A file containing a text serialized TestCase protobuf.') -class EncodeProtoOpTest(test_case.ProtoOpTestCase): +class EncodeProtoOpTest(test_base.ProtoOpTestBase, parameterized.TestCase): def testBadInputs(self): # Invalid field name @@ -139,10 +140,8 @@ class EncodeProtoOpTest(test_case.ProtoOpTestCase): # loss of packing in the encoding). self.assertEqual(in_buf, out_buf) - def testRoundtrip(self): - with open(FLAGS.message_text_file, 'r') as fp: - case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) - + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testRoundtrip(self, case): in_bufs = [primitive.SerializeToString() for primitive in case.primitive] # np.array silently truncates strings if you don't specify dtype=object. @@ -150,10 +149,8 @@ class EncodeProtoOpTest(test_case.ProtoOpTestCase): return self._testRoundtrip( in_bufs, 'tensorflow.contrib.proto.RepeatedPrimitiveValue', case.field) - def testRoundtripPacked(self): - with open(FLAGS.message_text_file, 'r') as fp: - case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) - + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testRoundtripPacked(self, case): # Now try with the packed serialization. # We test the packed representations by loading the same test cases # using PackedPrimitiveValue instead of RepeatedPrimitiveValue. diff --git a/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt deleted file mode 100644 index b170f89c0f..0000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt +++ /dev/null @@ -1,161 +0,0 @@ -primitive { - double_value: -1.7976931348623158e+308 - double_value: 2.2250738585072014e-308 - double_value: 1.7976931348623158e+308 - float_value: -3.402823466e+38 - float_value: 1.175494351e-38 - float_value: 3.402823466e+38 - int64_value: -9223372036854775808 - int64_value: 9223372036854775807 - uint64_value: 0 - uint64_value: 18446744073709551615 - int32_value: -2147483648 - int32_value: 2147483647 - fixed64_value: 0 - fixed64_value: 18446744073709551615 - fixed32_value: 0 - fixed32_value: 4294967295 - bool_value: false - bool_value: true - string_value: "" - string_value: "I refer to the infinite." - uint32_value: 0 - uint32_value: 4294967295 - sfixed32_value: -2147483648 - sfixed32_value: 2147483647 - sfixed64_value: -9223372036854775808 - sfixed64_value: 9223372036854775807 - sint32_value: -2147483648 - sint32_value: 2147483647 - sint64_value: -9223372036854775808 - sint64_value: 9223372036854775807 -} -shape: 1 -sizes: 3 -sizes: 3 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -field { - name: "double_value" - dtype: DT_DOUBLE - expected { - double_value: -1.7976931348623158e+308 - double_value: 2.2250738585072014e-308 - double_value: 1.7976931348623158e+308 - } -} -field { - name: "float_value" - dtype: DT_FLOAT - expected { - float_value: -3.402823466e+38 - float_value: 1.175494351e-38 - float_value: 3.402823466e+38 - } -} -field { - name: "int64_value" - dtype: DT_INT64 - expected { - int64_value: -9223372036854775808 - int64_value: 9223372036854775807 - } -} -field { - name: "uint64_value" - dtype: DT_INT64 - expected { - int64_value: 0 - int64_value: -1 - } -} -field { - name: "int32_value" - dtype: DT_INT32 - expected { - int32_value: -2147483648 - int32_value: 2147483647 - } -} -field { - name: "fixed64_value" - dtype: DT_INT64 - expected { - int64_value: 0 - int64_value: -1 # unsigned is 18446744073709551615 - } -} -field { - name: "fixed32_value" - dtype: DT_INT32 - expected { - int32_value: 0 - int32_value: -1 # unsigned is 4294967295 - } -} -field { - name: "bool_value" - dtype: DT_BOOL - expected { - bool_value: false - bool_value: true - } -} -field { - name: "string_value" - dtype: DT_STRING - expected { - string_value: "" - string_value: "I refer to the infinite." - } -} -field { - name: "uint32_value" - dtype: DT_INT32 - expected { - int32_value: 0 - int32_value: -1 # unsigned is 4294967295 - } -} -field { - name: "sfixed32_value" - dtype: DT_INT32 - expected { - int32_value: -2147483648 - int32_value: 2147483647 - } -} -field { - name: "sfixed64_value" - dtype: DT_INT64 - expected { - int64_value: -9223372036854775808 - int64_value: 9223372036854775807 - } -} -field { - name: "sint32_value" - dtype: DT_INT32 - expected { - int32_value: -2147483648 - int32_value: 2147483647 - } -} -field { - name: "sint64_value" - dtype: DT_INT64 - expected { - int64_value: -9223372036854775808 - int64_value: 9223372036854775807 - } -} diff --git a/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt deleted file mode 100644 index c664e52851..0000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt +++ /dev/null @@ -1,16 +0,0 @@ -primitive { - message_value { - double_value: 23.5 - } -} -shape: 1 -sizes: 1 -field { - name: "message_value" - dtype: DT_STRING - expected { - message_value { - double_value: 23.5 - } - } -} diff --git a/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt deleted file mode 100644 index 125651d7ea..0000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt +++ /dev/null @@ -1,20 +0,0 @@ -primitive { - bool_value: true -} -shape: 1 -sizes: 1 -sizes: 0 -field { - name: "bool_value" - dtype: DT_BOOL - expected { - bool_value: true - } -} -field { - name: "double_value" - dtype: DT_DOUBLE - expected { - double_value: 0.0 - } -} diff --git a/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt deleted file mode 100644 index bc07efc8f3..0000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt +++ /dev/null @@ -1,29 +0,0 @@ -primitive { - fixed32_value: 4294967295 - uint32_value: 4294967295 -} -shape: 1 -sizes: 1 -field { - name: "fixed32_value" - dtype: DT_INT64 - expected { - int64_value: 4294967295 - } -} -sizes: 1 -field { - name: "uint32_value" - dtype: DT_INT64 - expected { - int64_value: 4294967295 - } -} -sizes: 0 -field { - name: "uint32_default" - dtype: DT_INT64 - expected { - int64_value: 4294967295 # Comes from an explicitly-specified default - } -} diff --git a/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt deleted file mode 100644 index 61c7ac53f7..0000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt +++ /dev/null @@ -1,32 +0,0 @@ -primitive { - double_value: 23.5 - double_value: 123.0 - bool_value: true -} -primitive { - double_value: 3.1 - bool_value: false -} -shape: 2 -sizes: 2 -sizes: 1 -sizes: 1 -sizes: 1 -field { - name: "double_value" - dtype: DT_DOUBLE - expected { - double_value: 23.5 - double_value: 123.0 - double_value: 3.1 - double_value: 0.0 - } -} -field { - name: "bool_value" - dtype: DT_BOOL - expected { - bool_value: true - bool_value: false - } -} diff --git a/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt deleted file mode 100644 index f4828076d5..0000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt +++ /dev/null @@ -1,62 +0,0 @@ -primitive { - double_value: 23.5 - bool_value: true -} -primitive { - double_value: 44.0 - bool_value: false -} -primitive { - double_value: 3.14159 - bool_value: true -} -primitive { - double_value: 1.414 - bool_value: true -} -primitive { - double_value: -32.2 - bool_value: false -} -primitive { - double_value: 0.0001 - bool_value: true -} -shape: 3 -shape: 2 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -field { - name: "double_value" - dtype: DT_DOUBLE - expected { - double_value: 23.5 - double_value: 44.0 - double_value: 3.14159 - double_value: 1.414 - double_value: -32.2 - double_value: 0.0001 - } -} -field { - name: "bool_value" - dtype: DT_BOOL - expected { - bool_value: true - bool_value: false - bool_value: true - bool_value: true - bool_value: false - bool_value: true - } -} diff --git a/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt deleted file mode 100644 index dc20ac147b..0000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt +++ /dev/null @@ -1,21 +0,0 @@ -primitive { - double_value: 23.5 - bool_value: true -} -shape: 1 -sizes: 1 -sizes: 1 -field { - name: "double_value" - dtype: DT_DOUBLE - expected { - double_value: 23.5 - } -} -field { - name: "bool_value" - dtype: DT_BOOL - expected { - bool_value: true - } -} diff --git a/tensorflow/contrib/proto/python/kernel_tests/test_base.py b/tensorflow/contrib/proto/python/kernel_tests/test_base.py new file mode 100644 index 0000000000..1fc8c16786 --- /dev/null +++ b/tensorflow/contrib/proto/python/kernel_tests/test_base.py @@ -0,0 +1,407 @@ +# ============================================================================= +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Test case base for testing proto operations.""" + +# Python3 preparedness imports. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ctypes as ct +import os + +from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2 +from tensorflow.core.framework import types_pb2 +from tensorflow.python.platform import test + + +class ProtoOpTestBase(test.TestCase): + """Base class for testing proto decoding and encoding ops.""" + + def __init__(self, methodName="runTest"): # pylint: disable=invalid-name + super(ProtoOpTestBase, self).__init__(methodName) + lib = os.path.join(os.path.dirname(__file__), "libtestexample.so") + if os.path.isfile(lib): + ct.cdll.LoadLibrary(lib) + + @staticmethod + def named_parameters(): + return ( + ("defaults", ProtoOpTestBase.defaults_test_case()), + ("minmax", ProtoOpTestBase.minmax_test_case()), + ("nested", ProtoOpTestBase.nested_test_case()), + ("optional", ProtoOpTestBase.optional_test_case()), + ("promote_unsigned", ProtoOpTestBase.promote_unsigned_test_case()), + ("ragged", ProtoOpTestBase.ragged_test_case()), + ("shaped_batch", ProtoOpTestBase.shaped_batch_test_case()), + ("simple", ProtoOpTestBase.simple_test_case()), + ) + + @staticmethod + def defaults_test_case(): + test_case = test_example_pb2.TestCase() + test_case.primitive.add() # No fields specified, so we get all defaults. + test_case.shape.append(1) + test_case.sizes.append(0) + field = test_case.field.add() + field.name = "double_default" + field.dtype = types_pb2.DT_DOUBLE + field.expected.double_value.append(1.0) + test_case.sizes.append(0) + field = test_case.field.add() + field.name = "float_default" + field.dtype = types_pb2.DT_FLOAT + field.expected.float_value.append(2.0) + test_case.sizes.append(0) + field = test_case.field.add() + field.name = "int64_default" + field.dtype = types_pb2.DT_INT64 + field.expected.int64_value.append(3) + test_case.sizes.append(0) + field = test_case.field.add() + field.name = "sfixed64_default" + field.dtype = types_pb2.DT_INT64 + field.expected.int64_value.append(11) + test_case.sizes.append(0) + field = test_case.field.add() + field.name = "sint64_default" + field.dtype = types_pb2.DT_INT64 + field.expected.int64_value.append(13) + test_case.sizes.append(0) + field = test_case.field.add() + field.name = "uint64_default" + field.dtype = types_pb2.DT_INT64 + field.expected.int64_value.append(4) + test_case.sizes.append(0) + field = test_case.field.add() + field.name = "fixed64_default" + field.dtype = types_pb2.DT_INT64 + field.expected.int64_value.append(6) + test_case.sizes.append(0) + field = test_case.field.add() + field.name = "int32_default" + field.dtype = types_pb2.DT_INT32 + field.expected.int32_value.append(5) + test_case.sizes.append(0) + field = test_case.field.add() + field.name = "sfixed32_default" + field.dtype = types_pb2.DT_INT32 + field.expected.int32_value.append(10) + test_case.sizes.append(0) + field = test_case.field.add() + field.name = "sint32_default" + field.dtype = types_pb2.DT_INT32 + field.expected.int32_value.append(12) + test_case.sizes.append(0) + field = test_case.field.add() + field.name = "uint32_default" + field.dtype = types_pb2.DT_INT32 + field.expected.int32_value.append(-1) + test_case.sizes.append(0) + field = test_case.field.add() + field.name = "fixed32_default" + field.dtype = types_pb2.DT_INT32 + field.expected.int32_value.append(7) + test_case.sizes.append(0) + field = test_case.field.add() + field.name = "bool_default" + field.dtype = types_pb2.DT_BOOL + field.expected.bool_value.append(True) + test_case.sizes.append(0) + field = test_case.field.add() + field.name = "string_default" + field.dtype = types_pb2.DT_STRING + field.expected.string_value.append("a") + test_case.sizes.append(0) + field = test_case.field.add() + field.name = "bytes_default" + field.dtype = types_pb2.DT_STRING + field.expected.string_value.append("a longer default string") + return test_case + + @staticmethod + def minmax_test_case(): + test_case = test_example_pb2.TestCase() + primitive = test_case.primitive.add() + primitive.double_value.append(-1.7976931348623158e+308) + primitive.double_value.append(2.2250738585072014e-308) + primitive.double_value.append(1.7976931348623158e+308) + primitive.float_value.append(-3.402823466e+38) + primitive.float_value.append(1.175494351e-38) + primitive.float_value.append(3.402823466e+38) + primitive.int64_value.append(-9223372036854775808) + primitive.int64_value.append(9223372036854775807) + primitive.sfixed64_value.append(-9223372036854775808) + primitive.sfixed64_value.append(9223372036854775807) + primitive.sint64_value.append(-9223372036854775808) + primitive.sint64_value.append(9223372036854775807) + primitive.uint64_value.append(0) + primitive.uint64_value.append(18446744073709551615) + primitive.fixed64_value.append(0) + primitive.fixed64_value.append(18446744073709551615) + primitive.int32_value.append(-2147483648) + primitive.int32_value.append(2147483647) + primitive.sfixed32_value.append(-2147483648) + primitive.sfixed32_value.append(2147483647) + primitive.sint32_value.append(-2147483648) + primitive.sint32_value.append(2147483647) + primitive.uint32_value.append(0) + primitive.uint32_value.append(4294967295) + primitive.fixed32_value.append(0) + primitive.fixed32_value.append(4294967295) + primitive.bool_value.append(False) + primitive.bool_value.append(True) + primitive.string_value.append("") + primitive.string_value.append("I refer to the infinite.") + test_case.shape.append(1) + test_case.sizes.append(3) + field = test_case.field.add() + field.name = "double_value" + field.dtype = types_pb2.DT_DOUBLE + field.expected.double_value.append(-1.7976931348623158e+308) + field.expected.double_value.append(2.2250738585072014e-308) + field.expected.double_value.append(1.7976931348623158e+308) + test_case.sizes.append(3) + field = test_case.field.add() + field.name = "float_value" + field.dtype = types_pb2.DT_FLOAT + field.expected.float_value.append(-3.402823466e+38) + field.expected.float_value.append(1.175494351e-38) + field.expected.float_value.append(3.402823466e+38) + test_case.sizes.append(2) + field = test_case.field.add() + field.name = "int64_value" + field.dtype = types_pb2.DT_INT64 + field.expected.int64_value.append(-9223372036854775808) + field.expected.int64_value.append(9223372036854775807) + test_case.sizes.append(2) + field = test_case.field.add() + field.name = "sfixed64_value" + field.dtype = types_pb2.DT_INT64 + field.expected.int64_value.append(-9223372036854775808) + field.expected.int64_value.append(9223372036854775807) + test_case.sizes.append(2) + field = test_case.field.add() + field.name = "sint64_value" + field.dtype = types_pb2.DT_INT64 + field.expected.int64_value.append(-9223372036854775808) + field.expected.int64_value.append(9223372036854775807) + test_case.sizes.append(2) + field = test_case.field.add() + field.name = "uint64_value" + field.dtype = types_pb2.DT_INT64 + field.expected.int64_value.append(0) + field.expected.int64_value.append(-1) + test_case.sizes.append(2) + field = test_case.field.add() + field.name = "fixed64_value" + field.dtype = types_pb2.DT_INT64 + field.expected.int64_value.append(0) + field.expected.int64_value.append(-1) + test_case.sizes.append(2) + field = test_case.field.add() + field.name = "int32_value" + field.dtype = types_pb2.DT_INT32 + field.expected.int32_value.append(-2147483648) + field.expected.int32_value.append(2147483647) + test_case.sizes.append(2) + field = test_case.field.add() + field.name = "sfixed32_value" + field.dtype = types_pb2.DT_INT32 + field.expected.int32_value.append(-2147483648) + field.expected.int32_value.append(2147483647) + test_case.sizes.append(2) + field = test_case.field.add() + field.name = "sint32_value" + field.dtype = types_pb2.DT_INT32 + field.expected.int32_value.append(-2147483648) + field.expected.int32_value.append(2147483647) + test_case.sizes.append(2) + field = test_case.field.add() + field.name = "uint32_value" + field.dtype = types_pb2.DT_INT32 + field.expected.int32_value.append(0) + field.expected.int32_value.append(-1) + test_case.sizes.append(2) + field = test_case.field.add() + field.name = "fixed32_value" + field.dtype = types_pb2.DT_INT32 + field.expected.int32_value.append(0) + field.expected.int32_value.append(-1) + test_case.sizes.append(2) + field = test_case.field.add() + field.name = "bool_value" + field.dtype = types_pb2.DT_BOOL + field.expected.bool_value.append(False) + field.expected.bool_value.append(True) + test_case.sizes.append(2) + field = test_case.field.add() + field.name = "string_value" + field.dtype = types_pb2.DT_STRING + field.expected.string_value.append("") + field.expected.string_value.append("I refer to the infinite.") + return test_case + + @staticmethod + def nested_test_case(): + test_case = test_example_pb2.TestCase() + primitive = test_case.primitive.add() + message_value = primitive.message_value.add() + message_value.double_value = 23.5 + test_case.shape.append(1) + test_case.sizes.append(1) + field = test_case.field.add() + field.name = "message_value" + field.dtype = types_pb2.DT_STRING + message_value = field.expected.message_value.add() + message_value.double_value = 23.5 + return test_case + + @staticmethod + def optional_test_case(): + test_case = test_example_pb2.TestCase() + primitive = test_case.primitive.add() + primitive.bool_value.append(True) + test_case.shape.append(1) + test_case.sizes.append(1) + field = test_case.field.add() + field.name = "bool_value" + field.dtype = types_pb2.DT_BOOL + field.expected.bool_value.append(True) + test_case.sizes.append(0) + field = test_case.field.add() + field.name = "double_value" + field.dtype = types_pb2.DT_DOUBLE + field.expected.double_value.append(0.0) + return test_case + + @staticmethod + def promote_unsigned_test_case(): + test_case = test_example_pb2.TestCase() + primitive = test_case.primitive.add() + primitive.fixed32_value.append(4294967295) + primitive.uint32_value.append(4294967295) + test_case.shape.append(1) + test_case.sizes.append(1) + field = test_case.field.add() + field.name = "fixed32_value" + field.dtype = types_pb2.DT_INT64 + field.expected.int64_value.append(4294967295) + test_case.sizes.append(1) + field = test_case.field.add() + field.name = "uint32_value" + field.dtype = types_pb2.DT_INT64 + field.expected.int64_value.append(4294967295) + # Comes from an explicitly-specified default + test_case.sizes.append(0) + field = test_case.field.add() + field.name = "uint32_default" + field.dtype = types_pb2.DT_INT64 + field.expected.int64_value.append(4294967295) + return test_case + + @staticmethod + def ragged_test_case(): + test_case = test_example_pb2.TestCase() + primitive = test_case.primitive.add() + primitive.double_value.append(23.5) + primitive.double_value.append(123.0) + primitive.bool_value.append(True) + primitive = test_case.primitive.add() + primitive.double_value.append(3.1) + primitive.bool_value.append(False) + test_case.shape.append(2) + test_case.sizes.append(2) + test_case.sizes.append(1) + test_case.sizes.append(1) + test_case.sizes.append(1) + field = test_case.field.add() + field.name = "double_value" + field.dtype = types_pb2.DT_DOUBLE + field.expected.double_value.append(23.5) + field.expected.double_value.append(123.0) + field.expected.double_value.append(3.1) + field.expected.double_value.append(0.0) + field = test_case.field.add() + field.name = "bool_value" + field.dtype = types_pb2.DT_BOOL + field.expected.bool_value.append(True) + field.expected.bool_value.append(False) + return test_case + + @staticmethod + def shaped_batch_test_case(): + test_case = test_example_pb2.TestCase() + primitive = test_case.primitive.add() + primitive.double_value.append(23.5) + primitive.bool_value.append(True) + primitive = test_case.primitive.add() + primitive.double_value.append(44.0) + primitive.bool_value.append(False) + primitive = test_case.primitive.add() + primitive.double_value.append(3.14159) + primitive.bool_value.append(True) + primitive = test_case.primitive.add() + primitive.double_value.append(1.414) + primitive.bool_value.append(True) + primitive = test_case.primitive.add() + primitive.double_value.append(-32.2) + primitive.bool_value.append(False) + primitive = test_case.primitive.add() + primitive.double_value.append(0.0001) + primitive.bool_value.append(True) + test_case.shape.append(3) + test_case.shape.append(2) + for _ in range(12): + test_case.sizes.append(1) + field = test_case.field.add() + field.name = "double_value" + field.dtype = types_pb2.DT_DOUBLE + field.expected.double_value.append(23.5) + field.expected.double_value.append(44.0) + field.expected.double_value.append(3.14159) + field.expected.double_value.append(1.414) + field.expected.double_value.append(-32.2) + field.expected.double_value.append(0.0001) + field = test_case.field.add() + field.name = "bool_value" + field.dtype = types_pb2.DT_BOOL + field.expected.bool_value.append(True) + field.expected.bool_value.append(False) + field.expected.bool_value.append(True) + field.expected.bool_value.append(True) + field.expected.bool_value.append(False) + field.expected.bool_value.append(True) + return test_case + + @staticmethod + def simple_test_case(): + test_case = test_example_pb2.TestCase() + primitive = test_case.primitive.add() + primitive.double_value.append(23.5) + primitive.bool_value.append(True) + test_case.shape.append(1) + test_case.sizes.append(1) + field = test_case.field.add() + field.name = "double_value" + field.dtype = types_pb2.DT_DOUBLE + field.expected.double_value.append(23.5) + test_case.sizes.append(1) + field = test_case.field.add() + field.name = "bool_value" + field.dtype = types_pb2.DT_BOOL + field.expected.bool_value.append(True) + return test_case diff --git a/tensorflow/contrib/proto/python/kernel_tests/test_case.py b/tensorflow/contrib/proto/python/kernel_tests/test_case.py deleted file mode 100644 index b95202c5df..0000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/test_case.py +++ /dev/null @@ -1,35 +0,0 @@ -# ============================================================================= -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================= -"""Test case base for testing proto operations.""" - -# Python3 preparedness imports. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import ctypes as ct -import os - -from tensorflow.python.platform import test - - -class ProtoOpTestCase(test.TestCase): - - def __init__(self, methodName='runTest'): # pylint: disable=invalid-name - super(ProtoOpTestCase, self).__init__(methodName) - lib = os.path.join(os.path.dirname(__file__), 'libtestexample.so') - if os.path.isfile(lib): - ct.cdll.LoadLibrary(lib) |