diff options
author | 2018-08-18 17:30:54 +0800 | |
---|---|---|
committer | 2018-08-18 18:19:22 +0800 | |
commit | 46522659b41d7c2fe93fec54feb93a8e4b56505d (patch) | |
tree | 431bef19b5d22fa604ae1f693ba9b60413d5550f /tensorflow/contrib/framework | |
parent | 7a3ab5d6201b467c783e8d44e0b9180624e0dfbd (diff) |
BUG: fix for partial shape
Diffstat (limited to 'tensorflow/contrib/framework')
-rw-r--r-- | tensorflow/contrib/framework/python/framework/tensor_util.py | 33 |
1 files changed, 27 insertions, 6 deletions
diff --git a/tensorflow/contrib/framework/python/framework/tensor_util.py b/tensorflow/contrib/framework/python/framework/tensor_util.py index 4e6eea8884..c8fc1789c7 100644 --- a/tensorflow/contrib/framework/python/framework/tensor_util.py +++ b/tensorflow/contrib/framework/python/framework/tensor_util.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -129,10 +130,25 @@ def remove_squeezable_dimensions(predictions, labels, name=None): return predictions, labels -def _all_equal(tensor0, tensor1): - with ops.name_scope('all_equal', values=[tensor0, tensor1]) as scope: +def _shape_tensor_equal(expected_shape, actual_shape): + """Returns whether actual_shape is equal to expected_shape. + + Note that -1 in `expected_shape` is recognized as unknown dimension. + + Args: + expected_shape: Integer list defining the expected shape, or tensor of same. + actual_shape: Shape of the tensor to test. + Returns: + New tensor. + """ + with ops.name_scope('shape_tensor_equal', + values=[expected_shape, actual_shape]) as scope: return math_ops.reduce_all( - math_ops.equal(tensor0, tensor1, name='equal'), name=scope) + math_ops.logical_or( + math_ops.equal(expected_shape, -1), + math_ops.equal(expected_shape, actual_shape, 'equal'), + name='exclude_partial_shape'), + name=scope) def _is_rank(expected_rank, actual_tensor): @@ -153,6 +169,8 @@ def _is_rank(expected_rank, actual_tensor): def _is_shape(expected_shape, actual_tensor, actual_shape=None): """Returns whether actual_tensor's shape is expected_shape. + Note that -1 in `expected_shape` is recognized as unknown dimension. + Args: expected_shape: Integer list defining the expected shape, or tensor of same. actual_tensor: Tensor to test. @@ -164,15 +182,15 @@ def _is_shape(expected_shape, actual_tensor, actual_shape=None): is_rank = _is_rank(array_ops.size(expected_shape), actual_tensor) if actual_shape is None: actual_shape = array_ops.shape(actual_tensor, name='actual') - shape_equal = _all_equal( - ops.convert_to_tensor(expected_shape, name='expected'), - actual_shape) + shape_equal = _shape_tensor_equal(expected_shape, actual_shape) return math_ops.logical_and(is_rank, shape_equal, name=scope) def _assert_shape_op(expected_shape, actual_tensor): """Asserts actual_tensor's shape is expected_shape. + Note that unknown dimension in `expected_shape` will be ignored. + Args: expected_shape: List of integers defining the expected shape, or tensor of same. @@ -182,6 +200,9 @@ def _assert_shape_op(expected_shape, actual_tensor): """ with ops.name_scope('assert_shape', values=[actual_tensor]) as scope: actual_shape = array_ops.shape(actual_tensor, name='actual') + if (isinstance(expected_shape, tensor_shape.TensorShape) + and not expected_shape.is_fully_defined()): + expected_shape = [d if d else -1 for d in expected_shape.as_list()] is_shape = _is_shape(expected_shape, actual_tensor, actual_shape) return control_flow_ops.Assert( is_shape, [ |