aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/tensor_util.py
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2018-04-06 21:00:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-06 21:03:10 -0700
commit273495dc2c957402f832cae31a438e550db2b7f0 (patch)
tree98691c91e0af5a5a7464ca0f2645b434160710fb /tensorflow/python/framework/tensor_util.py
parent7f97f1bf69765be51b9f79f5134eb44736d216eb (diff)
Improvements to ResourceVariable + Variant code.
* Works in graph + eager modes * Fixed shape inference * Updated shape inference + refiner + constant eval code to support static shape tensor of `-1` meaning unknown shape. * Gather and Scatter for Variants now properly supported. * Variable copy-on-write for Variants now does a more shallow copy (as Variants are not expected to be updated "in-place" inside a variable; instead Variants will be updated via read-update-write inside a CriticalSection) PiperOrigin-RevId: 191975898
Diffstat (limited to 'tensorflow/python/framework/tensor_util.py')
-rw-r--r--tensorflow/python/framework/tensor_util.py19
1 files changed, 17 insertions, 2 deletions
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
index 64b0fa6c00..8cf24206ed 100644
--- a/tensorflow/python/framework/tensor_util.py
+++ b/tensorflow/python/framework/tensor_util.py
@@ -822,17 +822,32 @@ def constant_value_as_shape(tensor): # pylint: disable=invalid-name
all-or-nothing.
Args:
- tensor: The rank-1 Tensor to be evaluated.
+ tensor: The rank-0 or rank-1 Tensor to be evaluated.
Returns:
A `TensorShape` based on the constant value of the given `tensor`.
+
+ Raises:
+ ValueError: If the shape is rank-0 and is not statically known to be -1.
"""
if isinstance(tensor, ops.EagerTensor):
return tensor_shape.as_shape(
[dim if dim != -1 else None for dim in tensor.numpy()])
+ if tensor.get_shape().ndims == 0:
+ value = constant_value(tensor)
+ if value is None:
+ raise ValueError(
+ "Received a scalar with unknown value as shape; require a statically "
+ "known scalar with value '-1' to describe an unknown shape.")
+ if value != -1:
+ raise ValueError(
+ "Received a scalar value '%s' as shape; require a statically known "
+ "scalar with value '-1' to describe an unknown shape." % value)
+ return tensor_shape.unknown_shape()
+
shape = tensor.get_shape().with_rank(1)
- if tensor.get_shape() == [0]:
+ if shape == [0]:
return tensor_shape.scalar()
elif tensor.op.type == "Shape":
return tensor.op.inputs[0].get_shape()