aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-10 10:32:55 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-10 10:37:18 -0700
commiteff5053cca2aebb5e296d20884ee18e8bfc49461 (patch)
tree0c4873d421dc66290e7a274a6bf5e38aab7bdd86
parentd2fa11acfd6b8a2a2663dc70cd899433bdde23e1 (diff)
Adding shape as alias for get_shape() for PartitionedVariable.
PiperOrigin-RevId: 203971287
-rw-r--r--tensorflow/python/kernel_tests/variables_test.py2
-rw-r--r--tensorflow/python/ops/variables.py4
2 files changed, 6 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index 62d596da91..2b9c62ad6f 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -642,6 +642,8 @@ class PartitionedVariableTest(test.TestCase):
iterated_partitions = list(partitioned_variable)
self.assertEqual(2, num_partitions)
self.assertEqual([v0, v1], iterated_partitions)
+ self.assertEqual([2], partitioned_variable.get_shape())
+ self.assertEqual([2], partitioned_variable.shape)
self.assertEqual([2], concatenated.get_shape())
self.assertEqual([2], concatenated.shape)
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 9a09cdaa52..d3b8da6d2a 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -1404,6 +1404,10 @@ class PartitionedVariable(object):
def dtype(self):
return self._dtype
+ @property
+ def shape(self):
+ return self.get_shape()
+
def get_shape(self):
return self._shape