From 4d35ecb853f5e9c7d41df3629a31dcbc7b4032ca Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 30 Mar 2017 10:08:42 -0800 Subject: Copy handle dtype and shape to the placeholders created inside Defun context for captured external tensors. Also fix a minor bug in ResourceVariable that dtype from deserialization contains integer enum value rather than dtypes.DType instance. Change: 151723784 --- tensorflow/python/framework/function.py | 4 ++++ tensorflow/python/framework/function_test.py | 17 +++++++++++++++++ .../python/kernel_tests/resource_variable_ops_test.py | 7 +++++++ tensorflow/python/ops/resource_variable_ops.py | 3 ++- 4 files changed, 30 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 169d53c45d..3c663c3a9b 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -342,6 +342,10 @@ class _FuncGraph(ops.Graph): # Substitute with a placeholder. self.extra_inputs.append(x) ph = array_ops.placeholder(x.dtype, shape=x.get_shape()) + # pylint: disable=protected-access + ph._handle_shape = x._handle_shape + ph._handle_dtype = x._handle_dtype + # pylint: enable=protected-access inputs[i] = ph self._captured[x] = ph self.extra_args.append(ph) diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index d72ab5ae77..96bf7bde29 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -338,6 +338,23 @@ class FunctionTest(test.TestCase): variables.global_variables_initializer().run() self.assertAllEqual(z.eval(), 101.) + def testResourceVarAsImplicitInput(self): + g = ops.Graph() + with g.as_default(): + v = variable_scope.get_variable( + "var", (4, 4), dtypes.float32, use_resource=True) + + @function.Defun() + def Foo(): + return array_ops.identity(v) + + y = v.value() + z = Foo() + + with self.test_session(graph=g): + v.initializer.run() + self.assertAllEqual(y.eval(), z.eval()) + def testDefineErrors(self): with ops.Graph().as_default(): with self.assertRaisesRegexp(ValueError, "can not return None"): diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index e69a420dba..0b81dcb8af 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -189,5 +189,12 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): feed_dict={placeholder: np.zeros(shape=[2, 2], dtype=np.float32)}) + def testDtypeAfterFromProto(self): + v = resource_variable_ops.ResourceVariable(2.0) + w = resource_variable_ops.ResourceVariable.from_proto(v.to_proto()) + self.assertIsInstance(w.dtype, dtypes.DType) + self.assertEqual(v.dtype, w.dtype) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 43053fea11..e84fa21868 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import variable_pb2 +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -232,7 +233,7 @@ class ResourceVariable(object): else: self._save_slice_info = None self._caching_device = None - self._dtype = self._handle.op.get_attr("dtype") + self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype")) @property def dtype(self): -- cgit v1.2.3