aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c
diff options
context:
space:
mode:
authorGravatar Saurabh Saxena <srbs@google.com>2018-09-19 18:27:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-19 18:30:56 -0700
commit9f05ca4ec89d9b03f740f881ae50d97d76a1b849 (patch)
treea585b3bcc287251896f211f2b2ed504f82352188 /tensorflow/c
parent415455b0ef2d65504ab8c9084a6daa2899521212 (diff)
Copy Tensor._handle_data from external_capture to placeholder for Variant tensors in Graph mode defun.
This allows inferring the shape of values popped from TensorLists inside defuns. Remove "Resource" from {Set|Get}ResourceHandleShapeAndType since the same functions are re-usable for variants. Eager mode fix coming in a future changelist. PiperOrigin-RevId: 213735462
Diffstat (limited to 'tensorflow/c')
-rw-r--r--tensorflow/c/python_api.cc7
-rw-r--r--tensorflow/c/python_api.h13
2 files changed, 10 insertions, 10 deletions
diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc
index 8486b585c8..247236b760 100644
--- a/tensorflow/c/python_api.cc
+++ b/tensorflow/c/python_api.cc
@@ -110,7 +110,7 @@ void ExtendSession(TF_Session* session, TF_Status* status) {
session->extend_before_run = false;
}
-std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) {
+std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) {
Node* node = &output.oper->node;
CppShapeInferenceResult::HandleData handle_data;
handle_data.set_is_set(true);
@@ -135,9 +135,8 @@ std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) {
return result;
}
-void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output,
- const void* proto, size_t proto_len,
- TF_Status* status) {
+void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
+ size_t proto_len, TF_Status* status) {
tensorflow::CppShapeInferenceResult::HandleData handle_data;
if (!handle_data.ParseFromArray(proto, proto_len)) {
status->status = tensorflow::errors::InvalidArgument(
diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h
index 4bcb5bde62..5cce84020b 100644
--- a/tensorflow/c/python_api.h
+++ b/tensorflow/c/python_api.h
@@ -54,16 +54,17 @@ void SetRequireShapeInferenceFns(TF_Graph* graph, bool require);
void ExtendSession(TF_Session* session, TF_Status* status);
// Returns the serialized CppShapeInferenceResult::HandleData proto for
-// `output` if its a resource tensor, or otherwise returns the empty string.
-std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output);
+// `output` if its a resource or variant tensor, or otherwise returns the empty
+// string.
+std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output);
// Sets `output` based on `proto`, which should be a serialized
-// CppShapeInferenceResult::HandleData proto.
+// CppShapeInferenceResult::HandleData proto. `output` should be a resource
+// or variant tensor.
// NOTE(skyewm): `proto` is passed a void*/size_t pair instead of a std::string
// because I couldn't get SWIG to work otherwise.
-void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output,
- const void* proto, size_t proto_len,
- TF_Status* status);
+void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
+ size_t proto_len, TF_Status* status);
} // namespace tensorflow
#endif // TENSORFLOW_C_PYTHON_API_H_