aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data/python/ops/batching.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/data/python/ops/batching.py')
-rw-r--r--tensorflow/contrib/data/python/ops/batching.py27
1 files changed, 20 insertions, 7 deletions
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py
index 9f059942a6..367c159dc5 100644
--- a/tensorflow/contrib/data/python/ops/batching.py
+++ b/tensorflow/contrib/data/python/ops/batching.py
@@ -272,9 +272,9 @@ def _padded_batch_dense_window(dataset, padded_shape, padding_value=None):
padding_value = 0
def batch_init_fn(_):
- return array_ops.fill(
- array_ops.concat([np.array([0], dtype=np.int32), padded_shape], 0),
- constant_op.constant(padding_value, dtype=dataset.output_types))
+ batch_shape = array_ops.concat(
+ [np.array([0], dtype=np.int32), padded_shape], 0)
+ return gen_array_ops.empty(batch_shape, dtype=dataset.output_types)
def batch_reduce_fn(state, value):
return array_ops.concat([state, [value]], 0)
@@ -647,15 +647,17 @@ def assert_element_shape(expected_shapes):
"""Assert the shape of this `Dataset`.
```python
- shapes = [tf.TensorShape([16, 256]), tf.TensorShape(None)]
+ shapes = [tf.TensorShape([16, 256]), tf.TensorShape([None, 2])]
result = dataset.apply(tf.contrib.data.assert_element_shape(shapes))
- print(result.output_shapes) # ==> "((16, 256), <unknown>)"
+ print(result.output_shapes) # ==> "((16, 256), (<unknown>, 2))"
```
If dataset shapes and expected_shape, are fully defined, assert they match.
Otherwise, add assert op that will validate the shapes when tensors are
evaluated, and set shapes on tensors, respectively.
+ Note that unknown dimension in `expected_shapes` will be ignored.
+
Args:
expected_shapes: A nested structure of `tf.TensorShape` objects.
@@ -664,20 +666,31 @@ def assert_element_shape(expected_shapes):
`tf.data.Dataset.apply`
"""
+ def _merge_output_shapes(original_shapes, expected_shapes):
+ flat_original_shapes = nest.flatten(original_shapes)
+ flat_new_shapes = nest.flatten_up_to(original_shapes, expected_shapes)
+ flat_merged_output_shapes = [
+ original_shape.merge_with(new_shape)
+ for original_shape, new_shape in zip(flat_original_shapes,
+ flat_new_shapes)]
+ return nest.pack_sequence_as(original_shapes, flat_merged_output_shapes)
+
def _check_shape(*elements):
flatten_tensors = nest.flatten(elements)
flatten_shapes = nest.flatten(expected_shapes)
checked_tensors = [
- with_shape(shape, tensor)
+ with_shape(shape, tensor) if shape else tensor # Ignore unknown shape
for shape, tensor in zip(flatten_shapes, flatten_tensors)
]
return nest.pack_sequence_as(elements, checked_tensors)
def _apply_fn(dataset):
+ output_shapes = _merge_output_shapes(dataset.output_shapes,
+ expected_shapes)
return _RestructuredDataset(
dataset.map(_check_shape),
dataset.output_types,
- output_shapes=expected_shapes,
+ output_shapes=output_shapes,
output_classes=dataset.output_classes)
return _apply_fn