diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-28 16:23:50 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-28 16:23:50 -0700 |
commit | ab4ae7e0cf029896a8f679a998f23763a8c2103d (patch) | |
tree | 12c1de52844c1f7b3d1c13912fdcdfdf91874a0f /tensorflow/contrib/framework | |
parent | 02163c55ae4e62495951e24c31e0a6ef96ab4e92 (diff) | |
parent | a5559a9d28bab6abfd65a9fad116ef9c6e13f8c2 (diff) |
Merge pull request #21702 from facaiy:ENH/assert_partial_shape
PiperOrigin-RevId: 210626817
Diffstat (limited to 'tensorflow/contrib/framework')
-rw-r--r-- | tensorflow/contrib/framework/python/framework/tensor_util.py | 33 | ||||
-rw-r--r-- | tensorflow/contrib/framework/python/framework/tensor_util_test.py | 12 |
2 files changed, 38 insertions, 7 deletions
diff --git a/tensorflow/contrib/framework/python/framework/tensor_util.py b/tensorflow/contrib/framework/python/framework/tensor_util.py index 4e6eea8884..bdf8aeb2b8 100644 --- a/tensorflow/contrib/framework/python/framework/tensor_util.py +++ b/tensorflow/contrib/framework/python/framework/tensor_util.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -129,10 +130,25 @@ def remove_squeezable_dimensions(predictions, labels, name=None): return predictions, labels -def _all_equal(tensor0, tensor1): - with ops.name_scope('all_equal', values=[tensor0, tensor1]) as scope: +def _shape_tensor_compatible(expected_shape, actual_shape): + """Returns whether actual_shape is compatible with expected_shape. + + Note that -1 in `expected_shape` is recognized as unknown dimension. + + Args: + expected_shape: Integer list defining the expected shape, or tensor of same. + actual_shape: Shape of the tensor to test. + Returns: + New tensor. + """ + with ops.name_scope('shape_tensor_equal', + values=[expected_shape, actual_shape]) as scope: return math_ops.reduce_all( - math_ops.equal(tensor0, tensor1, name='equal'), name=scope) + math_ops.logical_or( + math_ops.equal(expected_shape, -1), + math_ops.equal(expected_shape, actual_shape, 'equal'), + name='exclude_partial_shape'), + name=scope) def _is_rank(expected_rank, actual_tensor): @@ -153,6 +169,8 @@ def _is_rank(expected_rank, actual_tensor): def _is_shape(expected_shape, actual_tensor, actual_shape=None): """Returns whether actual_tensor's shape is expected_shape. + Note that -1 in `expected_shape` is recognized as unknown dimension. + Args: expected_shape: Integer list defining the expected shape, or tensor of same. actual_tensor: Tensor to test. @@ -164,15 +182,15 @@ def _is_shape(expected_shape, actual_tensor, actual_shape=None): is_rank = _is_rank(array_ops.size(expected_shape), actual_tensor) if actual_shape is None: actual_shape = array_ops.shape(actual_tensor, name='actual') - shape_equal = _all_equal( - ops.convert_to_tensor(expected_shape, name='expected'), - actual_shape) + shape_equal = _shape_tensor_compatible(expected_shape, actual_shape) return math_ops.logical_and(is_rank, shape_equal, name=scope) def _assert_shape_op(expected_shape, actual_tensor): """Asserts actual_tensor's shape is expected_shape. + Note that unknown dimension in `expected_shape` will be ignored. + Args: expected_shape: List of integers defining the expected shape, or tensor of same. @@ -182,6 +200,9 @@ def _assert_shape_op(expected_shape, actual_tensor): """ with ops.name_scope('assert_shape', values=[actual_tensor]) as scope: actual_shape = array_ops.shape(actual_tensor, name='actual') + if (isinstance(expected_shape, tensor_shape.TensorShape) + and not expected_shape.is_fully_defined()): + expected_shape = [d if d else -1 for d in expected_shape.as_list()] is_shape = _is_shape(expected_shape, actual_tensor, actual_shape) return control_flow_ops.Assert( is_shape, [ diff --git a/tensorflow/contrib/framework/python/framework/tensor_util_test.py b/tensorflow/contrib/framework/python/framework/tensor_util_test.py index 9db2670304..2479fe5b8d 100644 --- a/tensorflow/contrib/framework/python/framework/tensor_util_test.py +++ b/tensorflow/contrib/framework/python/framework/tensor_util_test.py @@ -29,7 +29,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import variables as variables_lib @@ -185,6 +185,16 @@ class WithShapeTest(test.TestCase): shape, unexpected_shapes) + def test_with_shape_2x2_with_partial_expected_shape(self): + with self.test_session(): + value = [[42, 43], [44, 45]] + actual_shape = [2, 2] + tensor = constant_op.constant(value, shape=actual_shape) + partial_expected_shape = tensor_shape.TensorShape([None, 2]) + # Won't raise any exception here: + tensor_with_shape = tensor_util.with_shape(partial_expected_shape, tensor) + np.testing.assert_array_equal(value, tensor_with_shape.eval()) + def test_with_shape_none(self): with self.test_session(): tensor_no_shape = array_ops.placeholder(dtypes.float32) |