diff options
author | 2018-09-19 18:27:52 -0700 | |
---|---|---|
committer | 2018-09-19 18:30:56 -0700 | |
commit | 9f05ca4ec89d9b03f740f881ae50d97d76a1b849 (patch) | |
tree | a585b3bcc287251896f211f2b2ed504f82352188 /tensorflow/python/framework | |
parent | 415455b0ef2d65504ab8c9084a6daa2899521212 (diff) |
Copy Tensor._handle_data from external_capture to placeholder for Variant tensors in Graph mode defun.
This allows inferring the shape of values popped from TensorLists inside defuns.
Remove "Resource" from {Set|Get}ResourceHandleShapeAndType since the same functions are re-usable for variants.
Eager mode fix coming in a future changelist.
PiperOrigin-RevId: 213735462
Diffstat (limited to 'tensorflow/python/framework')
-rw-r--r-- | tensorflow/python/framework/function.py | 9 | ||||
-rw-r--r-- | tensorflow/python/framework/ops.py | 4 |
2 files changed, 6 insertions, 7 deletions
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index a8aef3a009..68b3170dfe 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -762,13 +762,12 @@ class _FuncGraph(ops.Graph): if handle_data: handle_data = handle_data.SerializeToString() else: - handle_data = c_api.GetResourceHandleShapeAndType( - tensor.graph._c_graph, tensor._as_tf_output()) + handle_data = c_api.GetHandleShapeAndType(tensor.graph._c_graph, + tensor._as_tf_output()) if handle_data: - c_api.SetResourceHandleShapeAndType(ph.graph._c_graph, - ph._as_tf_output(), - compat.as_bytes(handle_data)) + c_api.SetHandleShapeAndType(ph.graph._c_graph, ph._as_tf_output(), + compat.as_bytes(handle_data)) else: ph._handle_data = tensor._handle_data # pylint: enable=protected-access diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 343f52fe8f..8bb177939e 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -2532,8 +2532,8 @@ def _set_shape_and_handle_data_for_outputs_c_api(op): output._shape_val = output._c_api_shape() # Set the resource handle data for compatibility with the Python shape # inference code. - serialized = c_api.GetResourceHandleShapeAndType(op._graph._c_graph, - output._as_tf_output()) + serialized = c_api.GetHandleShapeAndType(op._graph._c_graph, # pylint: disable=protected-access + output._as_tf_output()) if serialized: output._handle_data = ( cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData |