diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2017-02-07 16:45:06 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-02-07 17:09:35 -0800 |
commit | 780bc6b4d98665125c43685b20eeba6ad2804c0c (patch) | |
tree | 4acac8d596888cae078e520e65d836ff1a2c28d3 /tensorflow/python/ops/parsing_ops.py | |
parent | e6bfaf47374b44bb688023904eac98576baf4cd4 (diff) |
Add support for variable major dimension in dense features in example parser c++ op.
Full python support (including more comprehensive documentation) coming soon.
Change: 146852707
Diffstat (limited to 'tensorflow/python/ops/parsing_ops.py')
-rw-r--r-- | tensorflow/python/ops/parsing_ops.py | 48 |
1 files changed, 37 insertions, 11 deletions
diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py index 079837bce3..77c7cd397a 100644 --- a/tensorflow/python/ops/parsing_ops.py +++ b/tensorflow/python/ops/parsing_ops.py @@ -476,8 +476,13 @@ def _parse_example_raw(serialized, The keys of the dict must match the dense_keys of the feature. dense_shapes: A list of tuples with the same length as `dense_keys`. The shape of the data for each dense feature referenced by `dense_keys`. - Required for any input tensors identified by `dense_keys` whose shapes are - anything other than `[]` or `[1]`. + Required for any input tensors identified by `dense_keys`. Must be + either fully defined, or may contain an unknown first dimension. + An unknown first dimension means the feature is treated as having + a variable number of blocks, and the output shape along this dimension + is considered unknown at graph build time. Padding is applied for + minibatch elements smaller than the maximum number of blocks for the + given feature along this dimension. name: A name for this operation (optional). Returns: @@ -516,21 +521,42 @@ def _parse_example_raw(serialized, "Dense and sparse keys must not intersect; intersection: %s" % set(dense_keys).intersection(set(sparse_keys))) + # Convert dense_shapes to TensorShape object. + dense_shapes = [tensor_shape.as_shape(shape) for shape in dense_shapes] + dense_defaults_vec = [] for i, key in enumerate(dense_keys): default_value = dense_defaults.get(key) - if default_value is None: - default_value = constant_op.constant([], dtype=dense_types[i]) - elif not isinstance(default_value, ops.Tensor): - key_name = "key_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key) - default_value = ops.convert_to_tensor( - default_value, dtype=dense_types[i], name=key_name) - default_value = array_ops.reshape(default_value, dense_shapes[i]) + dense_shape = dense_shapes[i] + if (dense_shape.ndims is not None and dense_shape.ndims > 0 and + dense_shape[0].value is None): + # Variable stride dense shape, the default value should be a + # scalar padding value + if default_value is None: + default_value = ops.convert_to_tensor( + "" if dense_types[i] == dtypes.string else 0, + dtype=dense_types[i]) + else: + # Reshape to a scalar to ensure user gets an error if they + # provide a tensor that's not intended to be a padding value + # (0 or 2+ elements). + key_name = "padding_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key) + default_value = ops.convert_to_tensor( + default_value, dtype=dense_types[i], name=key_name) + default_value = array_ops.reshape(default_value, []) + else: + if default_value is None: + default_value = constant_op.constant([], dtype=dense_types[i]) + elif not isinstance(default_value, ops.Tensor): + key_name = "key_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key) + default_value = ops.convert_to_tensor( + default_value, dtype=dense_types[i], name=key_name) + default_value = array_ops.reshape(default_value, dense_shape) dense_defaults_vec.append(default_value) - dense_shapes = [tensor_shape.as_shape(shape).as_proto() - for shape in dense_shapes] + # Finally, convert dense_shapes to TensorShapeProto + dense_shapes = [shape.as_proto() for shape in dense_shapes] # pylint: disable=protected-access outputs = gen_parsing_ops._parse_example( |