aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/partitioned_function_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/partitioned_function_ops.cc')
-rw-r--r--tensorflow/core/kernels/partitioned_function_ops.cc49
1 files changed, 42 insertions, 7 deletions
diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc
index 876a1704c7..7bb403290d 100644
--- a/tensorflow/core/kernels/partitioned_function_ops.cc
+++ b/tensorflow/core/kernels/partitioned_function_ops.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/common_runtime/placer.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/function.h"
@@ -104,13 +105,6 @@ class PartitionedCallOp : public AsyncOpKernel {
for (auto d : lib->device_mgr()->ListDevices()) {
device_set.AddDevice(d);
}
- Placer placer(graph.get(), &device_set);
- OP_REQUIRES_OK_ASYNC(ctx, placer.Run(), done);
-
- std::unordered_map<string, std::unique_ptr<Graph>> subgraphs;
- OP_REQUIRES_OK_ASYNC(
- ctx, PartitionHelper(device_set, std::move(graph), &subgraphs),
- done);
// The FunctionLibraryRuntime's library cannot be mutated from within
// an OpKernel, so functions are instantiated in an overlay library.
@@ -124,6 +118,47 @@ class PartitionedCallOp : public AsyncOpKernel {
new FunctionLibraryDefinition(*lib->GetFunctionLibraryDefinition());
overlay_libs_.emplace(lib, overlay_lib);
+ GraphOptimizationPassOptions optimization_options;
+ // TODO(akshayka): Thread SessionOptions (if any) into this kernel, or
+ // make it possible to specify the relevant options via attributes.
+ SessionOptions session_options;
+ session_options.env = ctx->env();
+ optimization_options.session_options = &session_options;
+ optimization_options.graph = &graph;
+ optimization_options.flib_def = overlay_lib;
+ optimization_options.device_set = &device_set;
+ Placer placer(graph.get(), &device_set);
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ OptimizationPassRegistry::Global()->RunGrouping(
+ OptimizationPassRegistry::PRE_PLACEMENT, optimization_options),
+ done);
+ OP_REQUIRES_OK_ASYNC(ctx, placer.Run(), done);
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ OptimizationPassRegistry::Global()->RunGrouping(
+ OptimizationPassRegistry::POST_PLACEMENT, optimization_options),
+ done);
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ OptimizationPassRegistry::Global()->RunGrouping(
+ OptimizationPassRegistry::POST_REWRITE_FOR_EXEC,
+ optimization_options),
+ done);
+
+ std::unordered_map<string, std::unique_ptr<Graph>> subgraphs;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, PartitionHelper(device_set, std::move(graph), &subgraphs),
+ done);
+ optimization_options.graph = nullptr;
+ optimization_options.device_set = nullptr;
+ optimization_options.partition_graphs = &subgraphs;
+ OP_REQUIRES_OK_ASYNC(ctx,
+ OptimizationPassRegistry::Global()->RunGrouping(
+ OptimizationPassRegistry::POST_PARTITIONING,
+ optimization_options),
+ done);
+
auto handles = tensorflow::MakeUnique<gtl::FlatMap<string, FHandle>>();
for (const auto& pair : subgraphs) {
// TODO(akshayka): Fail gracefully if the set of devices corresponds