diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2018-07-12 12:51:23 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-12 12:56:39 -0700 |
commit | 546322104425cc1cc70afeb7c0cfc1ec36ed0b41 (patch) | |
tree | ef89a93d4be694f5db6b829e3b74e796a01c643b /tensorflow | |
parent | 86f84b066706db97f7b3fd184249fdbd54abb05e (diff) |
Fix bug in SetResourceHandleShapeAndType.
Prior to this change, captured resource variables in TF functions (or
any captured resource tensors) would not have shape information.
PiperOrigin-RevId: 204347306
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/c/python_api.cc | 2 | ||||
-rw-r--r-- | tensorflow/python/framework/function_test.py | 16 |
2 files changed, 12 insertions, 6 deletions
diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index e18fdf6c57..8486b585c8 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -155,7 +155,7 @@ void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output, tensorflow::shape_inference::ShapeHandle shape; status->status = ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape); - if (status->status.ok()) return; + if (!status->status.ok()) return; shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype()); } ic->set_output_handle_shapes_and_types(output.index, shapes_and_types); diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index 15e41ba91f..1707f929b8 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -537,19 +537,25 @@ class FunctionTest(test.TestCase): def testResourceVarAsImplicitInput(self): g = ops.Graph() with g.as_default(), ops.device("cpu:0"): + expected_type = dtypes.float32 + expected_shape = tensor_shape.TensorShape((4, 4)) v = variable_scope.get_variable( - "var", (4, 4), dtypes.float32, use_resource=True) + "var", expected_shape, expected_type, use_resource=True) @function.Defun() def Foo(): - return array_ops.identity(v) + captured = array_ops.identity(v) + self.assertEqual(expected_type, captured.dtype) + self.assertEqual(expected_shape, captured.shape) + return captured, array_ops.shape(captured) - y = v.value() - z = Foo() + expected_val = v.value() + actual_val, actual_shape = Foo() with self.test_session(graph=g): v.initializer.run() - self.assertAllEqual(y.eval(), z.eval()) + self.assertAllEqual(expected_val.eval(), actual_val.eval()) + self.assertAllEqual(expected_shape, actual_shape.eval()) def testDefineErrors(self): with ops.Graph().as_default(): |