aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework
diff options
context:
space:
mode:
authorGravatar Saurabh Saxena <srbs@google.com>2018-09-19 18:27:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-19 18:30:56 -0700
commit9f05ca4ec89d9b03f740f881ae50d97d76a1b849 (patch)
treea585b3bcc287251896f211f2b2ed504f82352188 /tensorflow/python/framework
parent415455b0ef2d65504ab8c9084a6daa2899521212 (diff)
Copy Tensor._handle_data from external_capture to placeholder for Variant tensors in Graph mode defun.
This allows inferring the shape of values popped from TensorLists inside defuns. Remove "Resource" from {Set|Get}ResourceHandleShapeAndType since the same functions are re-usable for variants. Eager mode fix coming in a future changelist. PiperOrigin-RevId: 213735462
Diffstat (limited to 'tensorflow/python/framework')
-rw-r--r--tensorflow/python/framework/function.py9
-rw-r--r--tensorflow/python/framework/ops.py4
2 files changed, 6 insertions, 7 deletions
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index a8aef3a009..68b3170dfe 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -762,13 +762,12 @@ class _FuncGraph(ops.Graph):
if handle_data:
handle_data = handle_data.SerializeToString()
else:
- handle_data = c_api.GetResourceHandleShapeAndType(
- tensor.graph._c_graph, tensor._as_tf_output())
+ handle_data = c_api.GetHandleShapeAndType(tensor.graph._c_graph,
+ tensor._as_tf_output())
if handle_data:
- c_api.SetResourceHandleShapeAndType(ph.graph._c_graph,
- ph._as_tf_output(),
- compat.as_bytes(handle_data))
+ c_api.SetHandleShapeAndType(ph.graph._c_graph, ph._as_tf_output(),
+ compat.as_bytes(handle_data))
else:
ph._handle_data = tensor._handle_data
# pylint: enable=protected-access
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 343f52fe8f..8bb177939e 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -2532,8 +2532,8 @@ def _set_shape_and_handle_data_for_outputs_c_api(op):
output._shape_val = output._c_api_shape()
# Set the resource handle data for compatibility with the Python shape
# inference code.
- serialized = c_api.GetResourceHandleShapeAndType(op._graph._c_graph,
- output._as_tf_output())
+ serialized = c_api.GetHandleShapeAndType(op._graph._c_graph, # pylint: disable=protected-access
+ output._as_tf_output())
if serialized:
output._handle_data = (
cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData