diff options
author | Saurabh Saxena <srbs@google.com> | 2018-09-19 18:27:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-19 18:30:56 -0700 |
commit | 9f05ca4ec89d9b03f740f881ae50d97d76a1b849 (patch) | |
tree | a585b3bcc287251896f211f2b2ed504f82352188 /tensorflow/python/eager | |
parent | 415455b0ef2d65504ab8c9084a6daa2899521212 (diff) |
Copy Tensor._handle_data from external_capture to placeholder for Variant tensors in Graph mode defun.
This allows inferring the shape of values popped from TensorLists inside defuns.
Remove "Resource" from {Set|Get}ResourceHandleShapeAndType since the same functions are re-usable for variants.
Eager mode fix coming in a future changelist.
PiperOrigin-RevId: 213735462
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r-- | tensorflow/python/eager/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/eager/function.py | 43 | ||||
-rw-r--r-- | tensorflow/python/eager/function_test.py | 47 |
3 files changed, 75 insertions, 16 deletions
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index f80256fc2a..a2686c68a9 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -147,6 +147,7 @@ cuda_py_test( "//tensorflow/python:clip_ops", "//tensorflow/python:init_ops", "//tensorflow/python:layers", + "//tensorflow/python:list_ops", "//tensorflow/python:math_ops", "//tensorflow/python:resource_variable_ops", ], diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index a68c6ab3b4..bcb1881264 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -73,16 +73,36 @@ def _create_substitute_placeholder(value, name=None, dtype=None): with ops.control_dependencies(None): placeholder = graph_placeholder( dtype=dtype or value.dtype, shape=value.shape, name=name) - if placeholder.dtype == dtypes_module.resource: - if isinstance(value, ops.EagerTensor): - handle_data = value._handle_data # pylint: disable=protected-access + _copy_handle_data(value, placeholder) + return placeholder + + +def _copy_handle_data(source_t, target_t): + """Copies HandleData for variant and resource type tensors if available. + + The CppShapeInferenceResult::HandleData proto contains information about the + shapes and types of the element tensors of resource/variant type tensors. + We need to copy this across function boundaries, i.e., when capturing a + placeholder or when returning a function tensor as output. If we don't do this + the element tensors will have unknown shapes, e.g., if a TensorList variant + tensor is captured as a placeholder, elements popped from that list would have + unknown shape. + + Args: + source_t: The tensor to copy HandleData from. + target_t: The tensor to copy HandleData to. + """ + if (target_t.dtype == dtypes_module.resource or + target_t.dtype == dtypes_module.variant): + if isinstance(source_t, ops.EagerTensor): + handle_data = source_t._handle_data # pylint: disable=protected-access else: - handle_data = resource_variable_ops.get_resource_handle_data(value) + handle_data = resource_variable_ops.get_resource_handle_data(source_t) if handle_data is not None and handle_data.is_set: # pylint: disable=protected-access - pywrap_tensorflow.SetResourceHandleShapeAndType( - placeholder.graph._c_graph, placeholder._as_tf_output(), - handle_data.SerializeToString()) + pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph, + target_t._as_tf_output(), + handle_data.SerializeToString()) # pylint: enable=protected-access # Ensure that shapes and dtypes are propagated. shapes, types = zip(*[(pair.shape, pair.dtype) @@ -91,12 +111,10 @@ def _create_substitute_placeholder(value, name=None, dtype=None): shapes = [[d.size for d in s.dim] if not s.unknown_rank else None for s in shapes] pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( - placeholder._op._graph._c_graph, # pylint: disable=protected-access - placeholder._as_tf_output(), # pylint: disable=protected-access + target_t._op._graph._c_graph, # pylint: disable=protected-access + target_t._as_tf_output(), # pylint: disable=protected-access shapes, ranks, types) - return placeholder - def _get_device_functions(ctx, graph): """Returns a tuple of device functions representing the device stack.""" @@ -435,6 +453,7 @@ class _EagerDefinedFunction(object): self._num_outputs = len(self.signature.output_arg) self._output_types = [o.type for o in self.signature.output_arg] self._output_shapes = [o.shape for o in outputs] + self._func_graph_outputs = outputs self.grad_func_name = None self.python_grad_func = None self._c_func = c_api_util.ScopedTFFunction(fn) @@ -511,6 +530,8 @@ class _EagerDefinedFunction(object): else: for i, shape in enumerate(self._output_shapes): outputs[i].set_shape(shape) + for i, func_graph_output in enumerate(self._func_graph_outputs): + _copy_handle_data(func_graph_output, outputs[i]) return outputs diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 4a1bde3f5e..e4513cc87c 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -48,6 +48,7 @@ from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops +from tensorflow.python.ops import list_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops @@ -438,10 +439,17 @@ class FunctionTest(test.TestCase): def f(): x = constant_op.constant([[1, 2], [3, 4]]) out = math_ops.matmul(v, x) - self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2])) + self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2])) + # We do not return v directly since the tensor conversion function of + # ResourceVariable returns the read value and not the resource itself. + return v._handle compiled = function.defun(f) - compiled() + var_handle = compiled() + self.assertEqual(var_handle.dtype, dtypes.resource) + self.assertEqual(var_handle.shape, tensor_shape.scalar()) + var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype) + self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2])) def testVariableInLoopInFunction(self): @@ -465,10 +473,17 @@ class FunctionTest(test.TestCase): def f(): x = constant_op.constant([[1, 2], [3, 4]]) out = math_ops.matmul(v, x) - self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2])) + self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2])) + # We do not return v directly since the tensor conversion function of + # ResourceVariable returns the read value and not the resource itself. + return v._handle compiled = function.defun(f) - compiled() + var_handle = compiled() + self.assertEqual(var_handle.dtype, dtypes.resource) + self.assertEqual(var_handle.shape, tensor_shape.scalar()) + var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype) + self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2])) def testDefunShapeInferenceWithCapturedVariableInGraphMode(self): with context.graph_mode(): @@ -477,12 +492,34 @@ class FunctionTest(test.TestCase): def f(): x = constant_op.constant([[1, 2], [3, 4]]) out = math_ops.matmul(v, x) - self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2])) + self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2])) # Check that shape inference works while creating the defun compiled = function.defun(f) compiled() + def testDefunShapeInferenceWithCapturedTensorListInGraphMode(self): + with context.graph_mode(): + tensor_list = list_ops.empty_tensor_list( + element_dtype=dtypes.float32, + element_shape=ops.convert_to_tensor([], dtype=dtypes.int32)) + tensor_list = list_ops.tensor_list_push_back(tensor_list, + constant_op.constant(1.0)) + tensor_list = list_ops.tensor_list_push_back(tensor_list, + constant_op.constant(2.0)) + + def f(): + tl, value = list_ops.tensor_list_pop_back( + tensor_list, element_dtype=dtypes.float32) + self.assertEqual(value.shape, tensor_shape.scalar()) + return tl + + compiled = function.defun(f) + output_tensor_list = compiled() + _, value = list_ops.tensor_list_pop_back( + output_tensor_list, element_dtype=dtypes.float32) + self.assertEqual(value.shape, tensor_shape.scalar()) + @test_util.run_in_graph_and_eager_modes def testDefunForcesResourceVariables(self): |