aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-03-30 10:08:42 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-30 11:27:44 -0700
commit4d35ecb853f5e9c7d41df3629a31dcbc7b4032ca (patch)
treec6b90a8eb0625407b20768c5363e7d8e5ef45721
parentebfb29cc93b9ed6a651db36a84accf50574aac6a (diff)
Copy handle dtype and shape to the placeholders created inside Defun context
for captured external tensors. Also fix a minor bug in ResourceVariable that dtype from deserialization contains integer enum value rather than dtypes.DType instance. Change: 151723784
-rw-r--r--tensorflow/python/framework/function.py4
-rw-r--r--tensorflow/python/framework/function_test.py17
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py7
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py3
4 files changed, 30 insertions, 1 deletions
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index 169d53c45d..3c663c3a9b 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -342,6 +342,10 @@ class _FuncGraph(ops.Graph):
# Substitute with a placeholder.
self.extra_inputs.append(x)
ph = array_ops.placeholder(x.dtype, shape=x.get_shape())
+ # pylint: disable=protected-access
+ ph._handle_shape = x._handle_shape
+ ph._handle_dtype = x._handle_dtype
+ # pylint: enable=protected-access
inputs[i] = ph
self._captured[x] = ph
self.extra_args.append(ph)
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index d72ab5ae77..96bf7bde29 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -338,6 +338,23 @@ class FunctionTest(test.TestCase):
variables.global_variables_initializer().run()
self.assertAllEqual(z.eval(), 101.)
+ def testResourceVarAsImplicitInput(self):
+ g = ops.Graph()
+ with g.as_default():
+ v = variable_scope.get_variable(
+ "var", (4, 4), dtypes.float32, use_resource=True)
+
+ @function.Defun()
+ def Foo():
+ return array_ops.identity(v)
+
+ y = v.value()
+ z = Foo()
+
+ with self.test_session(graph=g):
+ v.initializer.run()
+ self.assertAllEqual(y.eval(), z.eval())
+
def testDefineErrors(self):
with ops.Graph().as_default():
with self.assertRaisesRegexp(ValueError, "can not return None"):
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index e69a420dba..0b81dcb8af 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -189,5 +189,12 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
feed_dict={placeholder: np.zeros(shape=[2, 2],
dtype=np.float32)})
+ def testDtypeAfterFromProto(self):
+ v = resource_variable_ops.ResourceVariable(2.0)
+ w = resource_variable_ops.ResourceVariable.from_proto(v.to_proto())
+ self.assertIsInstance(w.dtype, dtypes.DType)
+ self.assertEqual(v.dtype, w.dtype)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 43053fea11..e84fa21868 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -21,6 +21,7 @@ from __future__ import print_function
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import variable_pb2
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
@@ -232,7 +233,7 @@ class ResourceVariable(object):
else:
self._save_slice_info = None
self._caching_device = None
- self._dtype = self._handle.op.get_attr("dtype")
+ self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype"))
@property
def dtype(self):