diff options
author | Akshay Agrawal <akshayka@google.com> | 2017-12-21 11:23:20 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-21 11:28:14 -0800 |
commit | 1938feab97e36275f18a0745804299acfe137dc8 (patch) | |
tree | 36fc978e9203487326e7855b27f0faafe01c3cd9 /tensorflow/c/c_api.cc | |
parent | c7a05f4b18df0a9bd6b594d6f3d67b7489fcb54e (diff) |
This change adds a mechanism to the internal C API for updating an output handle's shapes and types after its source operation has been created.
Context: framework/ops.py was recently updated to use the C API when setting shapes for an op's outputs. This update broke shape inference for graph functions that captured resource handles; this, in turn, made it impossible to create graph functions from Python methods that required fully defined shapes (e.g., like MNIST's `call` method). In particular, the C API computes shapes for ops when they are created and does not update them thereafter; this is problematic because when a resource handle is captured while building a function, we need to update the captured operation's output handle in order to propagate its outputs shapes and dtypes.
PiperOrigin-RevId: 179837104
Diffstat (limited to 'tensorflow/c/c_api.cc')
-rw-r--r-- | tensorflow/c/c_api.cc | 65 |
1 files changed, 52 insertions, 13 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 3aff4f9178..bc19044fa2 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -644,6 +644,56 @@ void RecordMutation(TF_Graph* graph, const TF_Operation& op, } } +namespace { + +// Helper method that creates a shape handle for a shape described by dims. +tensorflow::shape_inference::ShapeHandle ShapeHandleFromDims( + tensorflow::shape_inference::InferenceContext* ic, int num_dims, + const int64_t* dims) { + if (num_dims != -1) { + std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec; + dim_vec.reserve(num_dims); + for (int i = 0; i < num_dims; ++i) { + dim_vec.push_back(ic->MakeDim(dims[i])); + } + return ic->MakeShape(dim_vec); + } else { + return ic->UnknownShape(); + } +} + +} // namespace + +void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output, + int num_shapes_and_types, + const int64_t** shapes, + const int* ranks, + const TF_DataType* types, + TF_Status* status) { + Node* node = &output.oper->node; + + mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(node); + if (ic == nullptr) { + status->status = + InvalidArgument("Node ", node->name(), " was not found in the graph"); + return; + } + + auto shape_and_type_vec = + std::vector<tensorflow::shape_inference::ShapeAndType>( + num_shapes_and_types); + for (int i = 0; i < num_shapes_and_types; ++i) { + tensorflow::shape_inference::ShapeHandle shape_handle = + ShapeHandleFromDims(ic, ranks[i], shapes[i]); + shape_and_type_vec[i] = tensorflow::shape_inference::ShapeAndType( + shape_handle, static_cast<DataType>(types[i])); + } + + ic->set_output_handle_shapes_and_types(output.index, shape_and_type_vec); +} + // Helpers for loading a TensorFlow plugin (a .so file). Status LoadLibrary(const char* library_filename, void** result, const void** buf, size_t* len); @@ -949,7 +999,6 @@ void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output, Node* node = &output.oper->node; mutex_lock l(graph->mu); - // Set the shape. tensorflow::shape_inference::InferenceContext* ic = graph->refiner.GetContext(node); if (ic == nullptr) { @@ -957,18 +1006,8 @@ void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output, InvalidArgument("Node ", node->name(), " was not found in the graph"); return; } - - tensorflow::shape_inference::ShapeHandle new_shape; - if (num_dims != -1) { - std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec; - dim_vec.reserve(num_dims); - for (int i = 0; i < num_dims; ++i) { - dim_vec.push_back(ic->MakeDim(dims[i])); - } - new_shape = ic->MakeShape(dim_vec); - } else { - new_shape = ic->UnknownShape(); - } + tensorflow::shape_inference::ShapeHandle new_shape = + tensorflow::ShapeHandleFromDims(ic, num_dims, dims); status->status = graph->refiner.SetShape(node, output.index, new_shape); } |