aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/python_api.cc
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-03-23 16:28:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-25 04:27:22 -0700
commit202e4f3b3699e8e40e478402462f76ae853fecbf (patch)
tree9ce7c041c1d78ff4704d7c0752a7ea074b2d1a40 /tensorflow/c/python_api.cc
parent97249979d9a76ae05d590f9cbe199c0b47712b4f (diff)
Make _USE_C_API = True and _USE_C_SHAPES = False work with handle data.
This change makes _set_shapes_for_outputs_c_api fetch and set Tensor._handle_data. This is necessary for running the Python shape inference code on resource tensors. PiperOrigin-RevId: 190293303
Diffstat (limited to 'tensorflow/c/python_api.cc')
-rw-r--r--tensorflow/c/python_api.cc26
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc
index cd604538f1..93155998b8 100644
--- a/tensorflow/c/python_api.cc
+++ b/tensorflow/c/python_api.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/c/python_api.h"
#include "tensorflow/c/c_api_internal.h"
+#include "tensorflow/python/framework/cpp_shape_inference.pb.h"
namespace tensorflow {
@@ -109,4 +110,29 @@ void ExtendSession(TF_Session* session, TF_Status* status) {
session->extend_before_run = false;
}
+std::string ResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) {
+ Node* node = &output.oper->node;
+ CppShapeInferenceResult::HandleData handle_data;
+ handle_data.set_is_set(true);
+ {
+ mutex_lock l(graph->mu);
+ tensorflow::shape_inference::InferenceContext* ic =
+ graph->refiner.GetContext(node);
+ CHECK(ic != nullptr);
+ CHECK_LT(output.index, ic->num_outputs());
+ const auto* shapes_and_types =
+ ic->output_handle_shapes_and_types(output.index);
+ if (shapes_and_types == nullptr) return "";
+
+ for (const auto& p : *shapes_and_types) {
+ auto* out_shape_and_type = handle_data.add_shape_and_type();
+ ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape());
+ out_shape_and_type->set_dtype(p.dtype);
+ }
+ }
+ string result;
+ handle_data.SerializeToString(&result);
+ return result;
+}
+
} // namespace tensorflow