aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Shivani Agrawal <shivaniagrawal@google.com>2018-08-22 15:00:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-22 15:09:36 -0700
commitfb3bde1994d4ed7d6cb928326e8e2a1777930e5e (patch)
tree2232beae78bbeaf016f2f165e9f34e4e5a23bf3e
parentb56e4377687b95014fa8dadc8e99192484fa79a0 (diff)
[tf.data] Implements `dataset` transformation `parse_example_dataset(..)` which will replace dataset.map(parsing_ops.parse_example(..)).
PiperOrigin-RevId: 209836033
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD22
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py850
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD20
-rw-r--r--tensorflow/contrib/data/python/ops/parsing_ops.py152
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py8
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ParseExampleDataset.pbtxt69
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ParseExampleDataset.pbtxt4
-rw-r--r--tensorflow/core/kernels/data/BUILD10
-rw-r--r--tensorflow/core/kernels/data/parse_example_dataset_op.cc347
-rw-r--r--tensorflow/core/ops/dataset_ops.cc17
-rw-r--r--tensorflow/python/ops/parsing_ops.py178
11 files changed, 1602 insertions, 75 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index 220f1adf7f..a673c4b6f9 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -237,6 +237,27 @@ py_test(
],
)
+py_test(
+ name = "parsing_ops_test",
+ size = "small",
+ srcs = ["parsing_ops_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/data/python/ops:parsing_ops",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//third_party/py/numpy",
+ ],
+)
+
cuda_py_test(
name = "prefetching_ops_test",
size = "small",
@@ -323,6 +344,7 @@ py_test(
"//tensorflow/python:parsing_ops",
"//tensorflow/python:string_ops",
"//tensorflow/python/data/ops:readers",
+ "//tensorflow/python/data/util:nest",
"//third_party/py/numpy",
],
)
diff --git a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
new file mode 100644
index 0000000000..f6c4a984b8
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
@@ -0,0 +1,850 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.ops.parsing_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+
+import numpy as np
+
+from tensorflow.contrib.data.python.ops import parsing_ops as contrib_parsing_ops
+from tensorflow.core.example import example_pb2
+from tensorflow.core.example import feature_pb2
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
+
+# Helpers for creating Example objects
+example = example_pb2.Example
+feature = feature_pb2.Feature
+features = lambda d: feature_pb2.Features(feature=d)
+bytes_feature = lambda v: feature(bytes_list=feature_pb2.BytesList(value=v))
+int64_feature = lambda v: feature(int64_list=feature_pb2.Int64List(value=v))
+float_feature = lambda v: feature(float_list=feature_pb2.FloatList(value=v))
+# Helpers for creating SequenceExample objects
+feature_list = lambda l: feature_pb2.FeatureList(feature=l)
+feature_lists = lambda d: feature_pb2.FeatureLists(feature_list=d)
+sequence_example = example_pb2.SequenceExample
+
+
+def _compare_output_to_expected(tester, dict_tensors, expected_tensors,
+ flat_output):
+ tester.assertEqual(set(dict_tensors.keys()), set(expected_tensors.keys()))
+
+ i = 0 # Index into the flattened output of session.run()
+ for k, v in sorted(dict_tensors.items()):
+ # TODO(shivaniagrawal): flat_output is same as v.
+ expected_v = expected_tensors[k]
+ tf_logging.info("Comparing key: %s", k)
+ print("i", i, "flat_output", flat_output[i], "expected_v", expected_v)
+ if sparse_tensor.is_sparse(v):
+ # Three outputs for SparseTensor : indices, values, shape.
+ tester.assertEqual([k, len(expected_v)], [k, 3])
+ print("i", i, "flat_output", flat_output[i].indices, "expected_v",
+ expected_v[0])
+ tester.assertAllEqual(expected_v[0], flat_output[i].indices)
+ tester.assertAllEqual(expected_v[1], flat_output[i].values)
+ tester.assertAllEqual(expected_v[2], flat_output[i].dense_shape)
+ else:
+ # One output for standard Tensor.
+ tester.assertAllEqual(expected_v, flat_output[i])
+ i += 1
+
+
+class ParseExampleTest(test.TestCase):
+
+ def _test(self,
+ input_tensor,
+ feature_val,
+ expected_values=None,
+ expected_err=None):
+
+ with self.test_session() as sess:
+ if expected_err:
+ with self.assertRaisesWithPredicateMatch(expected_err[0],
+ expected_err[1]):
+ dataset = dataset_ops.Dataset.from_tensors(input_tensor).apply(
+ contrib_parsing_ops.parse_example_dataset(feature_val))
+ get_next = dataset.make_one_shot_iterator().get_next()
+ sess.run(get_next)
+ return
+ else:
+ # Returns dict w/ Tensors and SparseTensors.
+ # Check values.
+ dataset = dataset_ops.Dataset.from_tensors(input_tensor).apply(
+ contrib_parsing_ops.parse_example_dataset(feature_val))
+ get_next = dataset.make_one_shot_iterator().get_next()
+ result = sess.run(get_next)
+ flattened = nest.flatten(result)
+ print("result", result, "expected_values", expected_values)
+ _compare_output_to_expected(self, result, expected_values, flattened)
+
+ # Check shapes; if serialized is a Tensor we need its size to
+ # properly check.
+ batch_size = (
+ input_tensor.eval().size if isinstance(input_tensor, ops.Tensor) else
+ np.asarray(input_tensor).size)
+ for k, f in feature_val.items():
+ print("output_shapes as list ",
+ tuple(dataset.output_shapes[k].as_list()))
+ if isinstance(f, parsing_ops.FixedLenFeature) and f.shape is not None:
+ self.assertEqual(dataset.output_shapes[k].as_list()[0], batch_size)
+ elif isinstance(f, parsing_ops.VarLenFeature):
+ self.assertEqual(dataset.output_shapes[k].as_list()[1], None)
+
+ def testEmptySerializedWithAllDefaults(self):
+ sparse_name = "st_a"
+ a_name = "a"
+ b_name = "b"
+ c_name = "c:has_a_tricky_name"
+ a_default = [0, 42, 0]
+ b_default = np.random.rand(3, 3).astype(bytes)
+ c_default = np.random.rand(2).astype(np.float32)
+
+ expected_st_a = ( # indices, values, shape
+ np.empty(
+ (0, 2), dtype=np.int64), # indices
+ np.empty(
+ (0,), dtype=np.int64), # sp_a is DT_INT64
+ np.array(
+ [2, 0], dtype=np.int64)) # batch == 2, max_elems = 0
+
+ expected_output = {
+ sparse_name: expected_st_a,
+ a_name: np.array(2 * [[a_default]]),
+ b_name: np.array(2 * [b_default]),
+ c_name: np.array(2 * [c_default]),
+ }
+
+ self._test(
+ ops.convert_to_tensor(["", ""]), {
+ sparse_name:
+ parsing_ops.VarLenFeature(dtypes.int64),
+ a_name:
+ parsing_ops.FixedLenFeature(
+ (1, 3), dtypes.int64, default_value=a_default),
+ b_name:
+ parsing_ops.FixedLenFeature(
+ (3, 3), dtypes.string, default_value=b_default),
+ c_name:
+ parsing_ops.FixedLenFeature(
+ (2,), dtypes.float32, default_value=c_default),
+ },
+ expected_values=expected_output)
+
+ def testEmptySerializedWithoutDefaultsShouldFail(self):
+ input_features = {
+ "st_a":
+ parsing_ops.VarLenFeature(dtypes.int64),
+ "a":
+ parsing_ops.FixedLenFeature(
+ (1, 3), dtypes.int64, default_value=[0, 42, 0]),
+ "b":
+ parsing_ops.FixedLenFeature(
+ (3, 3),
+ dtypes.string,
+ default_value=np.random.rand(3, 3).astype(bytes)),
+ # Feature "c" is missing a default, this gap will cause failure.
+ "c":
+ parsing_ops.FixedLenFeature(
+ (2,), dtype=dtypes.float32),
+ }
+
+ # Edge case where the key is there but the feature value is empty
+ original = example(features=features({"c": feature()}))
+ self._test(
+ [original.SerializeToString()],
+ input_features,
+ expected_err=(errors_impl.InvalidArgumentError,
+ "Feature: c \\(data type: float\\) is required"))
+
+ # Standard case of missing key and value.
+ self._test(
+ ["", ""],
+ input_features,
+ expected_err=(errors_impl.InvalidArgumentError,
+ "Feature: c \\(data type: float\\) is required"))
+
+ def testDenseNotMatchingShapeShouldFail(self):
+ original = [
+ example(features=features({
+ "a": float_feature([1, 1, 3]),
+ })), example(features=features({
+ "a": float_feature([-1, -1]),
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ self._test(
+ ops.convert_to_tensor(serialized),
+ {"a": parsing_ops.FixedLenFeature((1, 3), dtypes.float32)},
+ expected_err=(errors_impl.InvalidArgumentError,
+ "Key: a, Index: 1. Number of float values"))
+
+ def testDenseDefaultNoShapeShouldFail(self):
+ original = [example(features=features({"a": float_feature([1, 1, 3]),})),]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ self._test(
+ ops.convert_to_tensor(serialized),
+ {"a": parsing_ops.FixedLenFeature(None, dtypes.float32)},
+ expected_err=(ValueError, "Missing shape for feature a"))
+
+ def testSerializedContainingSparse(self):
+ original = [
+ example(features=features({
+ "st_c": float_feature([3, 4])
+ })),
+ example(features=features({
+ "st_c": float_feature([]), # empty float list
+ })),
+ example(features=features({
+ "st_d": feature(), # feature with nothing in it
+ })),
+ example(features=features({
+ "st_c": float_feature([1, 2, -1]),
+ "st_d": bytes_feature([b"hi"])
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_st_c = ( # indices, values, shape
+ np.array(
+ [[0, 0], [0, 1], [3, 0], [3, 1], [3, 2]], dtype=np.int64), np.array(
+ [3.0, 4.0, 1.0, 2.0, -1.0], dtype=np.float32), np.array(
+ [4, 3], dtype=np.int64)) # batch == 2, max_elems = 3
+
+ expected_st_d = ( # indices, values, shape
+ np.array(
+ [[3, 0]], dtype=np.int64), np.array(
+ ["hi"], dtype=bytes), np.array(
+ [4, 1], dtype=np.int64)) # batch == 2, max_elems = 1
+
+ expected_output = {
+ "st_c": expected_st_c,
+ "st_d": expected_st_d,
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ "st_c": parsing_ops.VarLenFeature(dtypes.float32),
+ "st_d": parsing_ops.VarLenFeature(dtypes.string)
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingSparseFeature(self):
+ original = [
+ example(features=features({
+ "val": float_feature([3, 4]),
+ "idx": int64_feature([5, 10])
+ })),
+ example(features=features({
+ "val": float_feature([]), # empty float list
+ "idx": int64_feature([])
+ })),
+ example(features=features({
+ "val": feature(), # feature with nothing in it
+ # missing idx feature
+ })),
+ example(features=features({
+ "val": float_feature([1, 2, -1]),
+ "idx":
+ int64_feature([0, 9, 3]) # unsorted
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_sp = ( # indices, values, shape
+ np.array(
+ [[0, 5], [0, 10], [3, 0], [3, 3], [3, 9]], dtype=np.int64),
+ np.array(
+ [3.0, 4.0, 1.0, -1.0, 2.0], dtype=np.float32), np.array(
+ [4, 13], dtype=np.int64)) # batch == 4, max_elems = 13
+
+ expected_output = {"sp": expected_sp,}
+
+ self._test(
+ ops.convert_to_tensor(serialized),
+ {"sp": parsing_ops.SparseFeature(["idx"], "val", dtypes.float32, [13])},
+ expected_values=expected_output)
+
+ def testSerializedContainingSparseFeatureReuse(self):
+ original = [
+ example(features=features({
+ "val1": float_feature([3, 4]),
+ "val2": float_feature([5, 6]),
+ "idx": int64_feature([5, 10])
+ })),
+ example(features=features({
+ "val1": float_feature([]), # empty float list
+ "idx": int64_feature([])
+ })),
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_sp1 = ( # indices, values, shape
+ np.array(
+ [[0, 5], [0, 10]], dtype=np.int64), np.array(
+ [3.0, 4.0], dtype=np.float32), np.array(
+ [2, 13], dtype=np.int64)) # batch == 2, max_elems = 13
+
+ expected_sp2 = ( # indices, values, shape
+ np.array(
+ [[0, 5], [0, 10]], dtype=np.int64), np.array(
+ [5.0, 6.0], dtype=np.float32), np.array(
+ [2, 7], dtype=np.int64)) # batch == 2, max_elems = 13
+
+ expected_output = {
+ "sp1": expected_sp1,
+ "sp2": expected_sp2,
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ "sp1":
+ parsing_ops.SparseFeature("idx", "val1", dtypes.float32, 13),
+ "sp2":
+ parsing_ops.SparseFeature(
+ "idx", "val2", dtypes.float32, size=7, already_sorted=True)
+ },
+ expected_values=expected_output)
+
+ def testSerializedContaining3DSparseFeature(self):
+ original = [
+ example(features=features({
+ "val": float_feature([3, 4]),
+ "idx0": int64_feature([5, 10]),
+ "idx1": int64_feature([0, 2]),
+ })),
+ example(features=features({
+ "val": float_feature([]), # empty float list
+ "idx0": int64_feature([]),
+ "idx1": int64_feature([]),
+ })),
+ example(features=features({
+ "val": feature(), # feature with nothing in it
+ # missing idx feature
+ })),
+ example(features=features({
+ "val": float_feature([1, 2, -1]),
+ "idx0": int64_feature([0, 9, 3]), # unsorted
+ "idx1": int64_feature([1, 0, 2]),
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_sp = (
+ # indices
+ np.array(
+ [[0, 5, 0], [0, 10, 2], [3, 0, 1], [3, 3, 2], [3, 9, 0]],
+ dtype=np.int64),
+ # values
+ np.array([3.0, 4.0, 1.0, -1.0, 2.0], dtype=np.float32),
+ # shape batch == 4, max_elems = 13
+ np.array([4, 13, 3], dtype=np.int64))
+
+ expected_output = {"sp": expected_sp,}
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ "sp":
+ parsing_ops.SparseFeature(["idx0", "idx1"], "val",
+ dtypes.float32, [13, 3])
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingDense(self):
+ aname = "a"
+ bname = "b*has+a:tricky_name"
+ original = [
+ example(features=features({
+ aname: float_feature([1, 1]),
+ bname: bytes_feature([b"b0_str"]),
+ })), example(features=features({
+ aname: float_feature([-1, -1]),
+ bname: bytes_feature([b""]),
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_output = {
+ aname:
+ np.array(
+ [[1, 1], [-1, -1]], dtype=np.float32).reshape(2, 1, 2, 1),
+ bname:
+ np.array(
+ ["b0_str", ""], dtype=bytes).reshape(2, 1, 1, 1, 1),
+ }
+
+ # No defaults, values required
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenFeature((1, 2, 1), dtype=dtypes.float32),
+ bname:
+ parsing_ops.FixedLenFeature((1, 1, 1, 1), dtype=dtypes.string),
+ },
+ expected_values=expected_output)
+
+ # This test is identical as the previous one except
+ # for the creation of 'serialized'.
+ def testSerializedContainingDenseWithConcat(self):
+ aname = "a"
+ bname = "b*has+a:tricky_name"
+ # TODO(lew): Feature appearing twice should be an error in future.
+ original = [
+ (example(features=features({
+ aname: float_feature([10, 10]),
+ })), example(features=features({
+ aname: float_feature([1, 1]),
+ bname: bytes_feature([b"b0_str"]),
+ }))),
+ (
+ example(features=features({
+ bname: bytes_feature([b"b100"]),
+ })),
+ example(features=features({
+ aname: float_feature([-1, -1]),
+ bname: bytes_feature([b"b1"]),
+ })),),
+ ]
+
+ serialized = [
+ m.SerializeToString() + n.SerializeToString() for (m, n) in original
+ ]
+
+ expected_output = {
+ aname:
+ np.array(
+ [[1, 1], [-1, -1]], dtype=np.float32).reshape(2, 1, 2, 1),
+ bname:
+ np.array(
+ ["b0_str", "b1"], dtype=bytes).reshape(2, 1, 1, 1, 1),
+ }
+
+ # No defaults, values required
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenFeature((1, 2, 1), dtype=dtypes.float32),
+ bname:
+ parsing_ops.FixedLenFeature((1, 1, 1, 1), dtype=dtypes.string),
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingDenseScalar(self):
+ original = [
+ example(features=features({
+ "a": float_feature([1]),
+ })), example(features=features({}))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_output = {
+ "a":
+ np.array(
+ [[1], [-1]], dtype=np.float32) # 2x1 (column vector)
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ "a":
+ parsing_ops.FixedLenFeature(
+ (1,), dtype=dtypes.float32, default_value=-1),
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingDenseWithDefaults(self):
+ original = [
+ example(features=features({
+ "a": float_feature([1, 1]),
+ })),
+ example(features=features({
+ "b": bytes_feature([b"b1"]),
+ })),
+ example(features=features({
+ "b": feature()
+ })),
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_output = {
+ "a":
+ np.array(
+ [[1, 1], [3, -3], [3, -3]], dtype=np.float32).reshape(3, 1, 2,
+ 1),
+ "b":
+ np.array(
+ ["tmp_str", "b1", "tmp_str"], dtype=bytes).reshape(3, 1, 1, 1,
+ 1),
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ "a":
+ parsing_ops.FixedLenFeature(
+ (1, 2, 1), dtype=dtypes.float32, default_value=[3.0, -3.0]),
+ "b":
+ parsing_ops.FixedLenFeature(
+ (1, 1, 1, 1), dtype=dtypes.string, default_value="tmp_str"),
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingSparseAndSparseFeatureAndDenseWithNoDefault(self):
+ expected_st_a = ( # indices, values, shape
+ np.empty(
+ (0, 2), dtype=np.int64), # indices
+ np.empty(
+ (0,), dtype=np.int64), # sp_a is DT_INT64
+ np.array(
+ [2, 0], dtype=np.int64)) # batch == 2, max_elems = 0
+ expected_sp = ( # indices, values, shape
+ np.array(
+ [[0, 0], [0, 3], [1, 7]], dtype=np.int64), np.array(
+ ["a", "b", "c"], dtype="|S"), np.array(
+ [2, 13], dtype=np.int64)) # batch == 4, max_elems = 13
+
+ original = [
+ example(features=features({
+ "c": float_feature([3, 4]),
+ "val": bytes_feature([b"a", b"b"]),
+ "idx": int64_feature([0, 3])
+ })), example(features=features({
+ "c": float_feature([1, 2]),
+ "val": bytes_feature([b"c"]),
+ "idx": int64_feature([7])
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ a_default = [1, 2, 3]
+ b_default = np.random.rand(3, 3).astype(bytes)
+ expected_output = {
+ "st_a": expected_st_a,
+ "sp": expected_sp,
+ "a": np.array(2 * [[a_default]]),
+ "b": np.array(2 * [b_default]),
+ "c": np.array(
+ [[3, 4], [1, 2]], dtype=np.float32),
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized),
+ {
+ "st_a":
+ parsing_ops.VarLenFeature(dtypes.int64),
+ "sp":
+ parsing_ops.SparseFeature("idx", "val", dtypes.string, 13),
+ "a":
+ parsing_ops.FixedLenFeature(
+ (1, 3), dtypes.int64, default_value=a_default),
+ "b":
+ parsing_ops.FixedLenFeature(
+ (3, 3), dtypes.string, default_value=b_default),
+ # Feature "c" must be provided, since it has no default_value.
+ "c":
+ parsing_ops.FixedLenFeature((2,), dtypes.float32),
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingSparseAndSparseFeatureWithReuse(self):
+ expected_idx = ( # indices, values, shape
+ np.array(
+ [[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.int64),
+ np.array([0, 3, 7, 1]), np.array(
+ [2, 2], dtype=np.int64)) # batch == 4, max_elems = 2
+
+ expected_sp = ( # indices, values, shape
+ np.array(
+ [[0, 0], [0, 3], [1, 1], [1, 7]], dtype=np.int64), np.array(
+ ["a", "b", "d", "c"], dtype="|S"), np.array(
+ [2, 13], dtype=np.int64)) # batch == 4, max_elems = 13
+
+ original = [
+ example(features=features({
+ "val": bytes_feature([b"a", b"b"]),
+ "idx": int64_feature([0, 3])
+ })), example(features=features({
+ "val": bytes_feature([b"c", b"d"]),
+ "idx": int64_feature([7, 1])
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_output = {
+ "idx": expected_idx,
+ "sp": expected_sp,
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ "idx":
+ parsing_ops.VarLenFeature(dtypes.int64),
+ "sp":
+ parsing_ops.SparseFeature(["idx"], "val", dtypes.string, [13]),
+ },
+ expected_values=expected_output)
+
+ def _testSerializedContainingVarLenDenseLargerBatch(self, batch_size):
+ # During parsing, data read from the serialized proto is stored in buffers.
+ # For small batch sizes, a buffer will contain one minibatch entry.
+ # For larger batch sizes, a buffer may contain several minibatch
+ # entries. This test identified a bug where the code that copied
+ # data out of the buffers and into the output tensors assumed each
+ # buffer only contained one minibatch entry. The bug has since been fixed.
+ truth_int = [i for i in range(batch_size)]
+ truth_str = [[("foo%d" % i).encode(), ("bar%d" % i).encode()]
+ for i in range(batch_size)]
+
+ expected_str = copy.deepcopy(truth_str)
+
+ # Delete some intermediate entries
+ for i in range(batch_size):
+ col = 1
+ if np.random.rand() < 0.25:
+ # w.p. 25%, drop out the second entry
+ expected_str[i][col] = b"default"
+ col -= 1
+ truth_str[i].pop()
+ if np.random.rand() < 0.25:
+ # w.p. 25%, drop out the second entry (possibly again)
+ expected_str[i][col] = b"default"
+ truth_str[i].pop()
+
+ expected_output = {
+ # Batch size batch_size, 1 time step.
+ "a": np.array(truth_int, dtype=np.int64).reshape(batch_size, 1),
+ # Batch size batch_size, 2 time steps.
+ "b": np.array(expected_str, dtype="|S").reshape(batch_size, 2),
+ }
+
+ original = [
+ example(features=features(
+ {"a": int64_feature([truth_int[i]]),
+ "b": bytes_feature(truth_str[i])}))
+ for i in range(batch_size)
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ self._test(
+ ops.convert_to_tensor(serialized, dtype=dtypes.string), {
+ "a":
+ parsing_ops.FixedLenSequenceFeature(
+ shape=(),
+ dtype=dtypes.int64,
+ allow_missing=True,
+ default_value=-1),
+ "b":
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[],
+ dtype=dtypes.string,
+ allow_missing=True,
+ default_value="default"),
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingVarLenDenseLargerBatch(self):
+ np.random.seed(3456)
+ for batch_size in (1, 10, 20, 100, 256):
+ self._testSerializedContainingVarLenDenseLargerBatch(batch_size)
+
+ def testSerializedContainingVarLenDense(self):
+ aname = "a"
+ bname = "b"
+ cname = "c"
+ dname = "d"
+ original = [
+ example(features=features({
+ cname: int64_feature([2]),
+ })),
+ example(features=features({
+ aname: float_feature([1, 1]),
+ bname: bytes_feature([b"b0_str", b"b1_str"]),
+ })),
+ example(features=features({
+ aname: float_feature([-1, -1, 2, 2]),
+ bname: bytes_feature([b"b1"]),
+ })),
+ example(features=features({
+ aname: float_feature([]),
+ cname: int64_feature([3]),
+ })),
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_output = {
+ aname:
+ np.array(
+ [
+ [0, 0, 0, 0],
+ [1, 1, 0, 0],
+ [-1, -1, 2, 2],
+ [0, 0, 0, 0],
+ ],
+ dtype=np.float32).reshape(4, 2, 2, 1),
+ bname:
+ np.array(
+ [["", ""], ["b0_str", "b1_str"], ["b1", ""], ["", ""]],
+ dtype=bytes).reshape(4, 2, 1, 1, 1),
+ cname:
+ np.array([2, 0, 0, 3], dtype=np.int64).reshape(4, 1),
+ dname:
+ np.empty(shape=(4, 0), dtype=bytes),
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1), dtype=dtypes.float32, allow_missing=True),
+ bname:
+ parsing_ops.FixedLenSequenceFeature(
+ (1, 1, 1), dtype=dtypes.string, allow_missing=True),
+ cname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.int64, allow_missing=True),
+ dname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.string, allow_missing=True),
+ },
+ expected_values=expected_output)
+
+ # Test with padding values.
+ expected_output_custom_padding = dict(expected_output)
+ expected_output_custom_padding[aname] = np.array(
+ [
+ [-2, -2, -2, -2],
+ [1, 1, -2, -2],
+ [-1, -1, 2, 2],
+ [-2, -2, -2, -2],
+ ],
+ dtype=np.float32).reshape(4, 2, 2, 1)
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1),
+ dtype=dtypes.float32,
+ allow_missing=True,
+ default_value=-2.0),
+ bname:
+ parsing_ops.FixedLenSequenceFeature(
+ (1, 1, 1), dtype=dtypes.string, allow_missing=True),
+ cname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.int64, allow_missing=True),
+ dname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.string, allow_missing=True),
+ }, expected_output_custom_padding)
+
+ # Change number of required values so the inputs are not a
+ # multiple of this size.
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1), dtype=dtypes.float32, allow_missing=True),
+ bname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1, 1), dtype=dtypes.string, allow_missing=True),
+ },
+ expected_err=(
+ errors_impl.OpError, "Key: b, Index: 2. "
+ "Number of bytes values is not a multiple of stride length."))
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1),
+ dtype=dtypes.float32,
+ allow_missing=True,
+ default_value=[]),
+ bname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1, 1), dtype=dtypes.string, allow_missing=True),
+ },
+ expected_err=(ValueError,
+ "Cannot reshape a tensor with 0 elements to shape"))
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenFeature((None, 2, 1), dtype=dtypes.float32),
+ bname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1, 1), dtype=dtypes.string, allow_missing=True),
+ },
+ expected_err=(ValueError,
+ "First dimension of shape for feature a unknown. "
+ "Consider using FixedLenSequenceFeature."))
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ cname:
+ parsing_ops.FixedLenFeature(
+ (1, None), dtype=dtypes.int64, default_value=[[1]]),
+ },
+ expected_err=(ValueError,
+ "All dimensions of shape for feature c need to be known "
+ r"but received \(1, None\)."))
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1), dtype=dtypes.float32, allow_missing=True),
+ bname:
+ parsing_ops.FixedLenSequenceFeature(
+ (1, 1, 1), dtype=dtypes.string, allow_missing=True),
+ cname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.int64, allow_missing=False),
+ dname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.string, allow_missing=True),
+ },
+ expected_err=(ValueError,
+ "Unsupported: FixedLenSequenceFeature requires "
+ "allow_missing to be True."))
+
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index ad9378dfb9..d540ba470a 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -80,6 +80,7 @@ py_library(
":batching",
":gen_dataset_ops",
":interleave_ops",
+ ":parsing_ops",
":shuffle_ops",
":stats_ops",
"//tensorflow/python:constant_op",
@@ -87,10 +88,7 @@ py_library(
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:lib",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:parsing_ops",
"//tensorflow/python:platform",
- "//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
"//tensorflow/python/data/ops:dataset_ops",
@@ -211,6 +209,22 @@ py_library(
)
py_library(
+ name = "parsing_ops",
+ srcs = ["parsing_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
+
+py_library(
name = "map_defun",
srcs = ["map_defun.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/contrib/data/python/ops/parsing_ops.py b/tensorflow/contrib/data/python/ops/parsing_ops.py
new file mode 100644
index 0000000000..f868653554
--- /dev/null
+++ b/tensorflow/contrib/data/python/ops/parsing_ops.py
@@ -0,0 +1,152 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental `dataset` API for parsing example."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import parsing_ops
+
+
+class _ParseExampleDataset(dataset_ops.Dataset):
+ """A `Dataset` that parses `example` dataset into a `dict` dataset."""
+
+ def __init__(self, input_dataset, features, num_parallel_calls):
+ super(_ParseExampleDataset, self).__init__()
+ self._input_dataset = input_dataset
+ if not all(types == dtypes.string
+ for types in nest.flatten(input_dataset.output_types)):
+ raise TypeError("Input dataset should be a dataset of vectors of strings")
+ self._num_parallel_calls = num_parallel_calls
+ # pylint: disable=protected-access
+ self._features = parsing_ops._prepend_none_dimension(features)
+ # sparse_keys and dense_keys come back sorted here.
+ (sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults,
+ dense_shapes) = parsing_ops._features_to_raw_params(
+ self._features, [
+ parsing_ops.VarLenFeature, parsing_ops.SparseFeature,
+ parsing_ops.FixedLenFeature, parsing_ops.FixedLenSequenceFeature
+ ])
+ # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature.
+ (_, dense_defaults_vec, sparse_keys, sparse_types, dense_keys, dense_shapes,
+ dense_shape_as_shape) = parsing_ops._process_raw_parameters(
+ None, dense_defaults, sparse_keys, sparse_types, dense_keys,
+ dense_types, dense_shapes)
+ # pylint: enable=protected-access
+ self._sparse_keys = sparse_keys
+ self._sparse_types = sparse_types
+ self._dense_keys = dense_keys
+ self._dense_defaults = dense_defaults_vec
+ self._dense_shapes = dense_shapes
+ self._dense_types = dense_types
+ dense_output_shapes = [
+ self._input_dataset.output_shapes.concatenate(shape)
+ for shape in dense_shape_as_shape
+ ]
+ sparse_output_shapes = [
+ self._input_dataset.output_shapes.concatenate([None])
+ for _ in range(len(sparse_keys))
+ ]
+
+ self._output_shapes = dict(
+ zip(self._dense_keys + self._sparse_keys,
+ dense_output_shapes + sparse_output_shapes))
+ self._output_types = dict(
+ zip(self._dense_keys + self._sparse_keys,
+ self._dense_types + self._sparse_types))
+ self._output_classes = dict(
+ zip(self._dense_keys + self._sparse_keys,
+ [ops.Tensor for _ in range(len(self._dense_defaults))] +
+ [sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys))
+ ]))
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.parse_example_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._num_parallel_calls,
+ self._dense_defaults,
+ self._sparse_keys,
+ self._dense_keys,
+ self._sparse_types,
+ self._dense_shapes,
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+
+# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
+# or make private / remove.
+# TODO(b/111553342): add arguments names and example names as well.
+def parse_example_dataset(features, num_parallel_calls=1):
+ """A transformation that parses `Example` protos into a `dict` of tensors.
+
+ Parses a number of serialized `Example` protos given in `serialized`. We refer
+ to `serialized` as a batch with `batch_size` many entries of individual
+ `Example` protos.
+
+ This op parses serialized examples into a dictionary mapping keys to `Tensor`
+ and `SparseTensor` objects. `features` is a dict from keys to `VarLenFeature`,
+ `SparseFeature`, and `FixedLenFeature` objects. Each `VarLenFeature`
+ and `SparseFeature` is mapped to a `SparseTensor`, and each
+ `FixedLenFeature` is mapped to a `Tensor`. See `tf.parse_example` for more
+ details about feature dictionaries.
+
+ Args:
+ features: A `dict` mapping feature keys to `FixedLenFeature`,
+ `VarLenFeature`, and `SparseFeature` values.
+ num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
+ representing the number of parsing processes to call in parallel.
+
+ Returns:
+ A dataset transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+
+ Raises:
+ ValueError: if features argument is None.
+ """
+ if features is None:
+ raise ValueError("Missing: features was %s." % features)
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ out_dataset = _ParseExampleDataset(dataset, features, num_parallel_calls)
+ if any([
+ isinstance(feature, parsing_ops.SparseFeature)
+ for _, feature in features.items()
+ ]):
+ # pylint: disable=protected-access
+ # pylint: disable=g-long-lambda
+ out_dataset = out_dataset.map(
+ lambda x: parsing_ops._construct_sparse_tensors_for_sparse_features(
+ features, x), num_parallel_calls=num_parallel_calls)
+ return out_dataset
+
+ return _apply_fn
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index 3882d4bfdb..151f12b082 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -25,6 +25,7 @@ import numpy as np
from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops
from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.contrib.data.python.ops import parsing_ops
from tensorflow.contrib.data.python.ops import shuffle_ops
from tensorflow.contrib.data.python.ops import stats_ops
from tensorflow.python.data.ops import dataset_ops
@@ -37,7 +38,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import gfile
from tensorflow.python.util import deprecation
@@ -788,9 +788,9 @@ def make_batched_features_dataset(file_pattern,
batch_size, drop_remainder=drop_final_batch or num_epochs is None)
# Parse `Example` tensors to a dictionary of `Feature` tensors.
- dataset = dataset.map(
- lambda x: parsing_ops.parse_example(x, features),
- num_parallel_calls=parser_num_threads)
+ dataset = dataset.apply(
+ parsing_ops.parse_example_dataset(
+ features, num_parallel_calls=parser_num_threads))
# TODO(rachelim): Add an optional label_name argument for extracting the label
# from the features dictionary, to comply with the type expected by the
diff --git a/tensorflow/core/api_def/base_api/api_def_ParseExampleDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ParseExampleDataset.pbtxt
new file mode 100644
index 0000000000..3de2f18fc2
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ParseExampleDataset.pbtxt
@@ -0,0 +1,69 @@
+op {
+ graph_op_name: "ParseExampleDataset"
+ in_arg {
+ name: "dense_defaults"
+ description: <<END
+A dict mapping string keys to `Tensor`s.
+The keys of the dict must match the dense_keys of the feature.
+END
+ }
+ attr {
+ name: "sparse_keys"
+ description: <<END
+A list of string keys in the examples features.
+The results for these keys will be returned as `SparseTensor` objects.
+END
+ }
+ attr {
+ name: "dense_keys"
+ description: <<END
+A list of Ndense string Tensors (scalars).
+The keys expected in the Examples features associated with dense values.
+END
+ }
+ attr {
+ name: "sparse_types"
+ description: <<END
+A list of `DTypes` of the same length as `sparse_keys`.
+Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`),
+and `tf.string` (`BytesList`) are supported.
+END
+ }
+ attr {
+ name: "Tdense"
+ description: <<END
+A list of DTypes of the same length as `dense_keys`.
+Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`),
+and `tf.string` (`BytesList`) are supported.
+
+END
+ }
+ attr {
+ name: "dense_shapes"
+ description: <<END
+List of tuples with the same length as `dense_keys`.
+The shape of the data for each dense feature referenced by `dense_keys`.
+Required for any input tensors identified by `dense_keys`. Must be
+either fully defined, or may contain an unknown first dimension.
+An unknown first dimension means the feature is treated as having
+a variable number of blocks, and the output shape along this dimension
+is considered unknown at graph build time. Padding is applied for
+minibatch elements smaller than the maximum number of blocks for the
+given feature along this dimension.
+END
+ }
+ attr {
+ name: "output_types"
+ description: <<END
+The type list for the return values.
+END
+ }
+ attr {
+ name: "output_shapes"
+ description: <<END
+The list of shapes being produced.
+END
+ }
+ summary: "Transforms `input_dataset` containing `Example` protos as vectors of DT_STRING into a dataset of `Tensor` or `SparseTensor` objects representing the parsed features."
+}
+
diff --git a/tensorflow/core/api_def/python_api/api_def_ParseExampleDataset.pbtxt b/tensorflow/core/api_def/python_api/api_def_ParseExampleDataset.pbtxt
new file mode 100644
index 0000000000..45826b6fdc
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ParseExampleDataset.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ParseExampleDataset"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index 607a694dba..82ff2a365d 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -232,6 +232,15 @@ cc_library(
],
)
+cc_library(
+ name = "parse_example_dataset_op",
+ srcs = ["parse_example_dataset_op.cc"],
+ deps = [
+ ":parallel_map_iterator",
+ "//tensorflow/core:framework",
+ ],
+)
+
tf_kernel_library(
name = "parallel_map_dataset_op",
srcs = ["parallel_map_dataset_op.cc"],
@@ -668,6 +677,7 @@ tf_kernel_library(
":padded_batch_dataset_op",
":parallel_interleave_dataset_op",
":parallel_map_dataset_op",
+ ":parse_example_dataset_op",
":prefetch_dataset_op",
":random_dataset_op",
":range_dataset_op",
diff --git a/tensorflow/core/kernels/data/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
new file mode 100644
index 0000000000..1ab2af3e92
--- /dev/null
+++ b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
@@ -0,0 +1,347 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <deque>
+
+#include "tensorflow/core/kernels/data/parallel_map_iterator.h"
+#include "tensorflow/core/util/example_proto_fast_parsing.h"
+
+namespace tensorflow {
+
+namespace {
+
+// See documentation in ../ops/dataset_ops.cc for a high-level
+// description of the following op.
+
+class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit ParseExampleDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx),
+ graph_def_version_(ctx->graph_def_version()) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("sparse_keys", &sparse_keys_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dense_keys", &dense_keys_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("sparse_types", &sparse_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("Tdense", &dense_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dense_shapes", &dense_shapes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ for (int i = 0; i < dense_shapes_.size(); ++i) {
+ bool shape_ok = true;
+ if (dense_shapes_[i].dims() == -1) {
+ shape_ok = false;
+ } else {
+ for (int d = 1; d < dense_shapes_[i].dims(); ++d) {
+ if (dense_shapes_[i].dim_size(d) == -1) {
+ shape_ok = false;
+ }
+ }
+ }
+ OP_REQUIRES(ctx, shape_ok,
+ errors::InvalidArgument(
+ "dense_shapes[", i,
+ "] has unknown rank or unknown inner dimensions: ",
+ dense_shapes_[i].DebugString()));
+ TensorShape dense_shape;
+ if (dense_shapes_[i].dims() > 0 && dense_shapes_[i].dim_size(0) == -1) {
+ variable_length_.push_back(true);
+ for (int d = 1; d < dense_shapes_[i].dims(); ++d) {
+ dense_shape.AddDim(dense_shapes_[i].dim_size(d));
+ }
+ } else {
+ variable_length_.push_back(false);
+ dense_shapes_[i].AsTensorShape(&dense_shape);
+ }
+ elements_per_stride_.push_back(dense_shape.num_elements());
+ }
+ }
+
+ protected:
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ int64 num_parallel_calls;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
+ &num_parallel_calls));
+ OP_REQUIRES(ctx, num_parallel_calls > 0,
+ errors::InvalidArgument(
+ "num_parallel_calls must be greater than zero."));
+
+ OpInputList dense_default_tensors;
+ OP_REQUIRES_OK(ctx,
+ ctx->input_list("dense_defaults", &dense_default_tensors));
+
+ OP_REQUIRES(ctx, dense_default_tensors.size() == dense_keys_.size(),
+ errors::InvalidArgument(
+ "Expected len(dense_defaults) == len(dense_keys) but got: ",
+ dense_default_tensors.size(), " vs. ", dense_keys_.size()));
+
+ std::vector<Tensor> dense_defaults;
+ dense_defaults.reserve(dense_default_tensors.size());
+ for (const Tensor& dense_default_t : dense_default_tensors) {
+ dense_defaults.push_back(dense_default_t);
+ }
+
+ for (int d = 0; d < dense_keys_.size(); ++d) {
+ const Tensor& def_value = dense_defaults[d];
+ if (variable_length_[d]) {
+ OP_REQUIRES(ctx, def_value.NumElements() == 1,
+ errors::InvalidArgument(
+ "dense_shape[", d, "] is a variable length shape: ",
+ dense_shapes_[d].DebugString(),
+ ", therefore "
+ "def_value[",
+ d,
+ "] must contain a single element ("
+ "the padding element). But its shape is: ",
+ def_value.shape().DebugString()));
+ } else if (def_value.NumElements() > 0) {
+ OP_REQUIRES(ctx, dense_shapes_[d].IsCompatibleWith(def_value.shape()),
+ errors::InvalidArgument(
+ "def_value[", d,
+ "].shape() == ", def_value.shape().DebugString(),
+ " is not compatible with dense_shapes_[", d,
+ "] == ", dense_shapes_[d].DebugString()));
+ }
+ OP_REQUIRES(ctx, def_value.dtype() == dense_types_[d],
+ errors::InvalidArgument(
+ "dense_defaults[", d, "].dtype() == ",
+ DataTypeString(def_value.dtype()), " != dense_types_[", d,
+ "] == ", DataTypeString(dense_types_[d])));
+ }
+
+ example::FastParseExampleConfig config;
+ std::map<string, int> key_to_output_index;
+ for (int d = 0; d < dense_keys_.size(); ++d) {
+ config.dense.push_back({dense_keys_[d], dense_types_[d], dense_shapes_[d],
+ dense_default_tensors[d], variable_length_[d],
+ elements_per_stride_[d]});
+ auto result = key_to_output_index.insert({dense_keys_[d], 0});
+ OP_REQUIRES(ctx, result.second,
+ errors::InvalidArgument("Duplicate key not allowed: ",
+ dense_keys_[d]));
+ }
+ for (int d = 0; d < sparse_keys_.size(); ++d) {
+ config.sparse.push_back({sparse_keys_[d], sparse_types_[d]});
+ auto result = key_to_output_index.insert({sparse_keys_[d], 0});
+ OP_REQUIRES(ctx, result.second,
+ errors::InvalidArgument("Duplicate key not allowed: ",
+ sparse_keys_[d]));
+ }
+ int i = 0;
+ for (auto it = key_to_output_index.begin(); it != key_to_output_index.end();
+ it++) {
+ it->second = i++;
+ }
+
+ *output = new Dataset(ctx, input, std::move(dense_defaults),
+ std::move(sparse_keys_), std::move(dense_keys_),
+ std::move(key_to_output_index), std::move(config),
+ num_parallel_calls, sparse_types_, dense_types_,
+ dense_shapes_, output_types_, output_shapes_);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ std::vector<Tensor> dense_defaults, std::vector<string> sparse_keys,
+ std::vector<string> dense_keys,
+ std::map<string, int> key_to_output_index,
+ example::FastParseExampleConfig config, int32 num_parallel_calls,
+ const DataTypeVector& sparse_types,
+ const DataTypeVector& dense_types,
+ const std::vector<PartialTensorShape>& dense_shapes,
+ const DataTypeVector& output_types,
+ const std::vector<PartialTensorShape>& output_shapes)
+ : DatasetBase(DatasetContext(ctx)),
+ input_(input),
+ device_threadpool_(
+ ctx->device()->tensorflow_cpu_worker_threads()->workers),
+ dense_defaults_(std::move(dense_defaults)),
+ sparse_keys_(std::move(sparse_keys)),
+ dense_keys_(std::move(dense_keys)),
+ key_to_output_index_(std::move(key_to_output_index)),
+ config_(std::move(config)),
+ num_parallel_calls_(num_parallel_calls),
+ sparse_types_(sparse_types),
+ dense_types_(dense_types),
+ dense_shapes_(dense_shapes),
+ output_types_(output_types),
+ output_shapes_(output_shapes) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ auto map_fn = [this](IteratorContext* ctx,
+ std::vector<Tensor> input_element,
+ std::vector<Tensor>* result, StatusCallback done) {
+ (*ctx->runner())([this, input_element, result, done]() {
+ std::vector<string> slice_vec;
+ for (Tensor t : input_element) {
+ auto serialized_t = t.flat<string>();
+ gtl::ArraySlice<string> slice(serialized_t.data(),
+ serialized_t.size());
+ for (auto it = slice.begin(); it != slice.end(); it++)
+ slice_vec.push_back(*it);
+ }
+ example::Result example_result;
+ // TODO(b/111553342): Add stats collection logic here.
+ Status s = FastParseExample(config_, slice_vec, {},
+ device_threadpool_, &example_result);
+ if (s.ok()) {
+ (*result).resize(key_to_output_index_.size());
+ for (int d = 0; d < dense_keys_.size(); ++d) {
+ int output_index = key_to_output_index_.at(dense_keys_[d]);
+ CHECK(example_result.dense_values[d].dtype() ==
+ output_dtypes()[output_index])
+ << "Got wrong type for FastParseExample return value " << d
+ << " (expected "
+ << DataTypeString(output_dtypes()[output_index]) << ", got "
+ << DataTypeString(example_result.dense_values[d].dtype())
+ << ").";
+ CHECK(output_shapes()[output_index].IsCompatibleWith(
+ example_result.dense_values[d].shape()))
+ << "Got wrong shape for FastParseExample return value " << d
+ << " (expected "
+ << output_shapes()[output_index].DebugString() << ", got "
+ << example_result.dense_values[d].shape().DebugString()
+ << ").";
+ (*result)[output_index] = example_result.dense_values[d];
+ }
+ for (int d = 0; d < sparse_keys_.size(); ++d) {
+ Tensor serialized_sparse = Tensor(DT_VARIANT, TensorShape({3}));
+ auto serialized_sparse_t = serialized_sparse.vec<Variant>();
+ serialized_sparse_t(0) = example_result.sparse_indices[d];
+ serialized_sparse_t(1) = example_result.sparse_values[d];
+ serialized_sparse_t(2) = example_result.sparse_shapes[d];
+ int output_index = key_to_output_index_.at(sparse_keys_[d]);
+ CHECK(serialized_sparse.dtype() == output_dtypes()[output_index])
+ << "Got wrong type for FastParseExample return value " << d
+ << " (expected "
+ << DataTypeString(output_dtypes()[output_index]) << ", got "
+ << DataTypeString(serialized_sparse.dtype()) << ").";
+ CHECK(output_shapes()[output_index].IsCompatibleWith(
+ serialized_sparse.shape()))
+ << "Got wrong shape for FastParseExample return value " << d
+ << " (expected "
+ << output_shapes()[output_index].DebugString() << ", got "
+ << serialized_sparse.shape().DebugString() << ").";
+ (*result)[output_index] = serialized_sparse;
+ }
+ }
+ done(s);
+ });
+ };
+
+ return NewParallelMapIterator(
+ {this, strings::StrCat(prefix, "::ParseExample")}, input_,
+ std::move(map_fn), num_parallel_calls_);
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return output_types_;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+
+ string DebugString() const override {
+ return "ParseExampleDatasetOp::Dataset";
+ }
+
+ // TODO(b/111553342): Add/Check support for checkpointing.
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
+
+ Node* num_parallle_calls_node;
+ std::vector<Node*> dense_defaults_nodes;
+ dense_defaults_nodes.reserve(dense_defaults_.size());
+
+ TF_RETURN_IF_ERROR(
+ b->AddScalar(num_parallel_calls_, &num_parallle_calls_node));
+
+ for (const Tensor& dense_default : dense_defaults_) {
+ Node* node;
+ TF_RETURN_IF_ERROR(b->AddTensor(dense_default, &node));
+ dense_defaults_nodes.emplace_back(node);
+ }
+
+ AttrValue sparse_keys_attr;
+ AttrValue dense_keys_attr;
+ AttrValue sparse_types_attr;
+ AttrValue dense_attr;
+ AttrValue dense_shapes_attr;
+
+ b->BuildAttrValue(sparse_keys_, &sparse_keys_attr);
+ b->BuildAttrValue(dense_keys_, &dense_keys_attr);
+ b->BuildAttrValue(sparse_types_, &sparse_types_attr);
+ b->BuildAttrValue(dense_types_, &dense_attr);
+ b->BuildAttrValue(dense_shapes_, &dense_shapes_attr);
+
+ TF_RETURN_IF_ERROR(b->AddDataset(this,
+ {
+ {0, input_graph_node},
+ {1, num_parallle_calls_node},
+ },
+ {{2, dense_defaults_nodes}},
+ {{"sparse_keys", sparse_keys_attr},
+ {"dense_keys", dense_keys_attr},
+ {"sparse_types", sparse_types_attr},
+ {"Tdense", dense_attr},
+ {"dense_shapes", dense_shapes_attr}},
+ output));
+ return Status::OK();
+ }
+
+ private:
+ const DatasetBase* const input_;
+ thread::ThreadPool* const device_threadpool_;
+ const std::vector<Tensor> dense_defaults_;
+ const std::vector<string> sparse_keys_;
+ const std::vector<string> dense_keys_;
+ const std::map<string, int> key_to_output_index_;
+ const example::FastParseExampleConfig config_;
+ const int64 num_parallel_calls_;
+ const DataTypeVector sparse_types_;
+ const DataTypeVector dense_types_;
+ const std::vector<PartialTensorShape> dense_shapes_;
+ const DataTypeVector output_types_;
+ const std::vector<PartialTensorShape> output_shapes_;
+ };
+
+ const int graph_def_version_;
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+ std::vector<string> sparse_keys_;
+ std::vector<string> dense_keys_;
+ DataTypeVector sparse_types_;
+ DataTypeVector dense_types_;
+ std::vector<PartialTensorShape> dense_shapes_;
+ std::vector<bool> variable_length_;
+ std::vector<std::size_t> elements_per_stride_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("ParseExampleDataset").Device(DEVICE_CPU),
+ ParseExampleDatasetOp);
+
+} // namespace
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 13733d48f0..07e735c7cb 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -166,6 +166,23 @@ REGISTER_OP("LatencyStatsDataset")
return shape_inference::ScalarShape(c);
});
+REGISTER_OP("ParseExampleDataset")
+ .Input("input_dataset: variant")
+ .Input("num_parallel_calls: int64")
+ .Input("dense_defaults: Tdense")
+
+ .Output("handle: variant")
+ .Attr("sparse_keys: list(string) >= 0")
+ .Attr("dense_keys: list(string) >= 0")
+ .Attr("sparse_types: list({float,int64,string}) >= 0")
+ .Attr("Tdense: list({float,int64,string}) >= 0")
+ .Attr("dense_shapes: list(shape) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1") // Output components will be
+ // sorted by key (dense_keys and
+ // sparse_keys combined) here.
+ .SetShapeFn(shape_inference::ScalarShape);
+
REGISTER_OP("FeatureStatsDataset")
.Input("input_dataset: variant")
.Input("tag: string")
diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py
index d8d9af545f..6041e2a0c5 100644
--- a/tensorflow/python/ops/parsing_ops.py
+++ b/tensorflow/python/ops/parsing_ops.py
@@ -629,76 +629,12 @@ def _parse_example_raw(serialized,
Returns:
A `dict` mapping keys to `Tensor`s and `SparseTensor`s.
- Raises:
- ValueError: If sparse and dense key sets intersect, or input lengths do not
- match up.
"""
with ops.name_scope(name, "ParseExample", [serialized, names]):
- names = [] if names is None else names
- dense_defaults = collections.OrderedDict(
- ) if dense_defaults is None else dense_defaults
- sparse_keys = [] if sparse_keys is None else sparse_keys
- sparse_types = [] if sparse_types is None else sparse_types
- dense_keys = [] if dense_keys is None else dense_keys
- dense_types = [] if dense_types is None else dense_types
- dense_shapes = (
- [[]] * len(dense_keys) if dense_shapes is None else dense_shapes)
-
- num_dense = len(dense_keys)
- num_sparse = len(sparse_keys)
-
- if len(dense_shapes) != num_dense:
- raise ValueError("len(dense_shapes) != len(dense_keys): %d vs. %d"
- % (len(dense_shapes), num_dense))
- if len(dense_types) != num_dense:
- raise ValueError("len(dense_types) != len(num_dense): %d vs. %d"
- % (len(dense_types), num_dense))
- if len(sparse_types) != num_sparse:
- raise ValueError("len(sparse_types) != len(sparse_keys): %d vs. %d"
- % (len(sparse_types), num_sparse))
- if num_dense + num_sparse == 0:
- raise ValueError("Must provide at least one sparse key or dense key")
- if not set(dense_keys).isdisjoint(set(sparse_keys)):
- raise ValueError(
- "Dense and sparse keys must not intersect; intersection: %s" %
- set(dense_keys).intersection(set(sparse_keys)))
-
- # Convert dense_shapes to TensorShape object.
- dense_shapes = [tensor_shape.as_shape(shape) for shape in dense_shapes]
-
- dense_defaults_vec = []
- for i, key in enumerate(dense_keys):
- default_value = dense_defaults.get(key)
- dense_shape = dense_shapes[i]
- if (dense_shape.ndims is not None and dense_shape.ndims > 0 and
- dense_shape[0].value is None):
- # Variable stride dense shape, the default value should be a
- # scalar padding value
- if default_value is None:
- default_value = ops.convert_to_tensor(
- "" if dense_types[i] == dtypes.string else 0,
- dtype=dense_types[i])
- else:
- # Reshape to a scalar to ensure user gets an error if they
- # provide a tensor that's not intended to be a padding value
- # (0 or 2+ elements).
- key_name = "padding_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key)
- default_value = ops.convert_to_tensor(
- default_value, dtype=dense_types[i], name=key_name)
- default_value = array_ops.reshape(default_value, [])
- else:
- if default_value is None:
- default_value = constant_op.constant([], dtype=dense_types[i])
- elif not isinstance(default_value, ops.Tensor):
- key_name = "key_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key)
- default_value = ops.convert_to_tensor(
- default_value, dtype=dense_types[i], name=key_name)
- default_value = array_ops.reshape(default_value, dense_shape)
-
- dense_defaults_vec.append(default_value)
-
- # Finally, convert dense_shapes to TensorShapeProto
- dense_shapes = [shape.as_proto() for shape in dense_shapes]
+ (names, dense_defaults_vec, sparse_keys, sparse_types,
+ dense_keys, dense_shapes, _) = _process_raw_parameters(
+ names, dense_defaults, sparse_keys, sparse_types, dense_keys,
+ dense_types, dense_shapes)
outputs = gen_parsing_ops.parse_example(
serialized=serialized,
@@ -719,6 +655,112 @@ def _parse_example_raw(serialized,
return dict(zip(sparse_keys + dense_keys, sparse_tensors + dense_values))
+def _process_raw_parameters(names, dense_defaults, sparse_keys, sparse_types,
+ dense_keys, dense_types, dense_shapes):
+ """Process raw parameters to params used by `gen_parsing_ops`.
+
+ Args:
+ names: A vector (1-D Tensor) of strings (optional), the names of
+ the serialized protos.
+ dense_defaults: A dict mapping string keys to `Tensor`s.
+ The keys of the dict must match the dense_keys of the feature.
+ sparse_keys: A list of string keys in the examples' features.
+ The results for these keys will be returned as `SparseTensor` objects.
+ sparse_types: A list of `DTypes` of the same length as `sparse_keys`.
+ Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`),
+ and `tf.string` (`BytesList`) are supported.
+ dense_keys: A list of string keys in the examples' features.
+ The results for these keys will be returned as `Tensor`s
+ dense_types: A list of DTypes of the same length as `dense_keys`.
+ Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`),
+ and `tf.string` (`BytesList`) are supported.
+ dense_shapes: A list of tuples with the same length as `dense_keys`.
+ The shape of the data for each dense feature referenced by `dense_keys`.
+ Required for any input tensors identified by `dense_keys`. Must be
+ either fully defined, or may contain an unknown first dimension.
+ An unknown first dimension means the feature is treated as having
+ a variable number of blocks, and the output shape along this dimension
+ is considered unknown at graph build time. Padding is applied for
+ minibatch elements smaller than the maximum number of blocks for the
+ given feature along this dimension.
+
+ Returns:
+ Tuple of `names`, `dense_defaults_vec`, `sparse_keys`, `sparse_types`,
+ `dense_keys`, `dense_shapes`.
+
+ Raises:
+ ValueError: If sparse and dense key sets intersect, or input lengths do not
+ match up.
+ """
+ names = [] if names is None else names
+ dense_defaults = collections.OrderedDict(
+ ) if dense_defaults is None else dense_defaults
+ sparse_keys = [] if sparse_keys is None else sparse_keys
+ sparse_types = [] if sparse_types is None else sparse_types
+ dense_keys = [] if dense_keys is None else dense_keys
+ dense_types = [] if dense_types is None else dense_types
+ dense_shapes = ([[]] * len(dense_keys)
+ if dense_shapes is None else dense_shapes)
+
+ num_dense = len(dense_keys)
+ num_sparse = len(sparse_keys)
+
+ if len(dense_shapes) != num_dense:
+ raise ValueError("len(dense_shapes) != len(dense_keys): %d vs. %d" %
+ (len(dense_shapes), num_dense))
+ if len(dense_types) != num_dense:
+ raise ValueError("len(dense_types) != len(num_dense): %d vs. %d" %
+ (len(dense_types), num_dense))
+ if len(sparse_types) != num_sparse:
+ raise ValueError("len(sparse_types) != len(sparse_keys): %d vs. %d" %
+ (len(sparse_types), num_sparse))
+ if num_dense + num_sparse == 0:
+ raise ValueError("Must provide at least one sparse key or dense key")
+ if not set(dense_keys).isdisjoint(set(sparse_keys)):
+ raise ValueError(
+ "Dense and sparse keys must not intersect; intersection: %s" %
+ set(dense_keys).intersection(set(sparse_keys)))
+
+ # Convert dense_shapes to TensorShape object.
+ dense_shapes = [tensor_shape.as_shape(shape) for shape in dense_shapes]
+
+ dense_defaults_vec = []
+ for i, key in enumerate(dense_keys):
+ default_value = dense_defaults.get(key)
+ dense_shape = dense_shapes[i]
+ if (dense_shape.ndims is not None and dense_shape.ndims > 0 and
+ dense_shape[0].value is None):
+ # Variable stride dense shape, the default value should be a
+ # scalar padding value
+ if default_value is None:
+ default_value = ops.convert_to_tensor(
+ "" if dense_types[i] == dtypes.string else 0, dtype=dense_types[i])
+ else:
+ # Reshape to a scalar to ensure user gets an error if they
+ # provide a tensor that's not intended to be a padding value
+ # (0 or 2+ elements).
+ key_name = "padding_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key)
+ default_value = ops.convert_to_tensor(
+ default_value, dtype=dense_types[i], name=key_name)
+ default_value = array_ops.reshape(default_value, [])
+ else:
+ if default_value is None:
+ default_value = constant_op.constant([], dtype=dense_types[i])
+ elif not isinstance(default_value, ops.Tensor):
+ key_name = "key_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key)
+ default_value = ops.convert_to_tensor(
+ default_value, dtype=dense_types[i], name=key_name)
+ default_value = array_ops.reshape(default_value, dense_shape)
+
+ dense_defaults_vec.append(default_value)
+
+ # Finally, convert dense_shapes to TensorShapeProto
+ dense_shapes_as_proto = [shape.as_proto() for shape in dense_shapes]
+
+ return (names, dense_defaults_vec, sparse_keys, sparse_types, dense_keys,
+ dense_shapes_as_proto, dense_shapes)
+
+
@tf_export("parse_single_example")
def parse_single_example(serialized, features, name=None, example_names=None):
"""Parses a single `Example` proto.