aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/parsing_ops_test.py
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2016-08-04 16:32:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-04 17:46:33 -0700
commitd2589654ddc7bffc2f51eb5bebba2eb78f48a9a2 (patch)
treede82106f989305995b2d851a983a87bb7050e794 /tensorflow/python/kernel_tests/parsing_ops_test.py
parent6982c512b6491ff1424e67f16bc1ae68432c11ce (diff)
Fix tf.Example parsing when the Example feature name is provided, but no
value exists. In this case, if no default is given then the error should be correct. If a default *is* provided, it should be used. Also, reformatted the parsing_ops_test.py file (which doesn't lose history, since it hasn't changed much since my original version). Change: 129393762
Diffstat (limited to 'tensorflow/python/kernel_tests/parsing_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/parsing_ops_test.py729
1 files changed, 421 insertions, 308 deletions
diff --git a/tensorflow/python/kernel_tests/parsing_ops_test.py b/tensorflow/python/kernel_tests/parsing_ops_test.py
index 58f6da9f97..52d3c0dde1 100644
--- a/tensorflow/python/kernel_tests/parsing_ops_test.py
+++ b/tensorflow/python/kernel_tests/parsing_ops_test.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Tests for tensorflow.ops.parsing_ops."""
from __future__ import absolute_import
@@ -46,13 +45,13 @@ def flatten(list_of_lists):
def flatten_values_tensors_or_sparse(tensors_list):
"""Flatten each SparseTensor object into 3 Tensors for session.run()."""
- return list(flatten([[v.indices, v.values, v.shape]
- if isinstance(v, tf.SparseTensor) else [v]
- for v in tensors_list]))
+ return list(
+ flatten([[v.indices, v.values, v.shape] if isinstance(v, tf.SparseTensor)
+ else [v] for v in tensors_list]))
-def _compare_output_to_expected(
- tester, dict_tensors, expected_tensors, flat_output):
+def _compare_output_to_expected(tester, dict_tensors, expected_tensors,
+ flat_output):
tester.assertEqual(set(dict_tensors.keys()), set(expected_tensors.keys()))
i = 0 # Index into the flattened output of session.run()
@@ -74,11 +73,11 @@ def _compare_output_to_expected(
class ParseExampleTest(tf.test.TestCase):
- def _test(
- self, kwargs, expected_values=None, expected_err=None):
+ def _test(self, kwargs, expected_values=None, expected_err=None):
with self.test_session() as sess:
if expected_err:
- with self.assertRaisesRegexp(expected_err[0], expected_err[1]):
+ with self.assertRaisesWithPredicateMatch(
+ expected_err[0], expected_err[1]):
out = tf.parse_example(**kwargs)
sess.run(flatten_values_tensors_or_sparse(out.values()))
else:
@@ -92,9 +91,8 @@ class ParseExampleTest(tf.test.TestCase):
# Check shapes; if serialized is a Tensor we need its size to
# properly check.
serialized = kwargs["serialized"]
- batch_size = (
- serialized.eval().size if isinstance(serialized, tf.Tensor)
- else np.asarray(serialized).size)
+ batch_size = (serialized.eval().size if isinstance(serialized, tf.Tensor)
+ else np.asarray(serialized).size)
for k, f in kwargs["features"].items():
if isinstance(f, tf.FixedLenFeature) and f.shape is not None:
self.assertEqual(
@@ -115,9 +113,12 @@ class ParseExampleTest(tf.test.TestCase):
c_default = np.random.rand(2).astype(np.float32)
expected_st_a = ( # indices, values, shape
- np.empty((0, 2), dtype=np.int64), # indices
- np.empty((0,), dtype=np.int64), # sp_a is DT_INT64
- np.array([2, 0], dtype=np.int64)) # batch == 2, max_elems = 0
+ np.empty(
+ (0, 2), dtype=np.int64), # indices
+ np.empty(
+ (0,), dtype=np.int64), # sp_a is DT_INT64
+ np.array(
+ [2, 0], dtype=np.int64)) # batch == 2, max_elems = 0
expected_output = {
sparse_name: expected_st_a,
@@ -126,38 +127,63 @@ class ParseExampleTest(tf.test.TestCase):
c_name: np.array(2 * [c_default]),
}
- self._test({
- "example_names": np.empty((0,), dtype=bytes),
- "serialized": tf.convert_to_tensor(["", ""]),
- "features": {
- sparse_name: tf.VarLenFeature(tf.int64),
- a_name: tf.FixedLenFeature((1, 3), tf.int64, default_value=a_default),
- b_name: tf.FixedLenFeature((3, 3), tf.string, default_value=b_default),
- c_name: tf.FixedLenFeature((2,), tf.float32, default_value=c_default),
- }
- }, expected_output)
+ self._test(
+ {
+ "example_names": np.empty(
+ (0,), dtype=bytes),
+ "serialized": tf.convert_to_tensor(["", ""]),
+ "features": {
+ sparse_name: tf.VarLenFeature(tf.int64),
+ a_name: tf.FixedLenFeature(
+ (1, 3), tf.int64, default_value=a_default),
+ b_name: tf.FixedLenFeature(
+ (3, 3), tf.string, default_value=b_default),
+ c_name: tf.FixedLenFeature(
+ (2,), tf.float32, default_value=c_default),
+ }
+ },
+ expected_output)
def testEmptySerializedWithoutDefaultsShouldFail(self):
- self._test({
- "example_names": ["in1", "in2"],
- "serialized": ["", ""],
- "features": {
- "st_a": tf.VarLenFeature(tf.int64),
- "a": tf.FixedLenFeature((1, 3), tf.int64, default_value=[0, 42, 0]),
- "b": tf.FixedLenFeature(
- (3, 3), tf.string,
- default_value=np.random.rand(3, 3).astype(bytes)),
- # Feature "c" is missing a default, this gap will cause failure.
- "c": tf.FixedLenFeature((2,), dtype=tf.float32),
- }
- }, expected_err=(tf.OpError, "Name: in1, Feature: c is required"))
+ input_features = {
+ "st_a": tf.VarLenFeature(tf.int64),
+ "a": tf.FixedLenFeature(
+ (1, 3), tf.int64, default_value=[0, 42, 0]),
+ "b": tf.FixedLenFeature(
+ (3, 3),
+ tf.string,
+ default_value=np.random.rand(3, 3).astype(bytes)),
+ # Feature "c" is missing a default, this gap will cause failure.
+ "c": tf.FixedLenFeature(
+ (2,), dtype=tf.float32),
+ }
+
+ # Edge case where the key is there but the feature value is empty
+ original = example(features=features({
+ "c": feature()
+ }))
+ self._test(
+ {
+ "example_names": ["in1"],
+ "serialized": [original.SerializeToString()],
+ "features": input_features,
+ },
+ expected_err=(tf.OpError, "Name: in1, Feature: c is required"))
+
+ # Standard case of missing key and value.
+ self._test(
+ {
+ "example_names": ["in1", "in2"],
+ "serialized": ["", ""],
+ "features": input_features,
+ },
+ expected_err=(tf.OpError, "Name: in1, Feature: c is required"))
def testDenseNotMatchingShapeShouldFail(self):
original = [
example(features=features({
"a": float_feature([1, 1, 3]),
- })),
- example(features=features({
+ })), example(features=features({
"a": float_feature([-1, -1]),
}))
]
@@ -165,27 +191,27 @@ class ParseExampleTest(tf.test.TestCase):
names = ["passing", "failing"]
serialized = [m.SerializeToString() for m in original]
- self._test({
- "example_names": names,
- "serialized": tf.convert_to_tensor(serialized),
- "features": {"a": tf.FixedLenFeature((1, 3), tf.float32)}
- }, expected_err=(
- tf.OpError, "Name: failing, Key: a, Index: 1. Number of float val"))
+ self._test(
+ {
+ "example_names": names,
+ "serialized": tf.convert_to_tensor(serialized),
+ "features": {"a": tf.FixedLenFeature((1, 3), tf.float32)}
+ },
+ expected_err=(tf.OpError,
+ "Name: failing, Key: a, Index: 1. Number of float val"))
def testDenseDefaultNoShapeShouldFail(self):
- original = [
- example(features=features({
- "a": float_feature([1, 1, 3]),
- })),
- ]
+ original = [example(features=features({"a": float_feature([1, 1, 3]),})),]
serialized = [m.SerializeToString() for m in original]
- self._test({
- "example_names": ["failing"],
- "serialized": tf.convert_to_tensor(serialized),
- "features": {"a": tf.FixedLenFeature(None, tf.float32)}
- }, expected_err=(ValueError, "Missing shape for feature a"))
+ self._test(
+ {
+ "example_names": ["failing"],
+ "serialized": tf.convert_to_tensor(serialized),
+ "features": {"a": tf.FixedLenFeature(None, tf.float32)}
+ },
+ expected_err=(ValueError, "Missing shape for feature a"))
def testSerializedContainingSparse(self):
original = [
@@ -207,14 +233,16 @@ class ParseExampleTest(tf.test.TestCase):
serialized = [m.SerializeToString() for m in original]
expected_st_c = ( # indices, values, shape
- np.array([[0, 0], [0, 1], [3, 0], [3, 1], [3, 2]], dtype=np.int64),
- np.array([3.0, 4.0, 1.0, 2.0, -1.0], dtype=np.float32),
- np.array([4, 3], dtype=np.int64)) # batch == 2, max_elems = 3
+ np.array(
+ [[0, 0], [0, 1], [3, 0], [3, 1], [3, 2]], dtype=np.int64), np.array(
+ [3.0, 4.0, 1.0, 2.0, -1.0], dtype=np.float32), np.array(
+ [4, 3], dtype=np.int64)) # batch == 2, max_elems = 3
expected_st_d = ( # indices, values, shape
- np.array([[3, 0]], dtype=np.int64),
- np.array(["hi"], dtype=bytes),
- np.array([4, 1], dtype=np.int64)) # batch == 2, max_elems = 1
+ np.array(
+ [[3, 0]], dtype=np.int64), np.array(
+ ["hi"], dtype=bytes), np.array(
+ [4, 1], dtype=np.int64)) # batch == 2, max_elems = 1
expected_output = {
"st_c": expected_st_c,
@@ -236,8 +264,7 @@ class ParseExampleTest(tf.test.TestCase):
example(features=features({
aname: float_feature([1, 1]),
bname: bytes_feature([b"b0_str"]),
- })),
- example(features=features({
+ })), example(features=features({
aname: float_feature([-1, -1]),
bname: bytes_feature([b"b1"]),
}))
@@ -248,24 +275,28 @@ class ParseExampleTest(tf.test.TestCase):
expected_output = {
aname: np.array(
[[1, 1], [-1, -1]], dtype=np.float32).reshape(2, 1, 2, 1),
- bname: np.array(["b0_str", "b1"], dtype=bytes).reshape(2, 1, 1, 1, 1),
+ bname: np.array(
+ ["b0_str", "b1"], dtype=bytes).reshape(2, 1, 1, 1, 1),
}
# No defaults, values required
- self._test({
- "serialized": tf.convert_to_tensor(serialized),
- "features": {
- aname: tf.FixedLenFeature((1, 2, 1), dtype=tf.float32),
- bname: tf.FixedLenFeature((1, 1, 1, 1), dtype=tf.string),
- }
- }, expected_output)
+ self._test(
+ {
+ "serialized": tf.convert_to_tensor(serialized),
+ "features": {
+ aname: tf.FixedLenFeature(
+ (1, 2, 1), dtype=tf.float32),
+ bname: tf.FixedLenFeature(
+ (1, 1, 1, 1), dtype=tf.string),
+ }
+ },
+ expected_output)
def testSerializedContainingDenseScalar(self):
original = [
example(features=features({
"a": float_feature([1]),
- })),
- example(features=features({}))
+ })), example(features=features({}))
]
serialized = [m.SerializeToString() for m in original]
@@ -274,12 +305,15 @@ class ParseExampleTest(tf.test.TestCase):
"a": np.array([[1], [-1]], dtype=np.float32) # 2x1 (column vector)
}
- self._test({
- "serialized": tf.convert_to_tensor(serialized),
- "features": {
- "a": tf.FixedLenFeature((1,), dtype=tf.float32, default_value=-1),
- }
- }, expected_output)
+ self._test(
+ {
+ "serialized": tf.convert_to_tensor(serialized),
+ "features": {
+ "a": tf.FixedLenFeature(
+ (1,), dtype=tf.float32, default_value=-1),
+ }
+ },
+ expected_output)
def testSerializedContainingDenseWithDefaults(self):
original = [
@@ -288,37 +322,46 @@ class ParseExampleTest(tf.test.TestCase):
})),
example(features=features({
"b": bytes_feature([b"b1"]),
- }))
+ })),
+ example(features=features({
+ "b": feature()
+ })),
]
serialized = [m.SerializeToString() for m in original]
expected_output = {
- "a": np.array([[1, 1], [3, -3]], dtype=np.float32).reshape(2, 1, 2, 1),
- "b": np.array(["tmp_str", "b1"], dtype=bytes).reshape(2, 1, 1, 1, 1),
+ "a": np.array(
+ [[1, 1], [3, -3], [3, -3]], dtype=np.float32).reshape(3, 1, 2, 1),
+ "b": np.array(
+ ["tmp_str", "b1", "tmp_str"], dtype=bytes).reshape(3, 1, 1, 1, 1),
}
- self._test({
- "serialized": tf.convert_to_tensor(serialized),
- "features": {
- "a": tf.FixedLenFeature(
- (1, 2, 1), dtype=tf.float32, default_value=[3.0, -3.0]),
- "b": tf.FixedLenFeature(
- (1, 1, 1, 1), dtype=tf.string, default_value="tmp_str"),
- }
- }, expected_output)
+ self._test(
+ {
+ "serialized": tf.convert_to_tensor(serialized),
+ "features": {
+ "a": tf.FixedLenFeature(
+ (1, 2, 1), dtype=tf.float32, default_value=[3.0, -3.0]),
+ "b": tf.FixedLenFeature(
+ (1, 1, 1, 1), dtype=tf.string, default_value="tmp_str"),
+ }
+ },
+ expected_output)
def testSerializedContainingSparseAndDenseWithNoDefault(self):
expected_st_a = ( # indices, values, shape
- np.empty((0, 2), dtype=np.int64), # indices
- np.empty((0,), dtype=np.int64), # sp_a is DT_INT64
- np.array([2, 0], dtype=np.int64)) # batch == 2, max_elems = 0
+ np.empty(
+ (0, 2), dtype=np.int64), # indices
+ np.empty(
+ (0,), dtype=np.int64), # sp_a is DT_INT64
+ np.array(
+ [2, 0], dtype=np.int64)) # batch == 2, max_elems = 0
original = [
example(features=features({
"c": float_feature([3, 4])
- })),
- example(features=features({
+ })), example(features=features({
"c": float_feature([1, 2])
}))
]
@@ -332,20 +375,25 @@ class ParseExampleTest(tf.test.TestCase):
"st_a": expected_st_a,
"a": np.array(2 * [[a_default]]),
"b": np.array(2 * [b_default]),
- "c": np.array([[3, 4], [1, 2]], dtype=np.float32),
+ "c": np.array(
+ [[3, 4], [1, 2]], dtype=np.float32),
}
- self._test({
- "example_names": names,
- "serialized": tf.convert_to_tensor(serialized),
- "features": {
- "st_a": tf.VarLenFeature(tf.int64),
- "a": tf.FixedLenFeature((1, 3), tf.int64, default_value=a_default),
- "b": tf.FixedLenFeature((3, 3), tf.string, default_value=b_default),
- # Feature "c" must be provided, since it has no default_value.
- "c": tf.FixedLenFeature((2,), tf.float32),
- }
- }, expected_output)
+ self._test(
+ {
+ "example_names": names,
+ "serialized": tf.convert_to_tensor(serialized),
+ "features": {
+ "st_a": tf.VarLenFeature(tf.int64),
+ "a": tf.FixedLenFeature(
+ (1, 3), tf.int64, default_value=a_default),
+ "b": tf.FixedLenFeature(
+ (3, 3), tf.string, default_value=b_default),
+ # Feature "c" must be provided, since it has no default_value.
+ "c": tf.FixedLenFeature((2,), tf.float32),
+ }
+ },
+ expected_output)
class ParseSingleExampleTest(tf.test.TestCase):
@@ -353,7 +401,8 @@ class ParseSingleExampleTest(tf.test.TestCase):
def _test(self, kwargs, expected_values=None, expected_err=None):
with self.test_session() as sess:
if expected_err:
- with self.assertRaisesRegexp(expected_err[0], expected_err[1]):
+ with self.assertRaisesWithPredicateMatch(
+ expected_err[0], expected_err[1]):
out = tf.parse_single_example(**kwargs)
sess.run(flatten_values_tensors_or_sparse(out.values()))
else:
@@ -374,16 +423,17 @@ class ParseSingleExampleTest(tf.test.TestCase):
self.assertEqual(tuple(out[k].shape.get_shape().as_list()), (1,))
def testSingleExampleWithSparseAndDense(self):
- original = example(features=features(
- {"c": float_feature([3, 4]),
- "st_a": float_feature([3.0, 4.0])}))
+ original = example(features=features({"c": float_feature([3, 4]),
+ "st_a": float_feature([3.0, 4.0])}))
serialized = original.SerializeToString()
- expected_st_a = (
- np.array([[0], [1]], dtype=np.int64), # indices
- np.array([3.0, 4.0], dtype=np.float32), # values
- np.array([2], dtype=np.int64)) # shape: max_values = 2
+ expected_st_a = (np.array(
+ [[0], [1]], dtype=np.int64), # indices
+ np.array(
+ [3.0, 4.0], dtype=np.float32), # values
+ np.array(
+ [2], dtype=np.int64)) # shape: max_values = 2
a_default = [1, 2, 3]
b_default = np.random.rand(3, 3).astype(bytes)
@@ -391,20 +441,25 @@ class ParseSingleExampleTest(tf.test.TestCase):
"st_a": expected_st_a,
"a": [a_default],
"b": b_default,
- "c": np.array([3, 4], dtype=np.float32),
+ "c": np.array(
+ [3, 4], dtype=np.float32),
}
- self._test({
- "example_names": tf.convert_to_tensor("in1"),
- "serialized": tf.convert_to_tensor(serialized),
- "features": {
- "st_a": tf.VarLenFeature(tf.float32),
- "a": tf.FixedLenFeature((1, 3), tf.int64, default_value=a_default),
- "b": tf.FixedLenFeature((3, 3), tf.string, default_value=b_default),
- # Feature "c" must be provided, since it has no default_value.
- "c": tf.FixedLenFeature((2,), tf.float32),
- }
- }, expected_output)
+ self._test(
+ {
+ "example_names": tf.convert_to_tensor("in1"),
+ "serialized": tf.convert_to_tensor(serialized),
+ "features": {
+ "st_a": tf.VarLenFeature(tf.float32),
+ "a": tf.FixedLenFeature(
+ (1, 3), tf.int64, default_value=a_default),
+ "b": tf.FixedLenFeature(
+ (3, 3), tf.string, default_value=b_default),
+ # Feature "c" must be provided, since it has no default_value.
+ "c": tf.FixedLenFeature((2,), tf.float32),
+ }
+ },
+ expected_output)
class ParseSequenceExampleTest(tf.test.TestCase):
@@ -413,26 +468,31 @@ class ParseSequenceExampleTest(tf.test.TestCase):
value = sequence_example(
context=features({
"global_feature": float_feature([1, 2, 3]),
- }),
+ }),
feature_lists=feature_lists({
"repeated_feature_2_frames": feature_list([
bytes_feature([b"a", b"b", b"c"]),
- bytes_feature([b"a", b"d", b"e"])]),
+ bytes_feature([b"a", b"d", b"e"])
+ ]),
"repeated_feature_3_frames": feature_list([
- int64_feature([3, 4, 5, 6, 7]),
- int64_feature([-1, 0, 0, 0, 0]),
- int64_feature([1, 2, 3, 4, 5])])
- }))
+ int64_feature([3, 4, 5, 6, 7]), int64_feature([-1, 0, 0, 0, 0]),
+ int64_feature([1, 2, 3, 4, 5])
+ ])
+ }))
value.SerializeToString() # Smoke test
- def _test(self, kwargs, expected_context_values=None,
- expected_feat_list_values=None, expected_err=None):
+ def _test(self,
+ kwargs,
+ expected_context_values=None,
+ expected_feat_list_values=None,
+ expected_err=None):
expected_context_values = expected_context_values or {}
expected_feat_list_values = expected_feat_list_values or {}
with self.test_session() as sess:
if expected_err:
- with self.assertRaisesRegexp(expected_err[0], expected_err[1]):
+ with self.assertRaisesWithPredicateMatch(
+ expected_err[0], expected_err[1]):
c_out, fl_out = tf.parse_single_sequence_example(**kwargs)
if c_out:
sess.run(flatten_values_tensors_or_sparse(c_out.values()))
@@ -442,16 +502,16 @@ class ParseSequenceExampleTest(tf.test.TestCase):
# Returns dicts w/ Tensors and SparseTensors.
context_out, feat_list_out = tf.parse_single_sequence_example(**kwargs)
context_result = sess.run(
- flatten_values_tensors_or_sparse(
- context_out.values())) if context_out else []
+ flatten_values_tensors_or_sparse(context_out.values(
+ ))) if context_out else []
feat_list_result = sess.run(
- flatten_values_tensors_or_sparse(
- feat_list_out.values())) if feat_list_out else []
+ flatten_values_tensors_or_sparse(feat_list_out.values(
+ ))) if feat_list_out else []
# Check values.
- _compare_output_to_expected(
- self, context_out, expected_context_values, context_result)
- _compare_output_to_expected(
- self, feat_list_out, expected_feat_list_values, feat_list_result)
+ _compare_output_to_expected(self, context_out, expected_context_values,
+ context_result)
+ _compare_output_to_expected(self, feat_list_out,
+ expected_feat_list_values, feat_list_result)
# Check shapes; if serialized is a Tensor we need its size to
# properly check.
@@ -469,16 +529,18 @@ class ParseSequenceExampleTest(tf.test.TestCase):
tuple(context_out[k].shape.get_shape().as_list()), (1,))
def testSequenceExampleWithSparseAndDenseContext(self):
- original = sequence_example(context=features(
- {"c": float_feature([3, 4]),
- "st_a": float_feature([3.0, 4.0])}))
+ original = sequence_example(context=features({"c": float_feature([3, 4]),
+ "st_a": float_feature(
+ [3.0, 4.0])}))
serialized = original.SerializeToString()
- expected_st_a = (
- np.array([[0], [1]], dtype=np.int64), # indices
- np.array([3.0, 4.0], dtype=np.float32), # values
- np.array([2], dtype=np.int64)) # shape: num_features = 2
+ expected_st_a = (np.array(
+ [[0], [1]], dtype=np.int64), # indices
+ np.array(
+ [3.0, 4.0], dtype=np.float32), # values
+ np.array(
+ [2], dtype=np.int64)) # shape: num_features = 2
a_default = [1, 2, 3]
b_default = np.random.rand(3, 3).astype(bytes)
@@ -486,20 +548,25 @@ class ParseSequenceExampleTest(tf.test.TestCase):
"st_a": expected_st_a,
"a": [a_default],
"b": b_default,
- "c": np.array([3, 4], dtype=np.float32),
+ "c": np.array(
+ [3, 4], dtype=np.float32),
}
- self._test({
- "example_name": "in1",
- "serialized": tf.convert_to_tensor(serialized),
- "context_features": {
- "st_a": tf.VarLenFeature(tf.float32),
- "a": tf.FixedLenFeature((1, 3), tf.int64, default_value=a_default),
- "b": tf.FixedLenFeature((3, 3), tf.string, default_value=b_default),
- # Feature "c" must be provided, since it has no default_value.
- "c": tf.FixedLenFeature((2,), tf.float32),
- }
- }, expected_context_values=expected_context_output)
+ self._test(
+ {
+ "example_name": "in1",
+ "serialized": tf.convert_to_tensor(serialized),
+ "context_features": {
+ "st_a": tf.VarLenFeature(tf.float32),
+ "a": tf.FixedLenFeature(
+ (1, 3), tf.int64, default_value=a_default),
+ "b": tf.FixedLenFeature(
+ (3, 3), tf.string, default_value=b_default),
+ # Feature "c" must be provided, since it has no default_value.
+ "c": tf.FixedLenFeature((2,), tf.float32),
+ }
+ },
+ expected_context_values=expected_context_output)
def testSequenceExampleWithMultipleSizeFeatureLists(self):
original = sequence_example(feature_lists=feature_lists({
@@ -507,229 +574,274 @@ class ParseSequenceExampleTest(tf.test.TestCase):
int64_feature([-1, 0, 1]),
int64_feature([2, 3, 4]),
int64_feature([5, 6, 7]),
- int64_feature([8, 9, 10]),]),
+ int64_feature([8, 9, 10]),
+ ]),
"b": feature_list([
- bytes_feature([b"r00", b"r01", b"r10", b"r11"])]),
+ bytes_feature([b"r00", b"r01", b"r10", b"r11"])
+ ]),
"c": feature_list([
- float_feature([3, 4]),
- float_feature([-1, 2])]),
- }))
+ float_feature([3, 4]), float_feature([-1, 2])
+ ]),
+ }))
serialized = original.SerializeToString()
expected_feature_list_output = {
- "a": np.array([ # outer dimension is time.
- [[-1, 0, 1]], # inside are 1x3 matrices
- [[2, 3, 4]],
- [[5, 6, 7]],
- [[8, 9, 10]]], dtype=np.int64),
- "b": np.array([ # outer dimension is time, inside are 2x2 matrices
- [[b"r00", b"r01"], [b"r10", b"r11"]]], dtype=bytes),
- "c": np.array([ # outer dimension is time, inside are 2-vectors
- [3, 4],
- [-1, 2]], dtype=np.float32),
- "d": np.empty(shape=(0, 5), dtype=np.float32), # empty_allowed_missing
- }
+ "a": np.array(
+ [ # outer dimension is time.
+ [[-1, 0, 1]], # inside are 1x3 matrices
+ [[2, 3, 4]],
+ [[5, 6, 7]],
+ [[8, 9, 10]]
+ ],
+ dtype=np.int64),
+ "b": np.array(
+ [ # outer dimension is time, inside are 2x2 matrices
+ [[b"r00", b"r01"], [b"r10", b"r11"]]
+ ],
+ dtype=bytes),
+ "c": np.array(
+ [ # outer dimension is time, inside are 2-vectors
+ [3, 4], [-1, 2]
+ ],
+ dtype=np.float32),
+ "d": np.empty(
+ shape=(0, 5), dtype=np.float32), # empty_allowed_missing
+ }
- self._test({
- "example_name": "in1",
- "serialized": tf.convert_to_tensor(serialized),
- "sequence_features": {
- "a": tf.FixedLenSequenceFeature((1, 3), tf.int64),
- "b": tf.FixedLenSequenceFeature((2, 2), tf.string),
- "c": tf.FixedLenSequenceFeature((2,), tf.float32),
- "d": tf.FixedLenSequenceFeature((5,), tf.float32, allow_missing=True),
- }
- }, expected_feat_list_values=expected_feature_list_output)
+ self._test(
+ {
+ "example_name": "in1",
+ "serialized": tf.convert_to_tensor(serialized),
+ "sequence_features": {
+ "a": tf.FixedLenSequenceFeature((1, 3), tf.int64),
+ "b": tf.FixedLenSequenceFeature((2, 2), tf.string),
+ "c": tf.FixedLenSequenceFeature((2,), tf.float32),
+ "d": tf.FixedLenSequenceFeature(
+ (5,), tf.float32, allow_missing=True),
+ }
+ },
+ expected_feat_list_values=expected_feature_list_output)
def testSequenceExampleWithoutDebugName(self):
original = sequence_example(feature_lists=feature_lists({
"a": feature_list([
- int64_feature([3, 4]),
- int64_feature([1, 0])]),
+ int64_feature([3, 4]), int64_feature([1, 0])
+ ]),
"st_a": feature_list([
- float_feature([3.0, 4.0]),
- float_feature([5.0]),
- float_feature([])]),
+ float_feature([3.0, 4.0]), float_feature([5.0]), float_feature([])
+ ]),
"st_b": feature_list([
- bytes_feature([b"a"]),
- bytes_feature([]),
- bytes_feature([]),
- bytes_feature([b"b", b"c"])])}))
+ bytes_feature([b"a"]), bytes_feature([]), bytes_feature([]),
+ bytes_feature([b"b", b"c"])
+ ])
+ }))
serialized = original.SerializeToString()
expected_st_a = (
- np.array([[0, 0], [0, 1], [1, 0]], dtype=np.int64), # indices
- np.array([3.0, 4.0, 5.0], dtype=np.float32), # values
- np.array([3, 2], dtype=np.int64)) # shape: num_time = 3, max_feat = 2
+ np.array(
+ [[0, 0], [0, 1], [1, 0]], dtype=np.int64), # indices
+ np.array(
+ [3.0, 4.0, 5.0], dtype=np.float32), # values
+ np.array(
+ [3, 2], dtype=np.int64)) # shape: num_time = 3, max_feat = 2
expected_st_b = (
- np.array([[0, 0], [3, 0], [3, 1]], dtype=np.int64), # indices
- np.array(["a", "b", "c"], dtype="|S"), # values
- np.array([4, 2], dtype=np.int64)) # shape: num_time = 4, max_feat = 2
+ np.array(
+ [[0, 0], [3, 0], [3, 1]], dtype=np.int64), # indices
+ np.array(
+ ["a", "b", "c"], dtype="|S"), # values
+ np.array(
+ [4, 2], dtype=np.int64)) # shape: num_time = 4, max_feat = 2
expected_st_c = (
- np.empty((0, 2), dtype=np.int64), # indices
- np.empty((0,), dtype=np.int64), # values
- np.array([0, 0], dtype=np.int64)) # shape: num_time = 0, max_feat = 0
+ np.empty(
+ (0, 2), dtype=np.int64), # indices
+ np.empty(
+ (0,), dtype=np.int64), # values
+ np.array(
+ [0, 0], dtype=np.int64)) # shape: num_time = 0, max_feat = 0
expected_feature_list_output = {
- "a": np.array([[3, 4], [1, 0]], dtype=np.int64),
+ "a": np.array(
+ [[3, 4], [1, 0]], dtype=np.int64),
"st_a": expected_st_a,
"st_b": expected_st_b,
"st_c": expected_st_c,
}
- self._test({
- "serialized": tf.convert_to_tensor(serialized),
- "sequence_features": {
- "st_a": tf.VarLenFeature(tf.float32),
- "st_b": tf.VarLenFeature(tf.string),
- "st_c": tf.VarLenFeature(tf.int64),
- "a": tf.FixedLenSequenceFeature((2,), tf.int64),
- }
- }, expected_feat_list_values=expected_feature_list_output)
+ self._test(
+ {
+ "serialized": tf.convert_to_tensor(serialized),
+ "sequence_features": {
+ "st_a": tf.VarLenFeature(tf.float32),
+ "st_b": tf.VarLenFeature(tf.string),
+ "st_c": tf.VarLenFeature(tf.int64),
+ "a": tf.FixedLenSequenceFeature((2,), tf.int64),
+ }
+ },
+ expected_feat_list_values=expected_feature_list_output)
def testSequenceExampleWithSparseAndDenseFeatureLists(self):
original = sequence_example(feature_lists=feature_lists({
"a": feature_list([
- int64_feature([3, 4]),
- int64_feature([1, 0])]),
+ int64_feature([3, 4]), int64_feature([1, 0])
+ ]),
"st_a": feature_list([
- float_feature([3.0, 4.0]),
- float_feature([5.0]),
- float_feature([])]),
+ float_feature([3.0, 4.0]), float_feature([5.0]), float_feature([])
+ ]),
"st_b": feature_list([
- bytes_feature([b"a"]),
- bytes_feature([]),
- bytes_feature([]),
- bytes_feature([b"b", b"c"])])}))
+ bytes_feature([b"a"]), bytes_feature([]), bytes_feature([]),
+ bytes_feature([b"b", b"c"])
+ ])
+ }))
serialized = original.SerializeToString()
expected_st_a = (
- np.array([[0, 0], [0, 1], [1, 0]], dtype=np.int64), # indices
- np.array([3.0, 4.0, 5.0], dtype=np.float32), # values
- np.array([3, 2], dtype=np.int64)) # shape: num_time = 3, max_feat = 2
+ np.array(
+ [[0, 0], [0, 1], [1, 0]], dtype=np.int64), # indices
+ np.array(
+ [3.0, 4.0, 5.0], dtype=np.float32), # values
+ np.array(
+ [3, 2], dtype=np.int64)) # shape: num_time = 3, max_feat = 2
expected_st_b = (
- np.array([[0, 0], [3, 0], [3, 1]], dtype=np.int64), # indices
- np.array(["a", "b", "c"], dtype="|S"), # values
- np.array([4, 2], dtype=np.int64)) # shape: num_time = 4, max_feat = 2
+ np.array(
+ [[0, 0], [3, 0], [3, 1]], dtype=np.int64), # indices
+ np.array(
+ ["a", "b", "c"], dtype="|S"), # values
+ np.array(
+ [4, 2], dtype=np.int64)) # shape: num_time = 4, max_feat = 2
expected_st_c = (
- np.empty((0, 2), dtype=np.int64), # indices
- np.empty((0,), dtype=np.int64), # values
- np.array([0, 0], dtype=np.int64)) # shape: num_time = 0, max_feat = 0
+ np.empty(
+ (0, 2), dtype=np.int64), # indices
+ np.empty(
+ (0,), dtype=np.int64), # values
+ np.array(
+ [0, 0], dtype=np.int64)) # shape: num_time = 0, max_feat = 0
expected_feature_list_output = {
- "a": np.array([[3, 4], [1, 0]], dtype=np.int64),
+ "a": np.array(
+ [[3, 4], [1, 0]], dtype=np.int64),
"st_a": expected_st_a,
"st_b": expected_st_b,
"st_c": expected_st_c,
}
- self._test({
- "example_name": "in1",
- "serialized": tf.convert_to_tensor(serialized),
- "sequence_features": {
- "st_a": tf.VarLenFeature(tf.float32),
- "st_b": tf.VarLenFeature(tf.string),
- "st_c": tf.VarLenFeature(tf.int64),
- "a": tf.FixedLenSequenceFeature((2,), tf.int64),
- }
- }, expected_feat_list_values=expected_feature_list_output)
+ self._test(
+ {
+ "example_name": "in1",
+ "serialized": tf.convert_to_tensor(serialized),
+ "sequence_features": {
+ "st_a": tf.VarLenFeature(tf.float32),
+ "st_b": tf.VarLenFeature(tf.string),
+ "st_c": tf.VarLenFeature(tf.int64),
+ "a": tf.FixedLenSequenceFeature((2,), tf.int64),
+ }
+ },
+ expected_feat_list_values=expected_feature_list_output)
def testSequenceExampleListWithInconsistentDataFails(self):
original = sequence_example(feature_lists=feature_lists({
"a": feature_list([
- int64_feature([-1, 0]),
- float_feature([2, 3])])
- }))
+ int64_feature([-1, 0]), float_feature([2, 3])
+ ])
+ }))
serialized = original.SerializeToString()
- self._test({
- "example_name": "in1",
- "serialized": tf.convert_to_tensor(serialized),
- "sequence_features": {"a": tf.FixedLenSequenceFeature((2,), tf.int64)}
- }, expected_err=(
- tf.OpError,
- "Feature list: a, Index: 1."
- " Data types don't match. Expected type: int64"))
+ self._test(
+ {
+ "example_name": "in1",
+ "serialized": tf.convert_to_tensor(serialized),
+ "sequence_features": {"a": tf.FixedLenSequenceFeature(
+ (2,), tf.int64)}
+ },
+ expected_err=(tf.OpError, "Feature list: a, Index: 1."
+ " Data types don't match. Expected type: int64"))
def testSequenceExampleListWithWrongDataTypeFails(self):
original = sequence_example(feature_lists=feature_lists({
"a": feature_list([
- float_feature([2, 3])])
- }))
+ float_feature([2, 3])
+ ])
+ }))
serialized = original.SerializeToString()
- self._test({
- "example_name": "in1",
- "serialized": tf.convert_to_tensor(serialized),
- "sequence_features": {"a": tf.FixedLenSequenceFeature((2,), tf.int64)}
- }, expected_err=(
- tf.OpError,
- "Feature list: a, Index: 0. Data types don't match."
- " Expected type: int64"))
+ self._test(
+ {
+ "example_name": "in1",
+ "serialized": tf.convert_to_tensor(serialized),
+ "sequence_features": {"a": tf.FixedLenSequenceFeature(
+ (2,), tf.int64)}
+ },
+ expected_err=(tf.OpError,
+ "Feature list: a, Index: 0. Data types don't match."
+ " Expected type: int64"))
def testSequenceExampleListWithWrongSparseDataTypeFails(self):
original = sequence_example(feature_lists=feature_lists({
"a": feature_list([
- int64_feature([3, 4]),
- int64_feature([1, 2]),
- float_feature([2.0, 3.0])])
- }))
+ int64_feature([3, 4]), int64_feature([1, 2]),
+ float_feature([2.0, 3.0])
+ ])
+ }))
serialized = original.SerializeToString()
- self._test({
- "example_name": "in1",
- "serialized": tf.convert_to_tensor(serialized),
- "sequence_features": {"a": tf.FixedLenSequenceFeature((2,), tf.int64)}
- }, expected_err=(
- tf.OpError,
- "Name: in1, Feature list: a, Index: 2."
- " Data types don't match. Expected type: int64"
- " Feature is: float_list"))
+ self._test(
+ {
+ "example_name": "in1",
+ "serialized": tf.convert_to_tensor(serialized),
+ "sequence_features": {"a": tf.FixedLenSequenceFeature(
+ (2,), tf.int64)}
+ },
+ expected_err=(tf.OpError, "Name: in1, Feature list: a, Index: 2."
+ " Data types don't match. Expected type: int64"
+ " Feature is: float_list"))
def testSequenceExampleListWithWrongShapeFails(self):
original = sequence_example(feature_lists=feature_lists({
"a": feature_list([
- int64_feature([2, 3]),
- int64_feature([2, 3, 4])]),
- }))
+ int64_feature([2, 3]), int64_feature([2, 3, 4])
+ ]),
+ }))
serialized = original.SerializeToString()
- self._test({
- "example_name": "in1",
- "serialized": tf.convert_to_tensor(serialized),
- "sequence_features": {"a": tf.FixedLenSequenceFeature((2,), tf.int64)}
- }, expected_err=(
- tf.OpError,
- r"Name: in1, Key: a, Index: 1."
- r" Number of int64 values != expected."
- r" values size: 3 but output shape: \[2\]"))
+ self._test(
+ {
+ "example_name": "in1",
+ "serialized": tf.convert_to_tensor(serialized),
+ "sequence_features": {"a": tf.FixedLenSequenceFeature(
+ (2,), tf.int64)}
+ },
+ expected_err=(tf.OpError, r"Name: in1, Key: a, Index: 1."
+ r" Number of int64 values != expected."
+ r" values size: 3 but output shape: \[2\]"))
def testSequenceExampleWithMissingFeatureListFails(self):
original = sequence_example(feature_lists=feature_lists({}))
# Test fails because we didn't add:
# feature_list_dense_defaults = {"a": None}
- self._test({
- "example_name": "in1",
- "serialized": tf.convert_to_tensor(original.SerializeToString()),
- "sequence_features": {"a": tf.FixedLenSequenceFeature((2,), tf.int64)}
- }, expected_err=(
- tf.OpError,
- "Name: in1, Feature list 'a' is required but could not be found."
- " Did you mean to include it in"
- " feature_list_dense_missing_assumed_empty or"
- " feature_list_dense_defaults?"))
+ self._test(
+ {
+ "example_name": "in1",
+ "serialized": tf.convert_to_tensor(original.SerializeToString()),
+ "sequence_features": {"a": tf.FixedLenSequenceFeature(
+ (2,), tf.int64)}
+ },
+ expected_err=(
+ tf.OpError,
+ "Name: in1, Feature list 'a' is required but could not be found."
+ " Did you mean to include it in"
+ " feature_list_dense_missing_assumed_empty or"
+ " feature_list_dense_defaults?"))
class DecodeJSONExampleTest(tf.test.TestCase):
@@ -740,14 +852,15 @@ class DecodeJSONExampleTest(tf.test.TestCase):
json_tensor = tf.constant(
[json_format.MessageToJson(m) for m in examples.flatten()],
- shape=examples.shape, dtype=tf.string)
+ shape=examples.shape,
+ dtype=tf.string)
binary_tensor = tf.decode_json_example(json_tensor)
binary_val = sess.run(binary_tensor)
if examples.shape:
self.assertShapeEqual(binary_val, json_tensor)
- for input_example, output_binary in zip(np.array(examples).flatten(),
- binary_val.flatten()):
+ for input_example, output_binary in zip(
+ np.array(examples).flatten(), binary_val.flatten()):
output_example = tf.train.Example()
output_example.ParseFromString(output_binary)
self.assertProtoEquals(input_example, output_example)