diff options
author | Saurabh Saxena <srbs@google.com> | 2018-09-19 18:27:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-19 18:30:56 -0700 |
commit | 9f05ca4ec89d9b03f740f881ae50d97d76a1b849 (patch) | |
tree | a585b3bcc287251896f211f2b2ed504f82352188 | |
parent | 415455b0ef2d65504ab8c9084a6daa2899521212 (diff) |
Copy Tensor._handle_data from external_capture to placeholder for Variant tensors in Graph mode defun.
This allows inferring the shape of values popped from TensorLists inside defuns.
Remove "Resource" from {Set|Get}ResourceHandleShapeAndType since the same functions are re-usable for variants.
Eager mode fix coming in a future changelist.
PiperOrigin-RevId: 213735462
-rw-r--r-- | tensorflow/c/python_api.cc | 7 | ||||
-rw-r--r-- | tensorflow/c/python_api.h | 13 | ||||
-rw-r--r-- | tensorflow/python/client/tf_session.i | 4 | ||||
-rw-r--r-- | tensorflow/python/eager/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/eager/function.py | 43 | ||||
-rw-r--r-- | tensorflow/python/eager/function_test.py | 47 | ||||
-rw-r--r-- | tensorflow/python/framework/function.py | 9 | ||||
-rw-r--r-- | tensorflow/python/framework/ops.py | 4 | ||||
-rw-r--r-- | tensorflow/python/ops/resource_variable_ops.py | 2 |
9 files changed, 94 insertions, 36 deletions
diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index 8486b585c8..247236b760 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -110,7 +110,7 @@ void ExtendSession(TF_Session* session, TF_Status* status) { session->extend_before_run = false; } -std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) { +std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) { Node* node = &output.oper->node; CppShapeInferenceResult::HandleData handle_data; handle_data.set_is_set(true); @@ -135,9 +135,8 @@ std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) { return result; } -void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output, - const void* proto, size_t proto_len, - TF_Status* status) { +void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, + size_t proto_len, TF_Status* status) { tensorflow::CppShapeInferenceResult::HandleData handle_data; if (!handle_data.ParseFromArray(proto, proto_len)) { status->status = tensorflow::errors::InvalidArgument( diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h index 4bcb5bde62..5cce84020b 100644 --- a/tensorflow/c/python_api.h +++ b/tensorflow/c/python_api.h @@ -54,16 +54,17 @@ void SetRequireShapeInferenceFns(TF_Graph* graph, bool require); void ExtendSession(TF_Session* session, TF_Status* status); // Returns the serialized CppShapeInferenceResult::HandleData proto for -// `output` if its a resource tensor, or otherwise returns the empty string. -std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output); +// `output` if its a resource or variant tensor, or otherwise returns the empty +// string. +std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output); // Sets `output` based on `proto`, which should be a serialized -// CppShapeInferenceResult::HandleData proto. +// CppShapeInferenceResult::HandleData proto. `output` should be a resource +// or variant tensor. // NOTE(skyewm): `proto` is passed a void*/size_t pair instead of a std::string // because I couldn't get SWIG to work otherwise. -void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output, - const void* proto, size_t proto_len, - TF_Status* status); +void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, + size_t proto_len, TF_Status* status); } // namespace tensorflow #endif // TENSORFLOW_C_PYTHON_API_H_ diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index 39a2922ac0..ef7527d887 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -463,7 +463,7 @@ TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{ } // Override default py3 behavior of attempting to encode into Unicode. -%typemap(out) std::string tensorflow::GetResourceHandleShapeAndType { +%typemap(out) std::string tensorflow::GetHandleShapeAndType { $result = PyBytes_FromStringAndSize($1.data(), $1.size()); } @@ -782,7 +782,7 @@ def TF_Reset(target, containers=None, config=None): %unignore TF_TryEvaluateConstant_wrapper; %noexception TF_TryEvaluateConstant_wrapper; %unignore ExtendSession; -%unignore ResourceHandleShapeAndType; +%unignore HandleShapeAndType; %include "tensorflow/python/client/tf_session_helper.h" diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index f80256fc2a..a2686c68a9 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -147,6 +147,7 @@ cuda_py_test( "//tensorflow/python:clip_ops", "//tensorflow/python:init_ops", "//tensorflow/python:layers", + "//tensorflow/python:list_ops", "//tensorflow/python:math_ops", "//tensorflow/python:resource_variable_ops", ], diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index a68c6ab3b4..bcb1881264 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -73,16 +73,36 @@ def _create_substitute_placeholder(value, name=None, dtype=None): with ops.control_dependencies(None): placeholder = graph_placeholder( dtype=dtype or value.dtype, shape=value.shape, name=name) - if placeholder.dtype == dtypes_module.resource: - if isinstance(value, ops.EagerTensor): - handle_data = value._handle_data # pylint: disable=protected-access + _copy_handle_data(value, placeholder) + return placeholder + + +def _copy_handle_data(source_t, target_t): + """Copies HandleData for variant and resource type tensors if available. + + The CppShapeInferenceResult::HandleData proto contains information about the + shapes and types of the element tensors of resource/variant type tensors. + We need to copy this across function boundaries, i.e., when capturing a + placeholder or when returning a function tensor as output. If we don't do this + the element tensors will have unknown shapes, e.g., if a TensorList variant + tensor is captured as a placeholder, elements popped from that list would have + unknown shape. + + Args: + source_t: The tensor to copy HandleData from. + target_t: The tensor to copy HandleData to. + """ + if (target_t.dtype == dtypes_module.resource or + target_t.dtype == dtypes_module.variant): + if isinstance(source_t, ops.EagerTensor): + handle_data = source_t._handle_data # pylint: disable=protected-access else: - handle_data = resource_variable_ops.get_resource_handle_data(value) + handle_data = resource_variable_ops.get_resource_handle_data(source_t) if handle_data is not None and handle_data.is_set: # pylint: disable=protected-access - pywrap_tensorflow.SetResourceHandleShapeAndType( - placeholder.graph._c_graph, placeholder._as_tf_output(), - handle_data.SerializeToString()) + pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph, + target_t._as_tf_output(), + handle_data.SerializeToString()) # pylint: enable=protected-access # Ensure that shapes and dtypes are propagated. shapes, types = zip(*[(pair.shape, pair.dtype) @@ -91,12 +111,10 @@ def _create_substitute_placeholder(value, name=None, dtype=None): shapes = [[d.size for d in s.dim] if not s.unknown_rank else None for s in shapes] pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( - placeholder._op._graph._c_graph, # pylint: disable=protected-access - placeholder._as_tf_output(), # pylint: disable=protected-access + target_t._op._graph._c_graph, # pylint: disable=protected-access + target_t._as_tf_output(), # pylint: disable=protected-access shapes, ranks, types) - return placeholder - def _get_device_functions(ctx, graph): """Returns a tuple of device functions representing the device stack.""" @@ -435,6 +453,7 @@ class _EagerDefinedFunction(object): self._num_outputs = len(self.signature.output_arg) self._output_types = [o.type for o in self.signature.output_arg] self._output_shapes = [o.shape for o in outputs] + self._func_graph_outputs = outputs self.grad_func_name = None self.python_grad_func = None self._c_func = c_api_util.ScopedTFFunction(fn) @@ -511,6 +530,8 @@ class _EagerDefinedFunction(object): else: for i, shape in enumerate(self._output_shapes): outputs[i].set_shape(shape) + for i, func_graph_output in enumerate(self._func_graph_outputs): + _copy_handle_data(func_graph_output, outputs[i]) return outputs diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 4a1bde3f5e..e4513cc87c 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -48,6 +48,7 @@ from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops +from tensorflow.python.ops import list_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops @@ -438,10 +439,17 @@ class FunctionTest(test.TestCase): def f(): x = constant_op.constant([[1, 2], [3, 4]]) out = math_ops.matmul(v, x) - self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2])) + self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2])) + # We do not return v directly since the tensor conversion function of + # ResourceVariable returns the read value and not the resource itself. + return v._handle compiled = function.defun(f) - compiled() + var_handle = compiled() + self.assertEqual(var_handle.dtype, dtypes.resource) + self.assertEqual(var_handle.shape, tensor_shape.scalar()) + var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype) + self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2])) def testVariableInLoopInFunction(self): @@ -465,10 +473,17 @@ class FunctionTest(test.TestCase): def f(): x = constant_op.constant([[1, 2], [3, 4]]) out = math_ops.matmul(v, x) - self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2])) + self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2])) + # We do not return v directly since the tensor conversion function of + # ResourceVariable returns the read value and not the resource itself. + return v._handle compiled = function.defun(f) - compiled() + var_handle = compiled() + self.assertEqual(var_handle.dtype, dtypes.resource) + self.assertEqual(var_handle.shape, tensor_shape.scalar()) + var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype) + self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2])) def testDefunShapeInferenceWithCapturedVariableInGraphMode(self): with context.graph_mode(): @@ -477,12 +492,34 @@ class FunctionTest(test.TestCase): def f(): x = constant_op.constant([[1, 2], [3, 4]]) out = math_ops.matmul(v, x) - self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2])) + self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2])) # Check that shape inference works while creating the defun compiled = function.defun(f) compiled() + def testDefunShapeInferenceWithCapturedTensorListInGraphMode(self): + with context.graph_mode(): + tensor_list = list_ops.empty_tensor_list( + element_dtype=dtypes.float32, + element_shape=ops.convert_to_tensor([], dtype=dtypes.int32)) + tensor_list = list_ops.tensor_list_push_back(tensor_list, + constant_op.constant(1.0)) + tensor_list = list_ops.tensor_list_push_back(tensor_list, + constant_op.constant(2.0)) + + def f(): + tl, value = list_ops.tensor_list_pop_back( + tensor_list, element_dtype=dtypes.float32) + self.assertEqual(value.shape, tensor_shape.scalar()) + return tl + + compiled = function.defun(f) + output_tensor_list = compiled() + _, value = list_ops.tensor_list_pop_back( + output_tensor_list, element_dtype=dtypes.float32) + self.assertEqual(value.shape, tensor_shape.scalar()) + @test_util.run_in_graph_and_eager_modes def testDefunForcesResourceVariables(self): diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index a8aef3a009..68b3170dfe 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -762,13 +762,12 @@ class _FuncGraph(ops.Graph): if handle_data: handle_data = handle_data.SerializeToString() else: - handle_data = c_api.GetResourceHandleShapeAndType( - tensor.graph._c_graph, tensor._as_tf_output()) + handle_data = c_api.GetHandleShapeAndType(tensor.graph._c_graph, + tensor._as_tf_output()) if handle_data: - c_api.SetResourceHandleShapeAndType(ph.graph._c_graph, - ph._as_tf_output(), - compat.as_bytes(handle_data)) + c_api.SetHandleShapeAndType(ph.graph._c_graph, ph._as_tf_output(), + compat.as_bytes(handle_data)) else: ph._handle_data = tensor._handle_data # pylint: enable=protected-access diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 343f52fe8f..8bb177939e 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -2532,8 +2532,8 @@ def _set_shape_and_handle_data_for_outputs_c_api(op): output._shape_val = output._c_api_shape() # Set the resource handle data for compatibility with the Python shape # inference code. - serialized = c_api.GetResourceHandleShapeAndType(op._graph._c_graph, - output._as_tf_output()) + serialized = c_api.GetHandleShapeAndType(op._graph._c_graph, # pylint: disable=protected-access + output._as_tf_output()) if serialized: output._handle_data = ( cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 55c2eb5fa4..9e477ab8af 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -48,7 +48,7 @@ def get_resource_handle_data(graph_op): assert ops._USE_C_SHAPES # pylint: disable=protected-access assert type(graph_op) == ops.Tensor # pylint: disable=unidiomatic-typecheck - handle_data = pywrap_tensorflow.GetResourceHandleShapeAndType( + handle_data = pywrap_tensorflow.GetHandleShapeAndType( graph_op.graph._c_graph, graph_op._as_tf_output()) # pylint: disable=protected-access return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString( |