aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/grappler_item_builder.cc
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2017-08-17 17:20:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-17 17:24:41 -0700
commit19a55725af8102d72d4e081c5139f0e4bd5a4bb7 (patch)
tree971673c250a44e0c4cfa4ab634a7c4c96f8ebd33 /tensorflow/core/grappler/grappler_item_builder.cc
parent8c0853db731cf80cfeec9dfb4edab95961aaa585 (diff)
Allowing functions to run across devices. This change expands the ProcessFunctionLibraryRuntime library to Instantiate and Run functions on different devices. When a FunctionLibraryRuntime encounters a function with a target that is another device, it delegates Instantiate() and Run() calls to the ProcessFunctionLibraryRuntime.
This change also moves the table_ containing all function instantiations to the PFLR instead of the FunctionLibraryRuntime. PiperOrigin-RevId: 165651194
Diffstat (limited to 'tensorflow/core/grappler/grappler_item_builder.cc')
-rw-r--r--tensorflow/core/grappler/grappler_item_builder.cc11
1 files changed, 6 insertions, 5 deletions
diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc
index 6136651410..b740e8a999 100644
--- a/tensorflow/core/grappler/grappler_item_builder.cc
+++ b/tensorflow/core/grappler/grappler_item_builder.cc
@@ -104,9 +104,11 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def,
optimizer_opts->set_do_function_inlining(cfg.inline_functions);
// Create the function library runtime.
- std::unique_ptr<FunctionLibraryRuntime> flib(NewFunctionLibraryRuntime(
- dvc_mgr.get(), env, devices[0], inlined_graph_def.versions().producer(),
- &function_library, *optimizer_opts));
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
+ new ProcessFunctionLibraryRuntime(dvc_mgr.get(), env,
+ inlined_graph_def.versions().producer(),
+ &function_library, *optimizer_opts));
+ FunctionLibraryRuntime* flr = pflr->GetFLR(devices[0]->name());
// Create the GraphOptimizer to optimize the graph def.
GraphConstructorOptions graph_ctor_opts;
@@ -122,8 +124,7 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def,
// Optimize the graph.
GraphOptimizer optimizer(*optimizer_opts);
- optimizer.Optimize(flib.get(), env, devices[0], &graphptr,
- /*shape_map=*/nullptr);
+ optimizer.Optimize(flr, env, devices[0], &graphptr, /*shape_map=*/nullptr);
graphptr->ToGraphDef(output_graph_def);
return Status::OK();