diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2018-04-06 21:00:42 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-06 21:03:10 -0700 |
commit | 273495dc2c957402f832cae31a438e550db2b7f0 (patch) | |
tree | 98691c91e0af5a5a7464ca0f2645b434160710fb /tensorflow/python/framework/tensor_util.py | |
parent | 7f97f1bf69765be51b9f79f5134eb44736d216eb (diff) |
Improvements to ResourceVariable + Variant code.
* Works in graph + eager modes
* Fixed shape inference
* Updated shape inference + refiner + constant eval code to support static shape tensor of `-1` meaning unknown shape.
* Gather and Scatter for Variants now properly supported.
* Variable copy-on-write for Variants now does a more shallow copy (as Variants are not expected to be updated "in-place" inside a variable; instead Variants will be updated via read-update-write inside a CriticalSection)
PiperOrigin-RevId: 191975898
Diffstat (limited to 'tensorflow/python/framework/tensor_util.py')
-rw-r--r-- | tensorflow/python/framework/tensor_util.py | 19 |
1 files changed, 17 insertions, 2 deletions
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index 64b0fa6c00..8cf24206ed 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -822,17 +822,32 @@ def constant_value_as_shape(tensor): # pylint: disable=invalid-name all-or-nothing. Args: - tensor: The rank-1 Tensor to be evaluated. + tensor: The rank-0 or rank-1 Tensor to be evaluated. Returns: A `TensorShape` based on the constant value of the given `tensor`. + + Raises: + ValueError: If the shape is rank-0 and is not statically known to be -1. """ if isinstance(tensor, ops.EagerTensor): return tensor_shape.as_shape( [dim if dim != -1 else None for dim in tensor.numpy()]) + if tensor.get_shape().ndims == 0: + value = constant_value(tensor) + if value is None: + raise ValueError( + "Received a scalar with unknown value as shape; require a statically " + "known scalar with value '-1' to describe an unknown shape.") + if value != -1: + raise ValueError( + "Received a scalar value '%s' as shape; require a statically known " + "scalar with value '-1' to describe an unknown shape." % value) + return tensor_shape.unknown_shape() + shape = tensor.get_shape().with_rank(1) - if tensor.get_shape() == [0]: + if shape == [0]: return tensor_shape.scalar() elif tensor.op.type == "Shape": return tensor.op.inputs[0].get_shape() |