aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/grappler_item_builder.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-18 11:38:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-18 11:41:48 -0700
commit2701f31bb58df695809788fd82f1299c0b7f67fd (patch)
tree5c5f40eb131b21112d970b6420fa9fe5b7f3dc4d /tensorflow/core/grappler/grappler_item_builder.cc
parent5b9dcd8f9b30ca60ba4ee59c7dfd660203b08c17 (diff)
Add function inlining support to Grappler.
PiperOrigin-RevId: 156457746
Diffstat (limited to 'tensorflow/core/grappler/grappler_item_builder.cc')
-rw-r--r--tensorflow/core/grappler/grappler_item_builder.cc108
1 files changed, 108 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc
index 88799ba881..02eecb0ac5 100644
--- a/tensorflow/core/grappler/grappler_item_builder.cc
+++ b/tensorflow/core/grappler/grappler_item_builder.cc
@@ -19,10 +19,16 @@ limitations under the License.
#include <unordered_set>
#include <vector>
+#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/framework/variable.pb.h"
+#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/grappler/inputs/utils.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
@@ -51,6 +57,99 @@ void InitializeTensor(DataType type, Tensor* tensor) {
tensor->tensor_data().size());
}
}
+
+// Helper function that returns a bool indicating if there are function
+// call nodes in graph.
+bool HasFunctionInGraph(const Graph& graph) {
+ for (const Node* n : graph.nodes()) {
+ if (graph.flib_def().Find(n->type_string()) != nullptr) {
+ return true;
+ }
+ }
+ return false;
+}
+
+// Wrapper around FunctionDefToBodyHelper that creates a FunctionBody
+// for function_def.
+Status CreateFunctionBody(const FunctionLibraryDefinition& function_library,
+ const FunctionDef& function_def,
+ const NodeDef& node_def,
+ FunctionBody** function_body) {
+ std::function<Status(const string&, const OpDef**)> get_function_signature =
+ [&function_library](const string& name, const OpDef** signature) {
+ return function_library.LookUpOpDef(name, signature);
+ };
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
+ function_def, AttrSlice(node_def), &function_library,
+ get_function_signature, function_body));
+ return Status::OK();
+}
+
+// Inlines all functions in a Graph. Does not recursively inline, so if graph
+// contains Function A that calls Function B, calling InlineFunctions once will
+// produce a graph with A inlined but not B. Calling InlineFunctions a second
+// time will produce a graph with both A and B inlined.
+Status InlineFunctions(Graph* graph) {
+ const FunctionLibraryDefinition& function_library = graph->flib_def();
+ std::vector<std::pair<Node*, FunctionBody*>> nodes_and_funcs_to_inline;
+ std::unordered_map<string, std::unique_ptr<FunctionBody>>
+ function_name_to_body;
+ std::function<Status(const string&, const OpDef**)> get_function_signature =
+ [&function_library](const string& name, const OpDef** signature) {
+ return function_library.LookUpOpDef(name, signature);
+ };
+
+ for (Node* node : graph->nodes()) {
+ const FunctionDef* function_def =
+ function_library.Find(node->type_string());
+ if (!function_def) {
+ // Not a function node.
+ continue;
+ }
+ FunctionBody* function_body = nullptr;
+ const string key = Canonicalize(node->def().op(), AttrSlice(node->def()));
+ if (function_name_to_body.find(key) == function_name_to_body.end()) {
+ TF_RETURN_IF_ERROR(CreateFunctionBody(function_library, *function_def,
+ node->def(), &function_body));
+ function_name_to_body.emplace(
+ key, std::unique_ptr<FunctionBody>(function_body));
+ }
+ function_body = function_name_to_body[key].get();
+ if (function_body) {
+ nodes_and_funcs_to_inline.emplace_back(node, function_body);
+ }
+ }
+
+ for (const auto& iter : nodes_and_funcs_to_inline) {
+ InlineFunctionBody(function_library, graph, iter.first, iter.second);
+ }
+ return Status::OK();
+}
+
+// Sets *inlined_graph to be graph with all function NodeDefs in graph inlined.
+// Recursively inlines, so if graph contains Function A that calls Function B,
+// calling InlineAllFunctions once will produce a graph with both A and B
+// inlined.
+Status InlineAllFunctions(const GraphDef& graph_def,
+ GraphDef* inlined_graph_def) {
+ *inlined_graph_def = GraphDef::default_instance();
+ // Create a Graph from graph_def. Inlining needs to happen
+ // on a single Graph object in order to guarantee unique
+ // names of nodes created during the inlining process.
+ GraphConstructorOptions graph_ctor_opts;
+ graph_ctor_opts.allow_internal_ops = true;
+ graph_ctor_opts.expect_device_spec = false;
+ FunctionLibraryDefinition function_library(OpRegistry::Global(),
+ graph_def.library());
+ Graph inlined_graph(function_library);
+ TF_RETURN_IF_ERROR(
+ ConvertGraphDefToGraph(graph_ctor_opts, graph_def, &inlined_graph));
+ while (HasFunctionInGraph(inlined_graph)) {
+ TF_RETURN_IF_ERROR(InlineFunctions(&inlined_graph));
+ }
+ inlined_graph.ToGraphDef(inlined_graph_def);
+ return Status::OK();
+}
} // namespace
// static
@@ -64,6 +163,15 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
new_item->id = id;
new_item->graph = meta_graph.graph_def();
+ if (cfg.inline_functions) {
+ Status s = InlineAllFunctions(meta_graph.graph_def(), &new_item->graph);
+ if (!s.ok()) {
+ LOG(ERROR) << "Unable to inline functions: " << s.error_message()
+ << ", skipping this input.";
+ return nullptr;
+ }
+ }
+
// Attempt to detect the fetch node(s).
if (meta_graph.collection_def().count("train_op") > 0) {
const CollectionDef& nodes = meta_graph.collection_def().at("train_op");