aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api.cc
diff options
context:
space:
mode:
authorGravatar Akshay Agrawal <akshayka@google.com>2017-12-21 11:23:20 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-21 11:28:14 -0800
commit1938feab97e36275f18a0745804299acfe137dc8 (patch)
tree36fc978e9203487326e7855b27f0faafe01c3cd9 /tensorflow/c/c_api.cc
parentc7a05f4b18df0a9bd6b594d6f3d67b7489fcb54e (diff)
This change adds a mechanism to the internal C API for updating an output handle's shapes and types after its source operation has been created.
Context: framework/ops.py was recently updated to use the C API when setting shapes for an op's outputs. This update broke shape inference for graph functions that captured resource handles; this, in turn, made it impossible to create graph functions from Python methods that required fully defined shapes (e.g., like MNIST's `call` method). In particular, the C API computes shapes for ops when they are created and does not update them thereafter; this is problematic because when a resource handle is captured while building a function, we need to update the captured operation's output handle in order to propagate its outputs shapes and dtypes. PiperOrigin-RevId: 179837104
Diffstat (limited to 'tensorflow/c/c_api.cc')
-rw-r--r--tensorflow/c/c_api.cc65
1 files changed, 52 insertions, 13 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 3aff4f9178..bc19044fa2 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -644,6 +644,56 @@ void RecordMutation(TF_Graph* graph, const TF_Operation& op,
}
}
+namespace {
+
+// Helper method that creates a shape handle for a shape described by dims.
+tensorflow::shape_inference::ShapeHandle ShapeHandleFromDims(
+ tensorflow::shape_inference::InferenceContext* ic, int num_dims,
+ const int64_t* dims) {
+ if (num_dims != -1) {
+ std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec;
+ dim_vec.reserve(num_dims);
+ for (int i = 0; i < num_dims; ++i) {
+ dim_vec.push_back(ic->MakeDim(dims[i]));
+ }
+ return ic->MakeShape(dim_vec);
+ } else {
+ return ic->UnknownShape();
+ }
+}
+
+} // namespace
+
+void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
+ int num_shapes_and_types,
+ const int64_t** shapes,
+ const int* ranks,
+ const TF_DataType* types,
+ TF_Status* status) {
+ Node* node = &output.oper->node;
+
+ mutex_lock l(graph->mu);
+ tensorflow::shape_inference::InferenceContext* ic =
+ graph->refiner.GetContext(node);
+ if (ic == nullptr) {
+ status->status =
+ InvalidArgument("Node ", node->name(), " was not found in the graph");
+ return;
+ }
+
+ auto shape_and_type_vec =
+ std::vector<tensorflow::shape_inference::ShapeAndType>(
+ num_shapes_and_types);
+ for (int i = 0; i < num_shapes_and_types; ++i) {
+ tensorflow::shape_inference::ShapeHandle shape_handle =
+ ShapeHandleFromDims(ic, ranks[i], shapes[i]);
+ shape_and_type_vec[i] = tensorflow::shape_inference::ShapeAndType(
+ shape_handle, static_cast<DataType>(types[i]));
+ }
+
+ ic->set_output_handle_shapes_and_types(output.index, shape_and_type_vec);
+}
+
// Helpers for loading a TensorFlow plugin (a .so file).
Status LoadLibrary(const char* library_filename, void** result,
const void** buf, size_t* len);
@@ -949,7 +999,6 @@ void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output,
Node* node = &output.oper->node;
mutex_lock l(graph->mu);
- // Set the shape.
tensorflow::shape_inference::InferenceContext* ic =
graph->refiner.GetContext(node);
if (ic == nullptr) {
@@ -957,18 +1006,8 @@ void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output,
InvalidArgument("Node ", node->name(), " was not found in the graph");
return;
}
-
- tensorflow::shape_inference::ShapeHandle new_shape;
- if (num_dims != -1) {
- std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec;
- dim_vec.reserve(num_dims);
- for (int i = 0; i < num_dims; ++i) {
- dim_vec.push_back(ic->MakeDim(dims[i]));
- }
- new_shape = ic->MakeShape(dim_vec);
- } else {
- new_shape = ic->UnknownShape();
- }
+ tensorflow::shape_inference::ShapeHandle new_shape =
+ tensorflow::ShapeHandleFromDims(ic, num_dims, dims);
status->status = graph->refiner.SetShape(node, output.index, new_shape);
}