diff options
author | 2017-05-18 11:38:18 -0700 | |
---|---|---|
committer | 2017-05-18 11:41:48 -0700 | |
commit | 2701f31bb58df695809788fd82f1299c0b7f67fd (patch) | |
tree | 5c5f40eb131b21112d970b6420fa9fe5b7f3dc4d /tensorflow/core/grappler/grappler_item_builder.cc | |
parent | 5b9dcd8f9b30ca60ba4ee59c7dfd660203b08c17 (diff) |
Add function inlining support to Grappler.
PiperOrigin-RevId: 156457746
Diffstat (limited to 'tensorflow/core/grappler/grappler_item_builder.cc')
-rw-r--r-- | tensorflow/core/grappler/grappler_item_builder.cc | 108 |
1 files changed, 108 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index 88799ba881..02eecb0ac5 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -19,10 +19,16 @@ limitations under the License. #include <unordered_set> #include <vector> +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/variable.pb.h" +#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/grappler/inputs/utils.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" @@ -51,6 +57,99 @@ void InitializeTensor(DataType type, Tensor* tensor) { tensor->tensor_data().size()); } } + +// Helper function that returns a bool indicating if there are function +// call nodes in graph. +bool HasFunctionInGraph(const Graph& graph) { + for (const Node* n : graph.nodes()) { + if (graph.flib_def().Find(n->type_string()) != nullptr) { + return true; + } + } + return false; +} + +// Wrapper around FunctionDefToBodyHelper that creates a FunctionBody +// for function_def. +Status CreateFunctionBody(const FunctionLibraryDefinition& function_library, + const FunctionDef& function_def, + const NodeDef& node_def, + FunctionBody** function_body) { + std::function<Status(const string&, const OpDef**)> get_function_signature = + [&function_library](const string& name, const OpDef** signature) { + return function_library.LookUpOpDef(name, signature); + }; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + function_def, AttrSlice(node_def), &function_library, + get_function_signature, function_body)); + return Status::OK(); +} + +// Inlines all functions in a Graph. Does not recursively inline, so if graph +// contains Function A that calls Function B, calling InlineFunctions once will +// produce a graph with A inlined but not B. Calling InlineFunctions a second +// time will produce a graph with both A and B inlined. +Status InlineFunctions(Graph* graph) { + const FunctionLibraryDefinition& function_library = graph->flib_def(); + std::vector<std::pair<Node*, FunctionBody*>> nodes_and_funcs_to_inline; + std::unordered_map<string, std::unique_ptr<FunctionBody>> + function_name_to_body; + std::function<Status(const string&, const OpDef**)> get_function_signature = + [&function_library](const string& name, const OpDef** signature) { + return function_library.LookUpOpDef(name, signature); + }; + + for (Node* node : graph->nodes()) { + const FunctionDef* function_def = + function_library.Find(node->type_string()); + if (!function_def) { + // Not a function node. + continue; + } + FunctionBody* function_body = nullptr; + const string key = Canonicalize(node->def().op(), AttrSlice(node->def())); + if (function_name_to_body.find(key) == function_name_to_body.end()) { + TF_RETURN_IF_ERROR(CreateFunctionBody(function_library, *function_def, + node->def(), &function_body)); + function_name_to_body.emplace( + key, std::unique_ptr<FunctionBody>(function_body)); + } + function_body = function_name_to_body[key].get(); + if (function_body) { + nodes_and_funcs_to_inline.emplace_back(node, function_body); + } + } + + for (const auto& iter : nodes_and_funcs_to_inline) { + InlineFunctionBody(function_library, graph, iter.first, iter.second); + } + return Status::OK(); +} + +// Sets *inlined_graph to be graph with all function NodeDefs in graph inlined. +// Recursively inlines, so if graph contains Function A that calls Function B, +// calling InlineAllFunctions once will produce a graph with both A and B +// inlined. +Status InlineAllFunctions(const GraphDef& graph_def, + GraphDef* inlined_graph_def) { + *inlined_graph_def = GraphDef::default_instance(); + // Create a Graph from graph_def. Inlining needs to happen + // on a single Graph object in order to guarantee unique + // names of nodes created during the inlining process. + GraphConstructorOptions graph_ctor_opts; + graph_ctor_opts.allow_internal_ops = true; + graph_ctor_opts.expect_device_spec = false; + FunctionLibraryDefinition function_library(OpRegistry::Global(), + graph_def.library()); + Graph inlined_graph(function_library); + TF_RETURN_IF_ERROR( + ConvertGraphDefToGraph(graph_ctor_opts, graph_def, &inlined_graph)); + while (HasFunctionInGraph(inlined_graph)) { + TF_RETURN_IF_ERROR(InlineFunctions(&inlined_graph)); + } + inlined_graph.ToGraphDef(inlined_graph_def); + return Status::OK(); +} } // namespace // static @@ -64,6 +163,15 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef( new_item->id = id; new_item->graph = meta_graph.graph_def(); + if (cfg.inline_functions) { + Status s = InlineAllFunctions(meta_graph.graph_def(), &new_item->graph); + if (!s.ok()) { + LOG(ERROR) << "Unable to inline functions: " << s.error_message() + << ", skipping this input."; + return nullptr; + } + } + // Attempt to detect the fetch node(s). if (meta_graph.collection_def().count("train_op") > 0) { const CollectionDef& nodes = meta_graph.collection_def().at("train_op"); |