diff options
-rw-r--r-- | tensorflow/python/eager/graph_callable.py | 3 | ||||
-rw-r--r-- | tensorflow/python/eager/graph_callable_test.py | 10 |
2 files changed, 12 insertions, 1 deletions
diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py index 64d1659993..e3aacbd140 100644 --- a/tensorflow/python/eager/graph_callable.py +++ b/tensorflow/python/eager/graph_callable.py @@ -28,6 +28,7 @@ from tensorflow.python.eager import tape from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops as tf_ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope @@ -54,7 +55,7 @@ class _VariableFromResource(resource_variable_ops.ResourceVariable): def __init__(self, resource, dtype, name, shape): self._handle = resource - self._graph_shape = shape + self._graph_shape = tensor_shape.as_shape(shape) self._handle_device = resource.device self._handle_name = name self._cached_value = None diff --git a/tensorflow/python/eager/graph_callable_test.py b/tensorflow/python/eager/graph_callable_test.py index 4ad8f1f36e..104e019391 100644 --- a/tensorflow/python/eager/graph_callable_test.py +++ b/tensorflow/python/eager/graph_callable_test.py @@ -22,6 +22,7 @@ from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope @@ -209,6 +210,15 @@ class GraphCallableTest(test.TestCase): ret = my_op(inputs) self.assertEqual(ret[1].numpy(), 11.) + def testVariableShapeIsTensorShape(self): + @graph_callable.graph_callable([]) + def my_function(): + v = variable_scope.get_variable( + "v", initializer=init_ops.zeros_initializer(), shape=()) + self.assertIsInstance(v.get_shape(), tensor_shape.TensorShape) + + my_function() + if __name__ == "__main__": test.main() |