aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/kernel_tests/variables_test.py2
-rw-r--r--tensorflow/python/ops/variables.py4
2 files changed, 6 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index 62d596da91..2b9c62ad6f 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -642,6 +642,8 @@ class PartitionedVariableTest(test.TestCase):
iterated_partitions = list(partitioned_variable)
self.assertEqual(2, num_partitions)
self.assertEqual([v0, v1], iterated_partitions)
+ self.assertEqual([2], partitioned_variable.get_shape())
+ self.assertEqual([2], partitioned_variable.shape)
self.assertEqual([2], concatenated.get_shape())
self.assertEqual([2], concatenated.shape)
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 9a09cdaa52..d3b8da6d2a 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -1404,6 +1404,10 @@ class PartitionedVariable(object):
def dtype(self):
return self._dtype
+ @property
+ def shape(self):
+ return self.get_shape()
+
def get_shape(self):
return self._shape