diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2018-03-23 16:28:16 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-25 04:27:22 -0700 |
commit | 202e4f3b3699e8e40e478402462f76ae853fecbf (patch) | |
tree | 9ce7c041c1d78ff4704d7c0752a7ea074b2d1a40 /tensorflow/c/python_api.cc | |
parent | 97249979d9a76ae05d590f9cbe199c0b47712b4f (diff) |
Make _USE_C_API = True and _USE_C_SHAPES = False work with handle data.
This change makes _set_shapes_for_outputs_c_api fetch and set
Tensor._handle_data. This is necessary for running the
Python shape inference code on resource tensors.
PiperOrigin-RevId: 190293303
Diffstat (limited to 'tensorflow/c/python_api.cc')
-rw-r--r-- | tensorflow/c/python_api.cc | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index cd604538f1..93155998b8 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/c/python_api.h" #include "tensorflow/c/c_api_internal.h" +#include "tensorflow/python/framework/cpp_shape_inference.pb.h" namespace tensorflow { @@ -109,4 +110,29 @@ void ExtendSession(TF_Session* session, TF_Status* status) { session->extend_before_run = false; } +std::string ResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) { + Node* node = &output.oper->node; + CppShapeInferenceResult::HandleData handle_data; + handle_data.set_is_set(true); + { + mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(node); + CHECK(ic != nullptr); + CHECK_LT(output.index, ic->num_outputs()); + const auto* shapes_and_types = + ic->output_handle_shapes_and_types(output.index); + if (shapes_and_types == nullptr) return ""; + + for (const auto& p : *shapes_and_types) { + auto* out_shape_and_type = handle_data.add_shape_and_type(); + ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape()); + out_shape_and_type->set_dtype(p.dtype); + } + } + string result; + handle_data.SerializeToString(&result); + return result; +} + } // namespace tensorflow |