diff options
5 files changed, 20 insertions, 93 deletions
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 28e09d7b79..0362682bd6 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -94,8 +94,9 @@ Status FunctionalizeControlFlowForFunction( } }); const FunctionBody* body = flr->GetFunctionBody(handle); + Graph* g = body->graph; - // Check if the graph has Switch or Merge node before optimizing the graph. + // Check if the graph has Switch or Merge node. bool has_switch_or_merge = false; for (Node* n : body->graph->nodes()) { if (n->type_string() == "Switch" || n->type_string() == "Merge") { @@ -108,58 +109,13 @@ Status FunctionalizeControlFlowForFunction( // in function body. We still need to rewrite those functions and modify // corresponding nodes. - // Call graph optimizer. The most important optimization we need is constant - // folding, which will replace ops like Shape/BroadcastGradientArgs with - // constant shape input. Without this optimization, those ops might become - // dynamic input for then/else body function and XLA will complain that input - // is not compile time constant. We enable function inlining as well, because - // otherwise we won't be able to infer shape for any node depending on - // function call nodes. - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( - absl::StrCat("functionalize_control_flow_before_opt_", func_name), - *body->graph, fld); - } - // Optimizer accepts std::unique_ptr<Graph>* as input and might change - // underlying pointer, thus we create a new Graph and copy from body->graph. - std::unique_ptr<Graph> optimized_graph(new Graph(fld)); - CopyGraph(*body->graph, optimized_graph.get()); - OptimizerOptions opts; - opts.set_opt_level(OptimizerOptions::L0); - opts.set_do_function_inlining(true); - opts.set_do_constant_folding(true); - GraphOptimizer optimizer(opts); - auto cf_consider_fn = [](const Node* n) { - // Skip SymbolicGradient op when doing constant folding. - // Enabling SymbolicGradient op in constant folding requires - // flr->device() to be non-null, and here we have not constructed - // proper Device object yet (it will be constructed in XlaCompiler). - return n->type_string() != FunctionLibraryDefinition::kGradientOp; - }; - optimizer.Optimize(flr, flr->env(), - /*device=*/nullptr, &optimized_graph, - /*shape_map=*/nullptr, /*cse_consider_fn=*/nullptr, - cf_consider_fn); - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( - absl::StrCat("functionalize_control_flow_after_opt_", func_name), - *optimized_graph, fld); - } - // Some inlined functions might have Switch/Merge nodes. - for (Node* n : optimized_graph->nodes()) { - if (n->type_string() == "Switch" || n->type_string() == "Merge") { - has_switch_or_merge = true; - break; - } - } - // If any node has associated functions, functionalize them first. // Gather nodes with associated functions first, because rewriting those nodes // might involve node deletion/addition. Avoid modifying nodes while iterating // it. std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>> nodes_to_associated_functions; - for (auto* n : optimized_graph->nodes()) { + for (auto* n : g->nodes()) { auto associated_functions = GetAssociatedFunctions(*n, flr); if (!associated_functions.empty()) { nodes_to_associated_functions.push_back({n, associated_functions}); @@ -215,7 +171,7 @@ Status FunctionalizeControlFlowForFunction( // pointer. That's fine because in that case, associated_functions will // only have one member and the loop will only run once. TF_RETURN_IF_ERROR(RewriteAssociatedFunction( - optimized_graph.get(), n, fld, associated_function, new_name)); + g, n, fld, associated_function, new_name)); } } } @@ -227,21 +183,21 @@ Status FunctionalizeControlFlowForFunction( if (VLOG_IS_ON(4)) { dump_graph::DumpGraphToFile( absl::StrCat("functionalize_control_flow_before_fdef_", func_name), - *optimized_graph, fld); + *g, fld); } - TF_RETURN_IF_ERROR(FunctionalizeControlFlow(optimized_graph.get(), fld)); + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g, fld)); if (VLOG_IS_ON(4)) { dump_graph::DumpGraphToFile( - absl::StrCat("functionalize_control_flow_after_fdef_", func_name), - *optimized_graph, fld); + absl::StrCat("functionalize_control_flow_after_fdef_", func_name), *g, + fld); } } if (*modified) { // Add rewritten FunctionDef into library. FunctionDef functionalized_fdef; - TF_RETURN_IF_ERROR(GraphToFunctionDef(*optimized_graph, new_func_name, - &functionalized_fdef)); + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*g, new_func_name, &functionalized_fdef)); if (func_name == new_func_name) { VLOG(2) << "Replacing function " << func_name; TF_RETURN_IF_ERROR( diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index db137f1a19..e81e61b633 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -466,23 +466,23 @@ Graph* GetConstantGraph( bool ReplaceTensorWithConstant( Graph* graph, Device* partition_device, NodeAndOutput tensor, const Tensor& constant, const gtl::FlatSet<Node*>& control_deps, - int64 max_constant_size_in_bytes, bool disable_memory_output_type_check, + int64 max_constant_size_in_bytes, const ConstantFoldNameGenerator& generate_new_name) { // Be conservative when replacing a tensor with a constant, when not // running on CPU. // 1) Do not replace another constant. // 2) If the destination tensor is not an int32 tensor, and has HOST_MEMORY // constraint, do not replace it. - // 3) If the size of the constant in bytes is too large (> + // 3) If the destination tensor is an int32 tensor, and has DEVICE_MEMORY + // constraint, do not replace it. + // 4) If the size of the constant in bytes is too large (> // max_constant_in_bytes), do not replace it. This prevents the size of the // Graph from growing too large. - // 4) If the constant op created does not have a kernel implementation + // 5) If the constant op created does not have a kernel implementation // for the device, do not use it. // TODO(keveman): Consider adding a new constant op that has a kernel // implementation for all types, but with HostMemory constraint on it's // output. - // 5) If the constant op for the device has different output memory type - // from the original op output memory type, do not replace it. if (tensor.first->IsConstant()) { return false; } @@ -497,7 +497,8 @@ bool ReplaceTensorWithConstant( return false; } bool is_int32 = tensor.first->output_type(tensor.second) == DT_INT32; - if (memory_type == HOST_MEMORY && !is_int32) { + if ((memory_type == HOST_MEMORY && !is_int32) || + (memory_type == DEVICE_MEMORY && is_int32)) { return false; } } @@ -535,25 +536,6 @@ bool ReplaceTensorWithConstant( if (!NodeBuilder(builder).Finalize(graph, &constant_node).ok()) { return false; } - if (!disable_memory_output_type_check) { - if (partition_device && device_type != DEVICE_CPU) { - MemoryType original_output_memory_type; - if (!MemoryTypeForOutput(device_type, graph, tensor.first, tensor.second, - &original_output_memory_type) - .ok()) { - return false; - } - MemoryType const_output_memory_type; - if (!MemoryTypeForOutput(device_type, graph, constant_node, 0, - &const_output_memory_type) - .ok()) { - return false; - } - if (original_output_memory_type != const_output_memory_type) { - return false; - } - } - } for (auto edge : edges_to_remove) { graph->AddEdge(constant_node, 0, edge->dst(), edge->dst_input()); graph->RemoveEdge(edge); @@ -660,8 +642,7 @@ Status ConstantFold(const ConstantFoldingOptions& opts, constant_control_deps[tensors_to_replace[c].first]; if (ReplaceTensorWithConstant( graph, partition_device, tensors_to_replace[c], outputs[c], - control_deps, opts.max_constant_size_in_bytes, - opts.disable_memory_output_type_check, generate_new_name)) { + control_deps, opts.max_constant_size_in_bytes, generate_new_name)) { ++num_nodes_replaced; } } diff --git a/tensorflow/core/common_runtime/constant_folding.h b/tensorflow/core/common_runtime/constant_folding.h index 4c71b7bd27..a9a84f761b 100644 --- a/tensorflow/core/common_runtime/constant_folding.h +++ b/tensorflow/core/common_runtime/constant_folding.h @@ -45,10 +45,6 @@ struct ConstantFoldingOptions { // optimization. int64 max_constant_size_in_bytes = 10 * 1024 * 1024; - // If disable_memory_output_type_check is true, we will disable output memory - // type check for constant node replacement. - bool disable_memory_output_type_check = false; - // A generator for the name suffix of constant folded nodes. A // default id generator that monotonically increases is used if nullptr is // passed. diff --git a/tensorflow/core/common_runtime/graph_optimizer.cc b/tensorflow/core/common_runtime/graph_optimizer.cc index 91194bc86f..37a979a8f1 100644 --- a/tensorflow/core/common_runtime/graph_optimizer.cc +++ b/tensorflow/core/common_runtime/graph_optimizer.cc @@ -39,8 +39,7 @@ void GraphOptimizer::Optimize( const std::unordered_map<string, std::vector<PartialTensorShape>>* shape_map, const std::function<bool(const Node*)>& cse_consider_fn, - const std::function<bool(const Node*)>& cf_consider_fn, - bool cf_disable_memory_output_type_check) { + const std::function<bool(const Node*)>& cf_consider_fn) { Graph* g = graph->get(); DumpGraph("Initial", g); @@ -65,8 +64,6 @@ void GraphOptimizer::Optimize( ConstantFoldingOptions cf_opts; cf_opts.shape_map = shape_map; cf_opts.consider = cf_consider_fn; - cf_opts.disable_memory_output_type_check = - cf_disable_memory_output_type_check; if (opts_.max_folded_constant_in_bytes() > 0) { cf_opts.max_constant_size_in_bytes = opts_.max_folded_constant_in_bytes(); diff --git a/tensorflow/core/common_runtime/graph_optimizer.h b/tensorflow/core/common_runtime/graph_optimizer.h index 8954e9612d..789cc56942 100644 --- a/tensorflow/core/common_runtime/graph_optimizer.h +++ b/tensorflow/core/common_runtime/graph_optimizer.h @@ -47,16 +47,13 @@ class GraphOptimizer { // returns true will be considered for CSE. // If cf_consider_fn is not null then only nodes for which cf_consider_fn // returns true will be considered for CF. - // If cf_disable_memory_output_type_check is true, CF will discard output - // memory type check for constant node replacement. void Optimize( FunctionLibraryRuntime* runtime, Env* env, Device* device, std::unique_ptr<Graph>* graph, const std::unordered_map<string, std::vector<PartialTensorShape>>* shape_map, const std::function<bool(const Node*)>& cse_consider_fn = nullptr, - const std::function<bool(const Node*)>& cf_consider_fn = nullptr, - bool cf_disable_memory_output_type_check = false); + const std::function<bool(const Node*)>& cf_consider_fn = nullptr); const OptimizerOptions& options() { return opts_; } |