aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/framework
diff options
context:
space:
mode:
authorGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-08-18 17:30:54 +0800
committerGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-08-18 18:19:22 +0800
commit46522659b41d7c2fe93fec54feb93a8e4b56505d (patch)
tree431bef19b5d22fa604ae1f693ba9b60413d5550f /tensorflow/contrib/framework
parent7a3ab5d6201b467c783e8d44e0b9180624e0dfbd (diff)
BUG: fix for partial shape
Diffstat (limited to 'tensorflow/contrib/framework')
-rw-r--r--tensorflow/contrib/framework/python/framework/tensor_util.py33
1 files changed, 27 insertions, 6 deletions
diff --git a/tensorflow/contrib/framework/python/framework/tensor_util.py b/tensorflow/contrib/framework/python/framework/tensor_util.py
index 4e6eea8884..c8fc1789c7 100644
--- a/tensorflow/contrib/framework/python/framework/tensor_util.py
+++ b/tensorflow/contrib/framework/python/framework/tensor_util.py
@@ -23,6 +23,7 @@ import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
@@ -129,10 +130,25 @@ def remove_squeezable_dimensions(predictions, labels, name=None):
return predictions, labels
-def _all_equal(tensor0, tensor1):
- with ops.name_scope('all_equal', values=[tensor0, tensor1]) as scope:
+def _shape_tensor_equal(expected_shape, actual_shape):
+ """Returns whether actual_shape is equal to expected_shape.
+
+ Note that -1 in `expected_shape` is recognized as unknown dimension.
+
+ Args:
+ expected_shape: Integer list defining the expected shape, or tensor of same.
+ actual_shape: Shape of the tensor to test.
+ Returns:
+ New tensor.
+ """
+ with ops.name_scope('shape_tensor_equal',
+ values=[expected_shape, actual_shape]) as scope:
return math_ops.reduce_all(
- math_ops.equal(tensor0, tensor1, name='equal'), name=scope)
+ math_ops.logical_or(
+ math_ops.equal(expected_shape, -1),
+ math_ops.equal(expected_shape, actual_shape, 'equal'),
+ name='exclude_partial_shape'),
+ name=scope)
def _is_rank(expected_rank, actual_tensor):
@@ -153,6 +169,8 @@ def _is_rank(expected_rank, actual_tensor):
def _is_shape(expected_shape, actual_tensor, actual_shape=None):
"""Returns whether actual_tensor's shape is expected_shape.
+ Note that -1 in `expected_shape` is recognized as unknown dimension.
+
Args:
expected_shape: Integer list defining the expected shape, or tensor of same.
actual_tensor: Tensor to test.
@@ -164,15 +182,15 @@ def _is_shape(expected_shape, actual_tensor, actual_shape=None):
is_rank = _is_rank(array_ops.size(expected_shape), actual_tensor)
if actual_shape is None:
actual_shape = array_ops.shape(actual_tensor, name='actual')
- shape_equal = _all_equal(
- ops.convert_to_tensor(expected_shape, name='expected'),
- actual_shape)
+ shape_equal = _shape_tensor_equal(expected_shape, actual_shape)
return math_ops.logical_and(is_rank, shape_equal, name=scope)
def _assert_shape_op(expected_shape, actual_tensor):
"""Asserts actual_tensor's shape is expected_shape.
+ Note that unknown dimension in `expected_shape` will be ignored.
+
Args:
expected_shape: List of integers defining the expected shape, or tensor of
same.
@@ -182,6 +200,9 @@ def _assert_shape_op(expected_shape, actual_tensor):
"""
with ops.name_scope('assert_shape', values=[actual_tensor]) as scope:
actual_shape = array_ops.shape(actual_tensor, name='actual')
+ if (isinstance(expected_shape, tensor_shape.TensorShape)
+ and not expected_shape.is_fully_defined()):
+ expected_shape = [d if d else -1 for d in expected_shape.as_list()]
is_shape = _is_shape(expected_shape, actual_tensor, actual_shape)
return control_flow_ops.Assert(
is_shape, [