aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/framework
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-28 16:23:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-28 16:23:50 -0700
commitab4ae7e0cf029896a8f679a998f23763a8c2103d (patch)
tree12c1de52844c1f7b3d1c13912fdcdfdf91874a0f /tensorflow/contrib/framework
parent02163c55ae4e62495951e24c31e0a6ef96ab4e92 (diff)
parenta5559a9d28bab6abfd65a9fad116ef9c6e13f8c2 (diff)
Merge pull request #21702 from facaiy:ENH/assert_partial_shape
PiperOrigin-RevId: 210626817
Diffstat (limited to 'tensorflow/contrib/framework')
-rw-r--r--tensorflow/contrib/framework/python/framework/tensor_util.py33
-rw-r--r--tensorflow/contrib/framework/python/framework/tensor_util_test.py12
2 files changed, 38 insertions, 7 deletions
diff --git a/tensorflow/contrib/framework/python/framework/tensor_util.py b/tensorflow/contrib/framework/python/framework/tensor_util.py
index 4e6eea8884..bdf8aeb2b8 100644
--- a/tensorflow/contrib/framework/python/framework/tensor_util.py
+++ b/tensorflow/contrib/framework/python/framework/tensor_util.py
@@ -23,6 +23,7 @@ import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
@@ -129,10 +130,25 @@ def remove_squeezable_dimensions(predictions, labels, name=None):
return predictions, labels
-def _all_equal(tensor0, tensor1):
- with ops.name_scope('all_equal', values=[tensor0, tensor1]) as scope:
+def _shape_tensor_compatible(expected_shape, actual_shape):
+ """Returns whether actual_shape is compatible with expected_shape.
+
+ Note that -1 in `expected_shape` is recognized as unknown dimension.
+
+ Args:
+ expected_shape: Integer list defining the expected shape, or tensor of same.
+ actual_shape: Shape of the tensor to test.
+ Returns:
+ New tensor.
+ """
+ with ops.name_scope('shape_tensor_equal',
+ values=[expected_shape, actual_shape]) as scope:
return math_ops.reduce_all(
- math_ops.equal(tensor0, tensor1, name='equal'), name=scope)
+ math_ops.logical_or(
+ math_ops.equal(expected_shape, -1),
+ math_ops.equal(expected_shape, actual_shape, 'equal'),
+ name='exclude_partial_shape'),
+ name=scope)
def _is_rank(expected_rank, actual_tensor):
@@ -153,6 +169,8 @@ def _is_rank(expected_rank, actual_tensor):
def _is_shape(expected_shape, actual_tensor, actual_shape=None):
"""Returns whether actual_tensor's shape is expected_shape.
+ Note that -1 in `expected_shape` is recognized as unknown dimension.
+
Args:
expected_shape: Integer list defining the expected shape, or tensor of same.
actual_tensor: Tensor to test.
@@ -164,15 +182,15 @@ def _is_shape(expected_shape, actual_tensor, actual_shape=None):
is_rank = _is_rank(array_ops.size(expected_shape), actual_tensor)
if actual_shape is None:
actual_shape = array_ops.shape(actual_tensor, name='actual')
- shape_equal = _all_equal(
- ops.convert_to_tensor(expected_shape, name='expected'),
- actual_shape)
+ shape_equal = _shape_tensor_compatible(expected_shape, actual_shape)
return math_ops.logical_and(is_rank, shape_equal, name=scope)
def _assert_shape_op(expected_shape, actual_tensor):
"""Asserts actual_tensor's shape is expected_shape.
+ Note that unknown dimension in `expected_shape` will be ignored.
+
Args:
expected_shape: List of integers defining the expected shape, or tensor of
same.
@@ -182,6 +200,9 @@ def _assert_shape_op(expected_shape, actual_tensor):
"""
with ops.name_scope('assert_shape', values=[actual_tensor]) as scope:
actual_shape = array_ops.shape(actual_tensor, name='actual')
+ if (isinstance(expected_shape, tensor_shape.TensorShape)
+ and not expected_shape.is_fully_defined()):
+ expected_shape = [d if d else -1 for d in expected_shape.as_list()]
is_shape = _is_shape(expected_shape, actual_tensor, actual_shape)
return control_flow_ops.Assert(
is_shape, [
diff --git a/tensorflow/contrib/framework/python/framework/tensor_util_test.py b/tensorflow/contrib/framework/python/framework/tensor_util_test.py
index 9db2670304..2479fe5b8d 100644
--- a/tensorflow/contrib/framework/python/framework/tensor_util_test.py
+++ b/tensorflow/contrib/framework/python/framework/tensor_util_test.py
@@ -29,7 +29,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables as variables_lib
@@ -185,6 +185,16 @@ class WithShapeTest(test.TestCase):
shape,
unexpected_shapes)
+ def test_with_shape_2x2_with_partial_expected_shape(self):
+ with self.test_session():
+ value = [[42, 43], [44, 45]]
+ actual_shape = [2, 2]
+ tensor = constant_op.constant(value, shape=actual_shape)
+ partial_expected_shape = tensor_shape.TensorShape([None, 2])
+ # Won't raise any exception here:
+ tensor_with_shape = tensor_util.with_shape(partial_expected_shape, tensor)
+ np.testing.assert_array_equal(value, tensor_with_shape.eval())
+
def test_with_shape_none(self):
with self.test_session():
tensor_no_shape = array_ops.placeholder(dtypes.float32)