diff options
author | 2018-07-10 10:32:55 -0700 | |
---|---|---|
committer | 2018-07-10 10:37:18 -0700 | |
commit | eff5053cca2aebb5e296d20884ee18e8bfc49461 (patch) | |
tree | 0c4873d421dc66290e7a274a6bf5e38aab7bdd86 | |
parent | d2fa11acfd6b8a2a2663dc70cd899433bdde23e1 (diff) |
Adding shape as alias for get_shape() for PartitionedVariable.
PiperOrigin-RevId: 203971287
-rw-r--r-- | tensorflow/python/kernel_tests/variables_test.py | 2 | ||||
-rw-r--r-- | tensorflow/python/ops/variables.py | 4 |
2 files changed, 6 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py index 62d596da91..2b9c62ad6f 100644 --- a/tensorflow/python/kernel_tests/variables_test.py +++ b/tensorflow/python/kernel_tests/variables_test.py @@ -642,6 +642,8 @@ class PartitionedVariableTest(test.TestCase): iterated_partitions = list(partitioned_variable) self.assertEqual(2, num_partitions) self.assertEqual([v0, v1], iterated_partitions) + self.assertEqual([2], partitioned_variable.get_shape()) + self.assertEqual([2], partitioned_variable.shape) self.assertEqual([2], concatenated.get_shape()) self.assertEqual([2], concatenated.shape) diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 9a09cdaa52..d3b8da6d2a 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -1404,6 +1404,10 @@ class PartitionedVariable(object): def dtype(self): return self._dtype + @property + def shape(self): + return self.get_shape() + def get_shape(self): return self._shape |