aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/proto
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-07-11 08:59:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-11 09:03:40 -0700
commit50d121e2365a4edffe59741df19e0dcc96291576 (patch)
tree89c4063303cc2ea7e0151056d4c5e38a42727f45 /tensorflow/contrib/proto
parent1e2438318dd250132572a23458598f8c4c4d9ce5 (diff)
Improvements to testing of proto decode / encode ops.
PiperOrigin-RevId: 204132868
Diffstat (limited to 'tensorflow/contrib/proto')
-rw-r--r--tensorflow/contrib/proto/BUILD4
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/BUILD81
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl89
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py4
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py42
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/defaut_values.TestCase.pbtxt94
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py17
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt161
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt16
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt20
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt29
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt32
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt62
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt21
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/test_base.py407
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/test_case.py35
16 files changed, 481 insertions, 633 deletions
diff --git a/tensorflow/contrib/proto/BUILD b/tensorflow/contrib/proto/BUILD
index 3e9b1a0b8d..d45622174f 100644
--- a/tensorflow/contrib/proto/BUILD
+++ b/tensorflow/contrib/proto/BUILD
@@ -19,9 +19,7 @@ py_library(
py_library(
name = "proto_pip",
- data = [
- "//tensorflow/contrib/proto/python/kernel_tests:test_messages",
- ] + if_static(
+ data = if_static(
[],
otherwise = ["//tensorflow/contrib/proto/python/kernel_tests:libtestexample.so"],
),
diff --git a/tensorflow/contrib/proto/python/kernel_tests/BUILD b/tensorflow/contrib/proto/python/kernel_tests/BUILD
index a380a131f8..3f53ef1707 100644
--- a/tensorflow/contrib/proto/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/proto/python/kernel_tests/BUILD
@@ -4,33 +4,6 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
-# Much of the work in this BUILD file actually happens in the corresponding
-# build_defs.bzl, which creates an individual testcase for each example .pbtxt
-# file in this directory.
-#
-load(":build_defs.bzl", "decode_proto_test_suite")
-load(":build_defs.bzl", "encode_proto_test_suite")
-
-# This expands to a tf_py_test for each test file.
-# It defines the test_suite :decode_proto_op_tests.
-decode_proto_test_suite(
- name = "decode_proto_tests",
- examples = glob(["*.pbtxt"]),
-)
-
-# This expands to a tf_py_test for each test file.
-# It defines the test_suite :encode_proto_op_tests.
-encode_proto_test_suite(
- name = "encode_proto_tests",
- examples = glob(["*.pbtxt"]),
-)
-
-# Below here are tests that are not tied to an example text proto.
-filegroup(
- name = "test_messages",
- srcs = glob(["*.pbtxt"]),
-)
-
load("//tensorflow:tensorflow.bzl", "tf_py_test")
load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
@@ -56,16 +29,62 @@ tf_py_test(
],
)
+tf_py_test(
+ name = "decode_proto_op_test",
+ size = "small",
+ srcs = ["decode_proto_op_test.py"],
+ additional_deps = [
+ ":py_test_deps",
+ "@absl_py//absl/testing:parameterized",
+ "//third_party/py/numpy",
+ "//tensorflow/contrib/proto:proto",
+ "//tensorflow/contrib/proto/python/ops:decode_proto_op_py",
+ ],
+ data = if_static(
+ [],
+ otherwise = [":libtestexample.so"],
+ ),
+ tags = [
+ "no_pip", # TODO(b/78026780)
+ "no_windows", # TODO(b/78028010)
+ ],
+)
+
+tf_py_test(
+ name = "encode_proto_op_test",
+ size = "small",
+ srcs = ["encode_proto_op_test.py"],
+ additional_deps = [
+ ":py_test_deps",
+ "@absl_py//absl/testing:parameterized",
+ "//third_party/py/numpy",
+ "//tensorflow/contrib/proto:proto",
+ "//tensorflow/contrib/proto/python/ops:decode_proto_op_py",
+ "//tensorflow/contrib/proto/python/ops:encode_proto_op_py",
+ ],
+ data = if_static(
+ [],
+ otherwise = [":libtestexample.so"],
+ ),
+ tags = [
+ "no_pip", # TODO(b/78026780)
+ "no_windows", # TODO(b/78028010)
+ ],
+)
+
py_library(
- name = "test_case",
- srcs = ["test_case.py"],
- deps = ["//tensorflow/python:client_testlib"],
+ name = "test_base",
+ srcs = ["test_base.py"],
+ deps = [
+ ":test_example_proto_py",
+ "//tensorflow/python:client_testlib",
+ ],
)
py_library(
name = "py_test_deps",
deps = [
- ":test_case",
+ ":test_base",
":test_example_proto_py",
],
)
diff --git a/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl b/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl
deleted file mode 100644
index f425601691..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl
+++ /dev/null
@@ -1,89 +0,0 @@
-"""BUILD rules for generating file-driven proto test cases.
-
-The decode_proto_test_suite() and encode_proto_test_suite() rules take a list
-of text protos and generates a tf_py_test() for each one.
-"""
-
-load("//tensorflow:tensorflow.bzl", "tf_py_test")
-load("//tensorflow:tensorflow.bzl", "register_extension_info")
-load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
-
-def _test_name(test, path):
- return "%s_%s_test" % (test, path.split("/")[-1].split(".")[0])
-
-def decode_proto_test_suite(name, examples):
- """Build the decode_proto py_test for each test filename."""
- for test_filename in examples:
- tf_py_test(
- name = _test_name("decode_proto", test_filename),
- srcs = ["decode_proto_op_test.py"],
- size = "small",
- data = [test_filename] + if_static(
- [],
- otherwise = [":libtestexample.so"],
- ),
- main = "decode_proto_op_test.py",
- args = [
- "--message_text_file=\"%s/%s\"" % (native.package_name(), test_filename),
- ],
- additional_deps = [
- ":py_test_deps",
- "//third_party/py/numpy",
- "//tensorflow/contrib/proto:proto",
- "//tensorflow/contrib/proto/python/ops:decode_proto_op_py",
- ],
- tags = [
- "no_pip", # TODO(b/78026780)
- "no_windows", # TODO(b/78028010)
- ],
- )
- native.test_suite(
- name = name,
- tests = [":" + _test_name("decode_proto", test_filename)
- for test_filename in examples],
- )
-
-def encode_proto_test_suite(name, examples):
- """Build the encode_proto py_test for each test filename."""
- for test_filename in examples:
- tf_py_test(
- name = _test_name("encode_proto", test_filename),
- srcs = ["encode_proto_op_test.py"],
- size = "small",
- data = [test_filename] + if_static(
- [],
- otherwise = [":libtestexample.so"],
- ),
- main = "encode_proto_op_test.py",
- args = [
- "--message_text_file=\"%s/%s\"" % (native.package_name(), test_filename),
- ],
- additional_deps = [
- ":py_test_deps",
- "//third_party/py/numpy",
- "//tensorflow/contrib/proto:proto",
- "//tensorflow/contrib/proto/python/ops:decode_proto_op_py",
- "//tensorflow/contrib/proto/python/ops:encode_proto_op_py",
- ],
- tags = [
- "no_pip", # TODO(b/78026780)
- "no_windows", # TODO(b/78028010)
- ],
- )
- native.test_suite(
- name = name,
- tests = [":" + _test_name("encode_proto", test_filename)
- for test_filename in examples],
- )
-
-register_extension_info(
- extension_name = "decode_proto_test_suite",
- label_regex_map = {
- "deps": "deps:decode_example_.*",
- })
-
-register_extension_info(
- extension_name = "encode_proto_test_suite",
- label_regex_map = {
- "deps": "deps:encode_example_.*",
- })
diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py
index 5298342ee7..3b982864bc 100644
--- a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py
+++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py
@@ -21,14 +21,14 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.proto.python.kernel_tests import test_case
+from tensorflow.contrib.proto.python.kernel_tests import test_base
from tensorflow.contrib.proto.python.ops import decode_proto_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
-class DecodeProtoFailTest(test_case.ProtoOpTestCase):
+class DecodeProtoFailTest(test_base.ProtoOpTestBase):
"""Test failure cases for DecodeToProto."""
def _TestCorruptProtobuf(self, sanitize):
diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py
index d1c13c82bc..2a07794499 100644
--- a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py
+++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py
@@ -23,24 +23,20 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
+
from google.protobuf import text_format
-from tensorflow.contrib.proto.python.kernel_tests import test_case
+from tensorflow.contrib.proto.python.kernel_tests import test_base
from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2
from tensorflow.contrib.proto.python.ops import decode_proto_op
from tensorflow.python.framework import dtypes
-from tensorflow.python.platform import flags
from tensorflow.python.platform import test
-FLAGS = flags.FLAGS
-
-flags.DEFINE_string('message_text_file', None,
- 'A file containing a text serialized TestCase protobuf.')
-
-class DecodeProtoOpTest(test_case.ProtoOpTestCase):
+class DecodeProtoOpTest(test_base.ProtoOpTestBase, parameterized.TestCase):
def _compareValues(self, fd, vs, evs):
"""Compare lists/arrays of field values."""
@@ -203,10 +199,8 @@ class DecodeProtoOpTest(test_case.ProtoOpTestCase):
self._compareRepeatedPrimitiveValue(batch_shape, sizes, fields,
field_dict)
- def testBinary(self):
- with open(FLAGS.message_text_file, 'r') as fp:
- case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
-
+ @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
+ def testBinary(self, case):
batch = [primitive.SerializeToString() for primitive in case.primitive]
self._runDecodeProtoTests(
case.field,
@@ -217,10 +211,8 @@ class DecodeProtoOpTest(test_case.ProtoOpTestCase):
'binary',
sanitize=False)
- def testBinaryDisordered(self):
- with open(FLAGS.message_text_file, 'r') as fp:
- case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
-
+ @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
+ def testBinaryDisordered(self, case):
batch = [primitive.SerializeToString() for primitive in case.primitive]
self._runDecodeProtoTests(
case.field,
@@ -232,10 +224,8 @@ class DecodeProtoOpTest(test_case.ProtoOpTestCase):
sanitize=False,
force_disordered=True)
- def testPacked(self):
- with open(FLAGS.message_text_file, 'r') as fp:
- case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
-
+ @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
+ def testPacked(self, case):
# Now try with the packed serialization.
# We test the packed representations by loading the same test cases
# using PackedPrimitiveValue instead of RepeatedPrimitiveValue.
@@ -261,10 +251,8 @@ class DecodeProtoOpTest(test_case.ProtoOpTestCase):
'binary',
sanitize=False)
- def testText(self):
- with open(FLAGS.message_text_file, 'r') as fp:
- case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
-
+ @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
+ def testText(self, case):
# Note: float_format='.17g' is necessary to ensure preservation of
# doubles and floats in text format.
text_batch = [
@@ -281,10 +269,8 @@ class DecodeProtoOpTest(test_case.ProtoOpTestCase):
'text',
sanitize=False)
- def testSanitizerGood(self):
- with open(FLAGS.message_text_file, 'r') as fp:
- case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
-
+ @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
+ def testSanitizerGood(self, case):
batch = [primitive.SerializeToString() for primitive in case.primitive]
self._runDecodeProtoTests(
case.field,
diff --git a/tensorflow/contrib/proto/python/kernel_tests/defaut_values.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/defaut_values.TestCase.pbtxt
deleted file mode 100644
index 4e31681907..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/defaut_values.TestCase.pbtxt
+++ /dev/null
@@ -1,94 +0,0 @@
-primitive {
- # No fields specified, so we get all defaults
-}
-shape: 1
-sizes: 0
-field {
- name: "double_default"
- dtype: DT_DOUBLE
- expected { double_value: 1.0 }
-}
-sizes: 0
-field {
- name: "float_default"
- dtype: DT_DOUBLE # Try casting the float field to double.
- expected { double_value: 2.0 }
-}
-sizes: 0
-field {
- name: "int64_default"
- dtype: DT_INT64
- expected { int64_value: 3 }
-}
-sizes: 0
-field {
- name: "uint64_default"
- dtype: DT_INT64
- expected { int64_value: 4 }
-}
-sizes: 0
-field {
- name: "int32_default"
- dtype: DT_INT32
- expected { int32_value: 5 }
-}
-sizes: 0
-field {
- name: "fixed64_default"
- dtype: DT_INT64
- expected { int64_value: 6 }
-}
-sizes: 0
-field {
- name: "fixed32_default"
- dtype: DT_INT32
- expected { int32_value: 7 }
-}
-sizes: 0
-field {
- name: "bool_default"
- dtype: DT_BOOL
- expected { bool_value: true }
-}
-sizes: 0
-field {
- name: "string_default"
- dtype: DT_STRING
- expected { string_value: "a" }
-}
-sizes: 0
-field {
- name: "bytes_default"
- dtype: DT_STRING
- expected { string_value: "a longer default string" }
-}
-sizes: 0
-field {
- name: "uint32_default"
- dtype: DT_INT32
- expected { int32_value: -1 }
-}
-sizes: 0
-field {
- name: "sfixed32_default"
- dtype: DT_INT32
- expected { int32_value: 10 }
-}
-sizes: 0
-field {
- name: "sfixed64_default"
- dtype: DT_INT64
- expected { int64_value: 11 }
-}
-sizes: 0
-field {
- name: "sint32_default"
- dtype: DT_INT32
- expected { int32_value: 12 }
-}
-sizes: 0
-field {
- name: "sint64_default"
- dtype: DT_INT64
- expected { int64_value: 13 }
-}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py
index 30e58e6336..fb33660554 100644
--- a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py
+++ b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py
@@ -26,11 +26,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
from google.protobuf import text_format
-from tensorflow.contrib.proto.python.kernel_tests import test_case
+from tensorflow.contrib.proto.python.kernel_tests import test_base
from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2
from tensorflow.contrib.proto.python.ops import decode_proto_op
from tensorflow.contrib.proto.python.ops import encode_proto_op
@@ -45,7 +46,7 @@ flags.DEFINE_string('message_text_file', None,
'A file containing a text serialized TestCase protobuf.')
-class EncodeProtoOpTest(test_case.ProtoOpTestCase):
+class EncodeProtoOpTest(test_base.ProtoOpTestBase, parameterized.TestCase):
def testBadInputs(self):
# Invalid field name
@@ -139,10 +140,8 @@ class EncodeProtoOpTest(test_case.ProtoOpTestCase):
# loss of packing in the encoding).
self.assertEqual(in_buf, out_buf)
- def testRoundtrip(self):
- with open(FLAGS.message_text_file, 'r') as fp:
- case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
-
+ @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
+ def testRoundtrip(self, case):
in_bufs = [primitive.SerializeToString() for primitive in case.primitive]
# np.array silently truncates strings if you don't specify dtype=object.
@@ -150,10 +149,8 @@ class EncodeProtoOpTest(test_case.ProtoOpTestCase):
return self._testRoundtrip(
in_bufs, 'tensorflow.contrib.proto.RepeatedPrimitiveValue', case.field)
- def testRoundtripPacked(self):
- with open(FLAGS.message_text_file, 'r') as fp:
- case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
-
+ @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
+ def testRoundtripPacked(self, case):
# Now try with the packed serialization.
# We test the packed representations by loading the same test cases
# using PackedPrimitiveValue instead of RepeatedPrimitiveValue.
diff --git a/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt
deleted file mode 100644
index b170f89c0f..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt
+++ /dev/null
@@ -1,161 +0,0 @@
-primitive {
- double_value: -1.7976931348623158e+308
- double_value: 2.2250738585072014e-308
- double_value: 1.7976931348623158e+308
- float_value: -3.402823466e+38
- float_value: 1.175494351e-38
- float_value: 3.402823466e+38
- int64_value: -9223372036854775808
- int64_value: 9223372036854775807
- uint64_value: 0
- uint64_value: 18446744073709551615
- int32_value: -2147483648
- int32_value: 2147483647
- fixed64_value: 0
- fixed64_value: 18446744073709551615
- fixed32_value: 0
- fixed32_value: 4294967295
- bool_value: false
- bool_value: true
- string_value: ""
- string_value: "I refer to the infinite."
- uint32_value: 0
- uint32_value: 4294967295
- sfixed32_value: -2147483648
- sfixed32_value: 2147483647
- sfixed64_value: -9223372036854775808
- sfixed64_value: 9223372036854775807
- sint32_value: -2147483648
- sint32_value: 2147483647
- sint64_value: -9223372036854775808
- sint64_value: 9223372036854775807
-}
-shape: 1
-sizes: 3
-sizes: 3
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-field {
- name: "double_value"
- dtype: DT_DOUBLE
- expected {
- double_value: -1.7976931348623158e+308
- double_value: 2.2250738585072014e-308
- double_value: 1.7976931348623158e+308
- }
-}
-field {
- name: "float_value"
- dtype: DT_FLOAT
- expected {
- float_value: -3.402823466e+38
- float_value: 1.175494351e-38
- float_value: 3.402823466e+38
- }
-}
-field {
- name: "int64_value"
- dtype: DT_INT64
- expected {
- int64_value: -9223372036854775808
- int64_value: 9223372036854775807
- }
-}
-field {
- name: "uint64_value"
- dtype: DT_INT64
- expected {
- int64_value: 0
- int64_value: -1
- }
-}
-field {
- name: "int32_value"
- dtype: DT_INT32
- expected {
- int32_value: -2147483648
- int32_value: 2147483647
- }
-}
-field {
- name: "fixed64_value"
- dtype: DT_INT64
- expected {
- int64_value: 0
- int64_value: -1 # unsigned is 18446744073709551615
- }
-}
-field {
- name: "fixed32_value"
- dtype: DT_INT32
- expected {
- int32_value: 0
- int32_value: -1 # unsigned is 4294967295
- }
-}
-field {
- name: "bool_value"
- dtype: DT_BOOL
- expected {
- bool_value: false
- bool_value: true
- }
-}
-field {
- name: "string_value"
- dtype: DT_STRING
- expected {
- string_value: ""
- string_value: "I refer to the infinite."
- }
-}
-field {
- name: "uint32_value"
- dtype: DT_INT32
- expected {
- int32_value: 0
- int32_value: -1 # unsigned is 4294967295
- }
-}
-field {
- name: "sfixed32_value"
- dtype: DT_INT32
- expected {
- int32_value: -2147483648
- int32_value: 2147483647
- }
-}
-field {
- name: "sfixed64_value"
- dtype: DT_INT64
- expected {
- int64_value: -9223372036854775808
- int64_value: 9223372036854775807
- }
-}
-field {
- name: "sint32_value"
- dtype: DT_INT32
- expected {
- int32_value: -2147483648
- int32_value: 2147483647
- }
-}
-field {
- name: "sint64_value"
- dtype: DT_INT64
- expected {
- int64_value: -9223372036854775808
- int64_value: 9223372036854775807
- }
-}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt
deleted file mode 100644
index c664e52851..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt
+++ /dev/null
@@ -1,16 +0,0 @@
-primitive {
- message_value {
- double_value: 23.5
- }
-}
-shape: 1
-sizes: 1
-field {
- name: "message_value"
- dtype: DT_STRING
- expected {
- message_value {
- double_value: 23.5
- }
- }
-}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt
deleted file mode 100644
index 125651d7ea..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt
+++ /dev/null
@@ -1,20 +0,0 @@
-primitive {
- bool_value: true
-}
-shape: 1
-sizes: 1
-sizes: 0
-field {
- name: "bool_value"
- dtype: DT_BOOL
- expected {
- bool_value: true
- }
-}
-field {
- name: "double_value"
- dtype: DT_DOUBLE
- expected {
- double_value: 0.0
- }
-}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt
deleted file mode 100644
index bc07efc8f3..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt
+++ /dev/null
@@ -1,29 +0,0 @@
-primitive {
- fixed32_value: 4294967295
- uint32_value: 4294967295
-}
-shape: 1
-sizes: 1
-field {
- name: "fixed32_value"
- dtype: DT_INT64
- expected {
- int64_value: 4294967295
- }
-}
-sizes: 1
-field {
- name: "uint32_value"
- dtype: DT_INT64
- expected {
- int64_value: 4294967295
- }
-}
-sizes: 0
-field {
- name: "uint32_default"
- dtype: DT_INT64
- expected {
- int64_value: 4294967295 # Comes from an explicitly-specified default
- }
-}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt
deleted file mode 100644
index 61c7ac53f7..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt
+++ /dev/null
@@ -1,32 +0,0 @@
-primitive {
- double_value: 23.5
- double_value: 123.0
- bool_value: true
-}
-primitive {
- double_value: 3.1
- bool_value: false
-}
-shape: 2
-sizes: 2
-sizes: 1
-sizes: 1
-sizes: 1
-field {
- name: "double_value"
- dtype: DT_DOUBLE
- expected {
- double_value: 23.5
- double_value: 123.0
- double_value: 3.1
- double_value: 0.0
- }
-}
-field {
- name: "bool_value"
- dtype: DT_BOOL
- expected {
- bool_value: true
- bool_value: false
- }
-}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt
deleted file mode 100644
index f4828076d5..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt
+++ /dev/null
@@ -1,62 +0,0 @@
-primitive {
- double_value: 23.5
- bool_value: true
-}
-primitive {
- double_value: 44.0
- bool_value: false
-}
-primitive {
- double_value: 3.14159
- bool_value: true
-}
-primitive {
- double_value: 1.414
- bool_value: true
-}
-primitive {
- double_value: -32.2
- bool_value: false
-}
-primitive {
- double_value: 0.0001
- bool_value: true
-}
-shape: 3
-shape: 2
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-field {
- name: "double_value"
- dtype: DT_DOUBLE
- expected {
- double_value: 23.5
- double_value: 44.0
- double_value: 3.14159
- double_value: 1.414
- double_value: -32.2
- double_value: 0.0001
- }
-}
-field {
- name: "bool_value"
- dtype: DT_BOOL
- expected {
- bool_value: true
- bool_value: false
- bool_value: true
- bool_value: true
- bool_value: false
- bool_value: true
- }
-}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt
deleted file mode 100644
index dc20ac147b..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt
+++ /dev/null
@@ -1,21 +0,0 @@
-primitive {
- double_value: 23.5
- bool_value: true
-}
-shape: 1
-sizes: 1
-sizes: 1
-field {
- name: "double_value"
- dtype: DT_DOUBLE
- expected {
- double_value: 23.5
- }
-}
-field {
- name: "bool_value"
- dtype: DT_BOOL
- expected {
- bool_value: true
- }
-}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/test_base.py b/tensorflow/contrib/proto/python/kernel_tests/test_base.py
new file mode 100644
index 0000000000..1fc8c16786
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/test_base.py
@@ -0,0 +1,407 @@
+# =============================================================================
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Test case base for testing proto operations."""
+
+# Python3 preparedness imports.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ctypes as ct
+import os
+
+from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2
+from tensorflow.core.framework import types_pb2
+from tensorflow.python.platform import test
+
+
+class ProtoOpTestBase(test.TestCase):
+ """Base class for testing proto decoding and encoding ops."""
+
+ def __init__(self, methodName="runTest"): # pylint: disable=invalid-name
+ super(ProtoOpTestBase, self).__init__(methodName)
+ lib = os.path.join(os.path.dirname(__file__), "libtestexample.so")
+ if os.path.isfile(lib):
+ ct.cdll.LoadLibrary(lib)
+
+ @staticmethod
+ def named_parameters():
+ return (
+ ("defaults", ProtoOpTestBase.defaults_test_case()),
+ ("minmax", ProtoOpTestBase.minmax_test_case()),
+ ("nested", ProtoOpTestBase.nested_test_case()),
+ ("optional", ProtoOpTestBase.optional_test_case()),
+ ("promote_unsigned", ProtoOpTestBase.promote_unsigned_test_case()),
+ ("ragged", ProtoOpTestBase.ragged_test_case()),
+ ("shaped_batch", ProtoOpTestBase.shaped_batch_test_case()),
+ ("simple", ProtoOpTestBase.simple_test_case()),
+ )
+
+ @staticmethod
+ def defaults_test_case():
+ test_case = test_example_pb2.TestCase()
+ test_case.primitive.add() # No fields specified, so we get all defaults.
+ test_case.shape.append(1)
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "double_default"
+ field.dtype = types_pb2.DT_DOUBLE
+ field.expected.double_value.append(1.0)
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "float_default"
+ field.dtype = types_pb2.DT_FLOAT
+ field.expected.float_value.append(2.0)
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "int64_default"
+ field.dtype = types_pb2.DT_INT64
+ field.expected.int64_value.append(3)
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "sfixed64_default"
+ field.dtype = types_pb2.DT_INT64
+ field.expected.int64_value.append(11)
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "sint64_default"
+ field.dtype = types_pb2.DT_INT64
+ field.expected.int64_value.append(13)
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "uint64_default"
+ field.dtype = types_pb2.DT_INT64
+ field.expected.int64_value.append(4)
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "fixed64_default"
+ field.dtype = types_pb2.DT_INT64
+ field.expected.int64_value.append(6)
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "int32_default"
+ field.dtype = types_pb2.DT_INT32
+ field.expected.int32_value.append(5)
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "sfixed32_default"
+ field.dtype = types_pb2.DT_INT32
+ field.expected.int32_value.append(10)
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "sint32_default"
+ field.dtype = types_pb2.DT_INT32
+ field.expected.int32_value.append(12)
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "uint32_default"
+ field.dtype = types_pb2.DT_INT32
+ field.expected.int32_value.append(-1)
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "fixed32_default"
+ field.dtype = types_pb2.DT_INT32
+ field.expected.int32_value.append(7)
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "bool_default"
+ field.dtype = types_pb2.DT_BOOL
+ field.expected.bool_value.append(True)
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "string_default"
+ field.dtype = types_pb2.DT_STRING
+ field.expected.string_value.append("a")
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "bytes_default"
+ field.dtype = types_pb2.DT_STRING
+ field.expected.string_value.append("a longer default string")
+ return test_case
+
+ @staticmethod
+ def minmax_test_case():
+ test_case = test_example_pb2.TestCase()
+ primitive = test_case.primitive.add()
+ primitive.double_value.append(-1.7976931348623158e+308)
+ primitive.double_value.append(2.2250738585072014e-308)
+ primitive.double_value.append(1.7976931348623158e+308)
+ primitive.float_value.append(-3.402823466e+38)
+ primitive.float_value.append(1.175494351e-38)
+ primitive.float_value.append(3.402823466e+38)
+ primitive.int64_value.append(-9223372036854775808)
+ primitive.int64_value.append(9223372036854775807)
+ primitive.sfixed64_value.append(-9223372036854775808)
+ primitive.sfixed64_value.append(9223372036854775807)
+ primitive.sint64_value.append(-9223372036854775808)
+ primitive.sint64_value.append(9223372036854775807)
+ primitive.uint64_value.append(0)
+ primitive.uint64_value.append(18446744073709551615)
+ primitive.fixed64_value.append(0)
+ primitive.fixed64_value.append(18446744073709551615)
+ primitive.int32_value.append(-2147483648)
+ primitive.int32_value.append(2147483647)
+ primitive.sfixed32_value.append(-2147483648)
+ primitive.sfixed32_value.append(2147483647)
+ primitive.sint32_value.append(-2147483648)
+ primitive.sint32_value.append(2147483647)
+ primitive.uint32_value.append(0)
+ primitive.uint32_value.append(4294967295)
+ primitive.fixed32_value.append(0)
+ primitive.fixed32_value.append(4294967295)
+ primitive.bool_value.append(False)
+ primitive.bool_value.append(True)
+ primitive.string_value.append("")
+ primitive.string_value.append("I refer to the infinite.")
+ test_case.shape.append(1)
+ test_case.sizes.append(3)
+ field = test_case.field.add()
+ field.name = "double_value"
+ field.dtype = types_pb2.DT_DOUBLE
+ field.expected.double_value.append(-1.7976931348623158e+308)
+ field.expected.double_value.append(2.2250738585072014e-308)
+ field.expected.double_value.append(1.7976931348623158e+308)
+ test_case.sizes.append(3)
+ field = test_case.field.add()
+ field.name = "float_value"
+ field.dtype = types_pb2.DT_FLOAT
+ field.expected.float_value.append(-3.402823466e+38)
+ field.expected.float_value.append(1.175494351e-38)
+ field.expected.float_value.append(3.402823466e+38)
+ test_case.sizes.append(2)
+ field = test_case.field.add()
+ field.name = "int64_value"
+ field.dtype = types_pb2.DT_INT64
+ field.expected.int64_value.append(-9223372036854775808)
+ field.expected.int64_value.append(9223372036854775807)
+ test_case.sizes.append(2)
+ field = test_case.field.add()
+ field.name = "sfixed64_value"
+ field.dtype = types_pb2.DT_INT64
+ field.expected.int64_value.append(-9223372036854775808)
+ field.expected.int64_value.append(9223372036854775807)
+ test_case.sizes.append(2)
+ field = test_case.field.add()
+ field.name = "sint64_value"
+ field.dtype = types_pb2.DT_INT64
+ field.expected.int64_value.append(-9223372036854775808)
+ field.expected.int64_value.append(9223372036854775807)
+ test_case.sizes.append(2)
+ field = test_case.field.add()
+ field.name = "uint64_value"
+ field.dtype = types_pb2.DT_INT64
+ field.expected.int64_value.append(0)
+ field.expected.int64_value.append(-1)
+ test_case.sizes.append(2)
+ field = test_case.field.add()
+ field.name = "fixed64_value"
+ field.dtype = types_pb2.DT_INT64
+ field.expected.int64_value.append(0)
+ field.expected.int64_value.append(-1)
+ test_case.sizes.append(2)
+ field = test_case.field.add()
+ field.name = "int32_value"
+ field.dtype = types_pb2.DT_INT32
+ field.expected.int32_value.append(-2147483648)
+ field.expected.int32_value.append(2147483647)
+ test_case.sizes.append(2)
+ field = test_case.field.add()
+ field.name = "sfixed32_value"
+ field.dtype = types_pb2.DT_INT32
+ field.expected.int32_value.append(-2147483648)
+ field.expected.int32_value.append(2147483647)
+ test_case.sizes.append(2)
+ field = test_case.field.add()
+ field.name = "sint32_value"
+ field.dtype = types_pb2.DT_INT32
+ field.expected.int32_value.append(-2147483648)
+ field.expected.int32_value.append(2147483647)
+ test_case.sizes.append(2)
+ field = test_case.field.add()
+ field.name = "uint32_value"
+ field.dtype = types_pb2.DT_INT32
+ field.expected.int32_value.append(0)
+ field.expected.int32_value.append(-1)
+ test_case.sizes.append(2)
+ field = test_case.field.add()
+ field.name = "fixed32_value"
+ field.dtype = types_pb2.DT_INT32
+ field.expected.int32_value.append(0)
+ field.expected.int32_value.append(-1)
+ test_case.sizes.append(2)
+ field = test_case.field.add()
+ field.name = "bool_value"
+ field.dtype = types_pb2.DT_BOOL
+ field.expected.bool_value.append(False)
+ field.expected.bool_value.append(True)
+ test_case.sizes.append(2)
+ field = test_case.field.add()
+ field.name = "string_value"
+ field.dtype = types_pb2.DT_STRING
+ field.expected.string_value.append("")
+ field.expected.string_value.append("I refer to the infinite.")
+ return test_case
+
+ @staticmethod
+ def nested_test_case():
+ test_case = test_example_pb2.TestCase()
+ primitive = test_case.primitive.add()
+ message_value = primitive.message_value.add()
+ message_value.double_value = 23.5
+ test_case.shape.append(1)
+ test_case.sizes.append(1)
+ field = test_case.field.add()
+ field.name = "message_value"
+ field.dtype = types_pb2.DT_STRING
+ message_value = field.expected.message_value.add()
+ message_value.double_value = 23.5
+ return test_case
+
+ @staticmethod
+ def optional_test_case():
+ test_case = test_example_pb2.TestCase()
+ primitive = test_case.primitive.add()
+ primitive.bool_value.append(True)
+ test_case.shape.append(1)
+ test_case.sizes.append(1)
+ field = test_case.field.add()
+ field.name = "bool_value"
+ field.dtype = types_pb2.DT_BOOL
+ field.expected.bool_value.append(True)
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "double_value"
+ field.dtype = types_pb2.DT_DOUBLE
+ field.expected.double_value.append(0.0)
+ return test_case
+
+ @staticmethod
+ def promote_unsigned_test_case():
+ test_case = test_example_pb2.TestCase()
+ primitive = test_case.primitive.add()
+ primitive.fixed32_value.append(4294967295)
+ primitive.uint32_value.append(4294967295)
+ test_case.shape.append(1)
+ test_case.sizes.append(1)
+ field = test_case.field.add()
+ field.name = "fixed32_value"
+ field.dtype = types_pb2.DT_INT64
+ field.expected.int64_value.append(4294967295)
+ test_case.sizes.append(1)
+ field = test_case.field.add()
+ field.name = "uint32_value"
+ field.dtype = types_pb2.DT_INT64
+ field.expected.int64_value.append(4294967295)
+ # Comes from an explicitly-specified default
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "uint32_default"
+ field.dtype = types_pb2.DT_INT64
+ field.expected.int64_value.append(4294967295)
+ return test_case
+
+ @staticmethod
+ def ragged_test_case():
+ test_case = test_example_pb2.TestCase()
+ primitive = test_case.primitive.add()
+ primitive.double_value.append(23.5)
+ primitive.double_value.append(123.0)
+ primitive.bool_value.append(True)
+ primitive = test_case.primitive.add()
+ primitive.double_value.append(3.1)
+ primitive.bool_value.append(False)
+ test_case.shape.append(2)
+ test_case.sizes.append(2)
+ test_case.sizes.append(1)
+ test_case.sizes.append(1)
+ test_case.sizes.append(1)
+ field = test_case.field.add()
+ field.name = "double_value"
+ field.dtype = types_pb2.DT_DOUBLE
+ field.expected.double_value.append(23.5)
+ field.expected.double_value.append(123.0)
+ field.expected.double_value.append(3.1)
+ field.expected.double_value.append(0.0)
+ field = test_case.field.add()
+ field.name = "bool_value"
+ field.dtype = types_pb2.DT_BOOL
+ field.expected.bool_value.append(True)
+ field.expected.bool_value.append(False)
+ return test_case
+
+ @staticmethod
+ def shaped_batch_test_case():
+ test_case = test_example_pb2.TestCase()
+ primitive = test_case.primitive.add()
+ primitive.double_value.append(23.5)
+ primitive.bool_value.append(True)
+ primitive = test_case.primitive.add()
+ primitive.double_value.append(44.0)
+ primitive.bool_value.append(False)
+ primitive = test_case.primitive.add()
+ primitive.double_value.append(3.14159)
+ primitive.bool_value.append(True)
+ primitive = test_case.primitive.add()
+ primitive.double_value.append(1.414)
+ primitive.bool_value.append(True)
+ primitive = test_case.primitive.add()
+ primitive.double_value.append(-32.2)
+ primitive.bool_value.append(False)
+ primitive = test_case.primitive.add()
+ primitive.double_value.append(0.0001)
+ primitive.bool_value.append(True)
+ test_case.shape.append(3)
+ test_case.shape.append(2)
+ for _ in range(12):
+ test_case.sizes.append(1)
+ field = test_case.field.add()
+ field.name = "double_value"
+ field.dtype = types_pb2.DT_DOUBLE
+ field.expected.double_value.append(23.5)
+ field.expected.double_value.append(44.0)
+ field.expected.double_value.append(3.14159)
+ field.expected.double_value.append(1.414)
+ field.expected.double_value.append(-32.2)
+ field.expected.double_value.append(0.0001)
+ field = test_case.field.add()
+ field.name = "bool_value"
+ field.dtype = types_pb2.DT_BOOL
+ field.expected.bool_value.append(True)
+ field.expected.bool_value.append(False)
+ field.expected.bool_value.append(True)
+ field.expected.bool_value.append(True)
+ field.expected.bool_value.append(False)
+ field.expected.bool_value.append(True)
+ return test_case
+
+ @staticmethod
+ def simple_test_case():
+ test_case = test_example_pb2.TestCase()
+ primitive = test_case.primitive.add()
+ primitive.double_value.append(23.5)
+ primitive.bool_value.append(True)
+ test_case.shape.append(1)
+ test_case.sizes.append(1)
+ field = test_case.field.add()
+ field.name = "double_value"
+ field.dtype = types_pb2.DT_DOUBLE
+ field.expected.double_value.append(23.5)
+ test_case.sizes.append(1)
+ field = test_case.field.add()
+ field.name = "bool_value"
+ field.dtype = types_pb2.DT_BOOL
+ field.expected.bool_value.append(True)
+ return test_case
diff --git a/tensorflow/contrib/proto/python/kernel_tests/test_case.py b/tensorflow/contrib/proto/python/kernel_tests/test_case.py
deleted file mode 100644
index b95202c5df..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/test_case.py
+++ /dev/null
@@ -1,35 +0,0 @@
-# =============================================================================
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# =============================================================================
-"""Test case base for testing proto operations."""
-
-# Python3 preparedness imports.
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import ctypes as ct
-import os
-
-from tensorflow.python.platform import test
-
-
-class ProtoOpTestCase(test.TestCase):
-
- def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
- super(ProtoOpTestCase, self).__init__(methodName)
- lib = os.path.join(os.path.dirname(__file__), 'libtestexample.so')
- if os.path.isfile(lib):
- ct.cdll.LoadLibrary(lib)