aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler
diff options
context:
space:
mode:
authorGravatar Max Galkin <maxgalkin@google.com>2017-10-10 13:06:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-10 13:10:20 -0700
commit0ffb522f02129c5d23a8b20ef56d0fefd7be91fe (patch)
treedab0ed9d93b00f4c3fd149b8d13dad7dbb39805b /tensorflow/core/grappler
parente74adb670920dd6f41306a4a40784a535ea7b878 (diff)
Add a flag to erase "_noinline" attribute to allow total inlining in Grappler.
PiperOrigin-RevId: 171722354
Diffstat (limited to 'tensorflow/core/grappler')
-rw-r--r--tensorflow/core/grappler/grappler_item_builder.cc26
-rw-r--r--tensorflow/core/grappler/grappler_item_builder.h20
2 files changed, 26 insertions, 20 deletions
diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc
index cb7d7f7330..d23facf81a 100644
--- a/tensorflow/core/grappler/grappler_item_builder.cc
+++ b/tensorflow/core/grappler/grappler_item_builder.cc
@@ -74,7 +74,7 @@ void InitializeTensor(DataType type, Tensor* tensor) {
// of the cluster type (E.g: single cpu, multiple gpu, etc) being simulated in
// order to get the correct session options and environment, and performing the
// correct optimizations.
-Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def,
+Status OptimizeGraph(const GraphDef& graph_def_arg, GraphDef* output_graph_def,
const ItemConfig& cfg) {
if (!cfg.apply_optimizations && !cfg.inline_functions) {
return Status::OK();
@@ -83,8 +83,16 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def,
// Create a session option for a single GPU device.
SessionOptions options;
- // Inline all functions.
- GraphDef inlined_graph_def(graph_def);
+ // Make a local copy of graph def, because we need to change some things.
+ GraphDef graph_def(graph_def_arg);
+
+ if (cfg.inline_functions && cfg.erase_noinline_attributes) {
+ // TF optimizer doesn't inline functions with "_noinline" attribute,
+ // so let's go over the function library and erase it.
+ for (auto& func : *graph_def.mutable_library()->mutable_function()) {
+ func.mutable_attr()->erase("_noinline");
+ }
+ }
// Instantiate all variables for function library runtime creation.
std::vector<Device*> devices;
@@ -92,7 +100,7 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def,
options, "/job:localhost/replica:0/task:0", &devices));
std::unique_ptr<DeviceMgr> dvc_mgr(new DeviceMgr(devices));
FunctionLibraryDefinition function_library(OpRegistry::Global(),
- inlined_graph_def.library());
+ graph_def.library());
Env* env = Env::Default();
// Optimizer options: L1 and inlining. L1 is default.
@@ -108,7 +116,7 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def,
// Create the function library runtime.
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(dvc_mgr.get(), env,
- inlined_graph_def.versions().producer(),
+ graph_def.versions().producer(),
&function_library, *optimizer_opts));
FunctionLibraryRuntime* flr = pflr->GetFLR(devices[0]->name());
@@ -118,11 +126,11 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def,
graph_ctor_opts.expect_device_spec = false;
std::unique_ptr<Graph> graphptr(new Graph(function_library));
// Populate default attrs to the NodeDefs in the GraphDef.
- TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&inlined_graph_def,
- *graphptr->op_registry(), 0));
+ TF_RETURN_IF_ERROR(
+ AddDefaultAttrsToGraphDef(&graph_def, *graphptr->op_registry(), 0));
- TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(graph_ctor_opts, inlined_graph_def,
- graphptr.get()));
+ TF_RETURN_IF_ERROR(
+ ConvertGraphDefToGraph(graph_ctor_opts, graph_def, graphptr.get()));
// Optimize the graph.
GraphOptimizer optimizer(*optimizer_opts);
diff --git a/tensorflow/core/grappler/grappler_item_builder.h b/tensorflow/core/grappler/grappler_item_builder.h
index 4ce5055e7a..9a7f52228b 100644
--- a/tensorflow/core/grappler/grappler_item_builder.h
+++ b/tensorflow/core/grappler/grappler_item_builder.h
@@ -27,24 +27,22 @@ class MetaGraphDef;
namespace grappler {
struct ItemConfig {
- ItemConfig()
- : ignore_user_placement(true),
- ignore_colocation(true),
- placeholder_unknown_output_shape_dim(-1),
- apply_optimizations(false),
- inline_functions(false) {}
+ ItemConfig() {}
// If true, ignore all user specified node placement.
- bool ignore_user_placement;
+ bool ignore_user_placement = true;
// If true, ignore all user specified colocation attributes.
- bool ignore_colocation;
+ bool ignore_colocation = true;
// Dimension to use if a placeholder node has an _output_shapes attribute with
// a dimension of -1.
- int placeholder_unknown_output_shape_dim;
+ int placeholder_unknown_output_shape_dim = -1;
// If true, does L1 optimizations.
- bool apply_optimizations;
+ bool apply_optimizations = false;
// If true, does inlining.
- bool inline_functions;
+ bool inline_functions = false;
+ // If true, erases all "_noinline" attributes from user-defined functions.
+ // Has no effect if "inline_functions" is disabled.
+ bool erase_noinline_attributes = false;
// If non-empty, override the directory of asset paths.
string assets_directory_override;
};