diff options
-rw-r--r-- | tensorflow/python/kernel_tests/variables_test.py | 2 | ||||
-rw-r--r-- | tensorflow/python/ops/variables.py | 4 |
2 files changed, 6 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py index 62d596da91..2b9c62ad6f 100644 --- a/tensorflow/python/kernel_tests/variables_test.py +++ b/tensorflow/python/kernel_tests/variables_test.py @@ -642,6 +642,8 @@ class PartitionedVariableTest(test.TestCase): iterated_partitions = list(partitioned_variable) self.assertEqual(2, num_partitions) self.assertEqual([v0, v1], iterated_partitions) + self.assertEqual([2], partitioned_variable.get_shape()) + self.assertEqual([2], partitioned_variable.shape) self.assertEqual([2], concatenated.get_shape()) self.assertEqual([2], concatenated.shape) diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 9a09cdaa52..d3b8da6d2a 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -1404,6 +1404,10 @@ class PartitionedVariable(object): def dtype(self): return self._dtype + @property + def shape(self): + return self.get_shape() + def get_shape(self): return self._shape |