diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2018-04-24 15:45:50 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-24 15:48:41 -0700 |
commit | e7db82f821a1c522eed9e0c633df8b3db26ef38d (patch) | |
tree | e18a1e5afcbe56831bbce78784e0f36161921aaa /tensorflow/c/python_api.cc | |
parent | 184c8306a4a3d41f42f077b4898933500d61ce86 (diff) |
Make TF functions work with _USE_C_SHAPES=True.
It turns out regular functions need to manually copy handle data in
addition to eager GraphModeFunctions, so I moved the C extensions to
python_api.h from eager/c_api.h.
This also cleans up function_test.py to assume the C API is enabled.
PiperOrigin-RevId: 194158700
Diffstat (limited to 'tensorflow/c/python_api.cc')
-rw-r--r-- | tensorflow/c/python_api.cc | 28 |
1 files changed, 27 insertions, 1 deletions
diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index 93155998b8..e18fdf6c57 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -110,7 +110,7 @@ void ExtendSession(TF_Session* session, TF_Status* status) { session->extend_before_run = false; } -std::string ResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) { +std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) { Node* node = &output.oper->node; CppShapeInferenceResult::HandleData handle_data; handle_data.set_is_set(true); @@ -135,4 +135,30 @@ std::string ResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) { return result; } +void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output, + const void* proto, size_t proto_len, + TF_Status* status) { + tensorflow::CppShapeInferenceResult::HandleData handle_data; + if (!handle_data.ParseFromArray(proto, proto_len)) { + status->status = tensorflow::errors::InvalidArgument( + "Couldn't deserialize HandleData proto"); + return; + } + DCHECK(handle_data.is_set()); + + tensorflow::mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(&output.oper->node); + + std::vector<tensorflow::shape_inference::ShapeAndType> shapes_and_types; + for (const auto& shape_and_type_proto : handle_data.shape_and_type()) { + tensorflow::shape_inference::ShapeHandle shape; + status->status = + ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape); + if (status->status.ok()) return; + shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype()); + } + ic->set_output_handle_shapes_and_types(output.index, shapes_and_types); +} + } // namespace tensorflow |