aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-07-12 12:51:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-12 12:56:39 -0700
commit546322104425cc1cc70afeb7c0cfc1ec36ed0b41 (patch)
treeef89a93d4be694f5db6b829e3b74e796a01c643b /tensorflow
parent86f84b066706db97f7b3fd184249fdbd54abb05e (diff)
Fix bug in SetResourceHandleShapeAndType.
Prior to this change, captured resource variables in TF functions (or any captured resource tensors) would not have shape information. PiperOrigin-RevId: 204347306
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/c/python_api.cc2
-rw-r--r--tensorflow/python/framework/function_test.py16
2 files changed, 12 insertions, 6 deletions
diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc
index e18fdf6c57..8486b585c8 100644
--- a/tensorflow/c/python_api.cc
+++ b/tensorflow/c/python_api.cc
@@ -155,7 +155,7 @@ void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output,
tensorflow::shape_inference::ShapeHandle shape;
status->status =
ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
- if (status->status.ok()) return;
+ if (!status->status.ok()) return;
shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype());
}
ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 15e41ba91f..1707f929b8 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -537,19 +537,25 @@ class FunctionTest(test.TestCase):
def testResourceVarAsImplicitInput(self):
g = ops.Graph()
with g.as_default(), ops.device("cpu:0"):
+ expected_type = dtypes.float32
+ expected_shape = tensor_shape.TensorShape((4, 4))
v = variable_scope.get_variable(
- "var", (4, 4), dtypes.float32, use_resource=True)
+ "var", expected_shape, expected_type, use_resource=True)
@function.Defun()
def Foo():
- return array_ops.identity(v)
+ captured = array_ops.identity(v)
+ self.assertEqual(expected_type, captured.dtype)
+ self.assertEqual(expected_shape, captured.shape)
+ return captured, array_ops.shape(captured)
- y = v.value()
- z = Foo()
+ expected_val = v.value()
+ actual_val, actual_shape = Foo()
with self.test_session(graph=g):
v.initializer.run()
- self.assertAllEqual(y.eval(), z.eval())
+ self.assertAllEqual(expected_val.eval(), actual_val.eval())
+ self.assertAllEqual(expected_shape, actual_shape.eval())
def testDefineErrors(self):
with ops.Graph().as_default():