diff options
-rw-r--r-- | tensorflow/contrib/data/python/ops/batching.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index a9535d9b83..0d942f33e6 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -647,15 +647,17 @@ def assert_element_shape(expected_shapes): """Assert the shape of this `Dataset`. ```python - shapes = [tf.TensorShape([16, 256]), tf.TensorShape(None)] + shapes = [tf.TensorShape([16, 256]), tf.TensorShape([None, 2])] result = dataset.apply(tf.contrib.data.assert_element_shape(shapes)) - print(result.output_shapes) # ==> "((16, 256), <unknown>)" + print(result.output_shapes) # ==> "((16, 256), (<unknown>, 2))" ``` If dataset shapes and expected_shape, are fully defined, assert they match. Otherwise, add assert op that will validate the shapes when tensors are evaluated, and set shapes on tensors, respectively. + Note that unknown dimension in `expected_shapes` will be ignored. + Args: expected_shapes: A nested structure of `tf.TensorShape` objects. |