diff options
Diffstat (limited to 'tensorflow/core/kernels/partitioned_function_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/partitioned_function_ops.cc | 49 |
1 files changed, 42 insertions, 7 deletions
diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc index 876a1704c7..7bb403290d 100644 --- a/tensorflow/core/kernels/partitioned_function_ops.cc +++ b/tensorflow/core/kernels/partitioned_function_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/common_runtime/placer.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/function.h" @@ -104,13 +105,6 @@ class PartitionedCallOp : public AsyncOpKernel { for (auto d : lib->device_mgr()->ListDevices()) { device_set.AddDevice(d); } - Placer placer(graph.get(), &device_set); - OP_REQUIRES_OK_ASYNC(ctx, placer.Run(), done); - - std::unordered_map<string, std::unique_ptr<Graph>> subgraphs; - OP_REQUIRES_OK_ASYNC( - ctx, PartitionHelper(device_set, std::move(graph), &subgraphs), - done); // The FunctionLibraryRuntime's library cannot be mutated from within // an OpKernel, so functions are instantiated in an overlay library. @@ -124,6 +118,47 @@ class PartitionedCallOp : public AsyncOpKernel { new FunctionLibraryDefinition(*lib->GetFunctionLibraryDefinition()); overlay_libs_.emplace(lib, overlay_lib); + GraphOptimizationPassOptions optimization_options; + // TODO(akshayka): Thread SessionOptions (if any) into this kernel, or + // make it possible to specify the relevant options via attributes. + SessionOptions session_options; + session_options.env = ctx->env(); + optimization_options.session_options = &session_options; + optimization_options.graph = &graph; + optimization_options.flib_def = overlay_lib; + optimization_options.device_set = &device_set; + Placer placer(graph.get(), &device_set); + OP_REQUIRES_OK_ASYNC( + ctx, + OptimizationPassRegistry::Global()->RunGrouping( + OptimizationPassRegistry::PRE_PLACEMENT, optimization_options), + done); + OP_REQUIRES_OK_ASYNC(ctx, placer.Run(), done); + OP_REQUIRES_OK_ASYNC( + ctx, + OptimizationPassRegistry::Global()->RunGrouping( + OptimizationPassRegistry::POST_PLACEMENT, optimization_options), + done); + OP_REQUIRES_OK_ASYNC( + ctx, + OptimizationPassRegistry::Global()->RunGrouping( + OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, + optimization_options), + done); + + std::unordered_map<string, std::unique_ptr<Graph>> subgraphs; + OP_REQUIRES_OK_ASYNC( + ctx, PartitionHelper(device_set, std::move(graph), &subgraphs), + done); + optimization_options.graph = nullptr; + optimization_options.device_set = nullptr; + optimization_options.partition_graphs = &subgraphs; + OP_REQUIRES_OK_ASYNC(ctx, + OptimizationPassRegistry::Global()->RunGrouping( + OptimizationPassRegistry::POST_PARTITIONING, + optimization_options), + done); + auto handles = tensorflow::MakeUnique<gtl::FlatMap<string, FHandle>>(); for (const auto& pair : subgraphs) { // TODO(akshayka): Fail gracefully if the set of devices corresponds |