aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/python_api.cc
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-04-24 15:45:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-24 15:48:41 -0700
commite7db82f821a1c522eed9e0c633df8b3db26ef38d (patch)
treee18a1e5afcbe56831bbce78784e0f36161921aaa /tensorflow/c/python_api.cc
parent184c8306a4a3d41f42f077b4898933500d61ce86 (diff)
Make TF functions work with _USE_C_SHAPES=True.
It turns out regular functions need to manually copy handle data in addition to eager GraphModeFunctions, so I moved the C extensions to python_api.h from eager/c_api.h. This also cleans up function_test.py to assume the C API is enabled. PiperOrigin-RevId: 194158700
Diffstat (limited to 'tensorflow/c/python_api.cc')
-rw-r--r--tensorflow/c/python_api.cc28
1 files changed, 27 insertions, 1 deletions
diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc
index 93155998b8..e18fdf6c57 100644
--- a/tensorflow/c/python_api.cc
+++ b/tensorflow/c/python_api.cc
@@ -110,7 +110,7 @@ void ExtendSession(TF_Session* session, TF_Status* status) {
session->extend_before_run = false;
}
-std::string ResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) {
+std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) {
Node* node = &output.oper->node;
CppShapeInferenceResult::HandleData handle_data;
handle_data.set_is_set(true);
@@ -135,4 +135,30 @@ std::string ResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) {
return result;
}
+void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output,
+ const void* proto, size_t proto_len,
+ TF_Status* status) {
+ tensorflow::CppShapeInferenceResult::HandleData handle_data;
+ if (!handle_data.ParseFromArray(proto, proto_len)) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "Couldn't deserialize HandleData proto");
+ return;
+ }
+ DCHECK(handle_data.is_set());
+
+ tensorflow::mutex_lock l(graph->mu);
+ tensorflow::shape_inference::InferenceContext* ic =
+ graph->refiner.GetContext(&output.oper->node);
+
+ std::vector<tensorflow::shape_inference::ShapeAndType> shapes_and_types;
+ for (const auto& shape_and_type_proto : handle_data.shape_and_type()) {
+ tensorflow::shape_inference::ShapeHandle shape;
+ status->status =
+ ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
+ if (status->status.ok()) return;
+ shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype());
+ }
+ ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
+}
+
} // namespace tensorflow