aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager
diff options
context:
space:
mode:
authorGravatar Saurabh Saxena <srbs@google.com>2018-09-19 18:27:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-19 18:30:56 -0700
commit9f05ca4ec89d9b03f740f881ae50d97d76a1b849 (patch)
treea585b3bcc287251896f211f2b2ed504f82352188 /tensorflow/python/eager
parent415455b0ef2d65504ab8c9084a6daa2899521212 (diff)
Copy Tensor._handle_data from external_capture to placeholder for Variant tensors in Graph mode defun.
This allows inferring the shape of values popped from TensorLists inside defuns. Remove "Resource" from {Set|Get}ResourceHandleShapeAndType since the same functions are re-usable for variants. Eager mode fix coming in a future changelist. PiperOrigin-RevId: 213735462
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r--tensorflow/python/eager/BUILD1
-rw-r--r--tensorflow/python/eager/function.py43
-rw-r--r--tensorflow/python/eager/function_test.py47
3 files changed, 75 insertions, 16 deletions
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index f80256fc2a..a2686c68a9 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -147,6 +147,7 @@ cuda_py_test(
"//tensorflow/python:clip_ops",
"//tensorflow/python:init_ops",
"//tensorflow/python:layers",
+ "//tensorflow/python:list_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:resource_variable_ops",
],
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index a68c6ab3b4..bcb1881264 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -73,16 +73,36 @@ def _create_substitute_placeholder(value, name=None, dtype=None):
with ops.control_dependencies(None):
placeholder = graph_placeholder(
dtype=dtype or value.dtype, shape=value.shape, name=name)
- if placeholder.dtype == dtypes_module.resource:
- if isinstance(value, ops.EagerTensor):
- handle_data = value._handle_data # pylint: disable=protected-access
+ _copy_handle_data(value, placeholder)
+ return placeholder
+
+
+def _copy_handle_data(source_t, target_t):
+ """Copies HandleData for variant and resource type tensors if available.
+
+ The CppShapeInferenceResult::HandleData proto contains information about the
+ shapes and types of the element tensors of resource/variant type tensors.
+ We need to copy this across function boundaries, i.e., when capturing a
+ placeholder or when returning a function tensor as output. If we don't do this
+ the element tensors will have unknown shapes, e.g., if a TensorList variant
+ tensor is captured as a placeholder, elements popped from that list would have
+ unknown shape.
+
+ Args:
+ source_t: The tensor to copy HandleData from.
+ target_t: The tensor to copy HandleData to.
+ """
+ if (target_t.dtype == dtypes_module.resource or
+ target_t.dtype == dtypes_module.variant):
+ if isinstance(source_t, ops.EagerTensor):
+ handle_data = source_t._handle_data # pylint: disable=protected-access
else:
- handle_data = resource_variable_ops.get_resource_handle_data(value)
+ handle_data = resource_variable_ops.get_resource_handle_data(source_t)
if handle_data is not None and handle_data.is_set:
# pylint: disable=protected-access
- pywrap_tensorflow.SetResourceHandleShapeAndType(
- placeholder.graph._c_graph, placeholder._as_tf_output(),
- handle_data.SerializeToString())
+ pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph,
+ target_t._as_tf_output(),
+ handle_data.SerializeToString())
# pylint: enable=protected-access
# Ensure that shapes and dtypes are propagated.
shapes, types = zip(*[(pair.shape, pair.dtype)
@@ -91,12 +111,10 @@ def _create_substitute_placeholder(value, name=None, dtype=None):
shapes = [[d.size for d in s.dim]
if not s.unknown_rank else None for s in shapes]
pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
- placeholder._op._graph._c_graph, # pylint: disable=protected-access
- placeholder._as_tf_output(), # pylint: disable=protected-access
+ target_t._op._graph._c_graph, # pylint: disable=protected-access
+ target_t._as_tf_output(), # pylint: disable=protected-access
shapes, ranks, types)
- return placeholder
-
def _get_device_functions(ctx, graph):
"""Returns a tuple of device functions representing the device stack."""
@@ -435,6 +453,7 @@ class _EagerDefinedFunction(object):
self._num_outputs = len(self.signature.output_arg)
self._output_types = [o.type for o in self.signature.output_arg]
self._output_shapes = [o.shape for o in outputs]
+ self._func_graph_outputs = outputs
self.grad_func_name = None
self.python_grad_func = None
self._c_func = c_api_util.ScopedTFFunction(fn)
@@ -511,6 +530,8 @@ class _EagerDefinedFunction(object):
else:
for i, shape in enumerate(self._output_shapes):
outputs[i].set_shape(shape)
+ for i, func_graph_output in enumerate(self._func_graph_outputs):
+ _copy_handle_data(func_graph_output, outputs[i])
return outputs
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 4a1bde3f5e..e4513cc87c 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -48,6 +48,7 @@ from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
@@ -438,10 +439,17 @@ class FunctionTest(test.TestCase):
def f():
x = constant_op.constant([[1, 2], [3, 4]])
out = math_ops.matmul(v, x)
- self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2]))
+ self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
+ # We do not return v directly since the tensor conversion function of
+ # ResourceVariable returns the read value and not the resource itself.
+ return v._handle
compiled = function.defun(f)
- compiled()
+ var_handle = compiled()
+ self.assertEqual(var_handle.dtype, dtypes.resource)
+ self.assertEqual(var_handle.shape, tensor_shape.scalar())
+ var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype)
+ self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2]))
def testVariableInLoopInFunction(self):
@@ -465,10 +473,17 @@ class FunctionTest(test.TestCase):
def f():
x = constant_op.constant([[1, 2], [3, 4]])
out = math_ops.matmul(v, x)
- self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2]))
+ self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
+ # We do not return v directly since the tensor conversion function of
+ # ResourceVariable returns the read value and not the resource itself.
+ return v._handle
compiled = function.defun(f)
- compiled()
+ var_handle = compiled()
+ self.assertEqual(var_handle.dtype, dtypes.resource)
+ self.assertEqual(var_handle.shape, tensor_shape.scalar())
+ var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype)
+ self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2]))
def testDefunShapeInferenceWithCapturedVariableInGraphMode(self):
with context.graph_mode():
@@ -477,12 +492,34 @@ class FunctionTest(test.TestCase):
def f():
x = constant_op.constant([[1, 2], [3, 4]])
out = math_ops.matmul(v, x)
- self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2]))
+ self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
# Check that shape inference works while creating the defun
compiled = function.defun(f)
compiled()
+ def testDefunShapeInferenceWithCapturedTensorListInGraphMode(self):
+ with context.graph_mode():
+ tensor_list = list_ops.empty_tensor_list(
+ element_dtype=dtypes.float32,
+ element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
+ tensor_list = list_ops.tensor_list_push_back(tensor_list,
+ constant_op.constant(1.0))
+ tensor_list = list_ops.tensor_list_push_back(tensor_list,
+ constant_op.constant(2.0))
+
+ def f():
+ tl, value = list_ops.tensor_list_pop_back(
+ tensor_list, element_dtype=dtypes.float32)
+ self.assertEqual(value.shape, tensor_shape.scalar())
+ return tl
+
+ compiled = function.defun(f)
+ output_tensor_list = compiled()
+ _, value = list_ops.tensor_list_pop_back(
+ output_tensor_list, element_dtype=dtypes.float32)
+ self.assertEqual(value.shape, tensor_shape.scalar())
+
@test_util.run_in_graph_and_eager_modes
def testDefunForcesResourceVariables(self):