aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/parsing_ops.py
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-02-07 16:45:06 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-07 17:09:35 -0800
commit780bc6b4d98665125c43685b20eeba6ad2804c0c (patch)
tree4acac8d596888cae078e520e65d836ff1a2c28d3 /tensorflow/python/ops/parsing_ops.py
parente6bfaf47374b44bb688023904eac98576baf4cd4 (diff)
Add support for variable major dimension in dense features in example parser c++ op.
Full python support (including more comprehensive documentation) coming soon. Change: 146852707
Diffstat (limited to 'tensorflow/python/ops/parsing_ops.py')
-rw-r--r--tensorflow/python/ops/parsing_ops.py48
1 files changed, 37 insertions, 11 deletions
diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py
index 079837bce3..77c7cd397a 100644
--- a/tensorflow/python/ops/parsing_ops.py
+++ b/tensorflow/python/ops/parsing_ops.py
@@ -476,8 +476,13 @@ def _parse_example_raw(serialized,
The keys of the dict must match the dense_keys of the feature.
dense_shapes: A list of tuples with the same length as `dense_keys`.
The shape of the data for each dense feature referenced by `dense_keys`.
- Required for any input tensors identified by `dense_keys` whose shapes are
- anything other than `[]` or `[1]`.
+ Required for any input tensors identified by `dense_keys`. Must be
+ either fully defined, or may contain an unknown first dimension.
+ An unknown first dimension means the feature is treated as having
+ a variable number of blocks, and the output shape along this dimension
+ is considered unknown at graph build time. Padding is applied for
+ minibatch elements smaller than the maximum number of blocks for the
+ given feature along this dimension.
name: A name for this operation (optional).
Returns:
@@ -516,21 +521,42 @@ def _parse_example_raw(serialized,
"Dense and sparse keys must not intersect; intersection: %s" %
set(dense_keys).intersection(set(sparse_keys)))
+ # Convert dense_shapes to TensorShape object.
+ dense_shapes = [tensor_shape.as_shape(shape) for shape in dense_shapes]
+
dense_defaults_vec = []
for i, key in enumerate(dense_keys):
default_value = dense_defaults.get(key)
- if default_value is None:
- default_value = constant_op.constant([], dtype=dense_types[i])
- elif not isinstance(default_value, ops.Tensor):
- key_name = "key_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key)
- default_value = ops.convert_to_tensor(
- default_value, dtype=dense_types[i], name=key_name)
- default_value = array_ops.reshape(default_value, dense_shapes[i])
+ dense_shape = dense_shapes[i]
+ if (dense_shape.ndims is not None and dense_shape.ndims > 0 and
+ dense_shape[0].value is None):
+ # Variable stride dense shape, the default value should be a
+ # scalar padding value
+ if default_value is None:
+ default_value = ops.convert_to_tensor(
+ "" if dense_types[i] == dtypes.string else 0,
+ dtype=dense_types[i])
+ else:
+ # Reshape to a scalar to ensure user gets an error if they
+ # provide a tensor that's not intended to be a padding value
+ # (0 or 2+ elements).
+ key_name = "padding_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key)
+ default_value = ops.convert_to_tensor(
+ default_value, dtype=dense_types[i], name=key_name)
+ default_value = array_ops.reshape(default_value, [])
+ else:
+ if default_value is None:
+ default_value = constant_op.constant([], dtype=dense_types[i])
+ elif not isinstance(default_value, ops.Tensor):
+ key_name = "key_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key)
+ default_value = ops.convert_to_tensor(
+ default_value, dtype=dense_types[i], name=key_name)
+ default_value = array_ops.reshape(default_value, dense_shape)
dense_defaults_vec.append(default_value)
- dense_shapes = [tensor_shape.as_shape(shape).as_proto()
- for shape in dense_shapes]
+ # Finally, convert dense_shapes to TensorShapeProto
+ dense_shapes = [shape.as_proto() for shape in dense_shapes]
# pylint: disable=protected-access
outputs = gen_parsing_ops._parse_example(