aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/eager/graph_callable.py3
-rw-r--r--tensorflow/python/eager/graph_callable_test.py10
2 files changed, 12 insertions, 1 deletions
diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py
index 64d1659993..e3aacbd140 100644
--- a/tensorflow/python/eager/graph_callable.py
+++ b/tensorflow/python/eager/graph_callable.py
@@ -28,6 +28,7 @@ from tensorflow.python.eager import tape
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops as tf_ops
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
@@ -54,7 +55,7 @@ class _VariableFromResource(resource_variable_ops.ResourceVariable):
def __init__(self, resource, dtype, name, shape):
self._handle = resource
- self._graph_shape = shape
+ self._graph_shape = tensor_shape.as_shape(shape)
self._handle_device = resource.device
self._handle_name = name
self._cached_value = None
diff --git a/tensorflow/python/eager/graph_callable_test.py b/tensorflow/python/eager/graph_callable_test.py
index 4ad8f1f36e..104e019391 100644
--- a/tensorflow/python/eager/graph_callable_test.py
+++ b/tensorflow/python/eager/graph_callable_test.py
@@ -22,6 +22,7 @@ from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
@@ -209,6 +210,15 @@ class GraphCallableTest(test.TestCase):
ret = my_op(inputs)
self.assertEqual(ret[1].numpy(), 11.)
+ def testVariableShapeIsTensorShape(self):
+ @graph_callable.graph_callable([])
+ def my_function():
+ v = variable_scope.get_variable(
+ "v", initializer=init_ops.zeros_initializer(), shape=())
+ self.assertIsInstance(v.get_shape(), tensor_shape.TensorShape)
+
+ my_function()
+
if __name__ == "__main__":
test.main()