From 4f8768319cfa56c25973cc66d920146ad454bd97 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Apr 2018 18:17:02 -0700 Subject: Optimize Graph function library. PiperOrigin-RevId: 193605910 --- tensorflow/core/grappler/optimizers/BUILD | 4 + .../core/grappler/optimizers/function_optimizer.cc | 126 ++++++-- .../core/grappler/optimizers/function_optimizer.h | 6 +- .../grappler/optimizers/function_optimizer_test.cc | 32 +- .../core/grappler/optimizers/meta_optimizer.cc | 326 +++++++++++++-------- .../core/grappler/optimizers/meta_optimizer.h | 33 ++- .../grappler/optimizers/meta_optimizer_test.cc | 172 ++++++++++- tensorflow/core/grappler/utils/functions.cc | 12 +- tensorflow/core/grappler/utils/functions.h | 40 ++- tensorflow/core/grappler/utils/functions_test.cc | 8 +- 10 files changed, 563 insertions(+), 196 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index a371186fe6..3ab8d8f584 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -518,11 +518,13 @@ cc_library( ":loop_optimizer", ":memory_optimizer", ":model_pruner", + "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler/utils:colocation", + "//tensorflow/core/grappler/utils:functions", "//tensorflow/core/grappler/utils:topological_sort", ], ) @@ -539,9 +541,11 @@ tf_cuda_cc_test( "//tensorflow/core:tensorflow", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:testlib", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", + "//tensorflow/core/grappler/utils:grappler_test", ], ) diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc index d008a9719f..950933b933 100644 --- a/tensorflow/core/grappler/optimizers/function_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -75,12 +76,10 @@ string UniqueSpecializedFunctionName(const FunctionDef& func, class FunctionOptimizerContext { public: - explicit FunctionOptimizerContext(const GrapplerItem& item, - RewriterConfig::Toggle opt_level) - : opt_level_(opt_level), - function_library_(FunctionLibraryDefinition(OpRegistry::Global(), - item.graph.library())) { - InitializeInlinedFunctions(item); + explicit FunctionOptimizerContext(RewriterConfig::Toggle opt_level, + const GrapplerItem& item) + : function_library_(OpRegistry::Global(), item.graph.library()) { + InitializeInlinedFunctions(opt_level, item); } const FunctionLibraryDefinition& function_library() const { @@ -101,8 +100,9 @@ class FunctionOptimizerContext { } private: - void InitializeInlinedFunctions(const GrapplerItem& item) { - bool aggressive = opt_level_ == RewriterConfig::AGGRESSIVE; + void InitializeInlinedFunctions(RewriterConfig::Toggle opt_level, + const GrapplerItem& item) { + bool aggressive = opt_level == RewriterConfig::AGGRESSIVE; for (const FunctionDef& func : item.graph.library().function()) { // Can't create IdentityN nodes with no input or output: skip these @@ -120,7 +120,6 @@ class FunctionOptimizerContext { } } - RewriterConfig::Toggle opt_level_; FunctionLibraryDefinition function_library_; // Functions that can be inlined into optimized graph. std::unordered_map inlined_functions_; @@ -128,9 +127,93 @@ class FunctionOptimizerContext { TF_DISALLOW_COPY_AND_ASSIGN(FunctionOptimizerContext); }; +// Return trimmed FunctionDefLibrary with functions that are reachable from +// the optimized graph. +FunctionDefLibrary TrimFunctionLibrary(const FunctionLibraryDefinition& flib, + const GraphDef& optimized_graph) { + // Functions that are reachable from the optimized graph. + std::unordered_set keep_funcs; + + std::vector func_queue; + func_queue.reserve(flib.num_functions()); + + // Add registered and not already processed functions to the queue by name. + const auto add_to_func_queue = [&](const string& func_name) { + const FunctionDef* func = flib.Find(func_name); + if (func && keep_funcs.find(func_name) == keep_funcs.end()) { + func_queue.push_back(func); + } + }; + + // Find all the functions that are reachable from the given node. + const auto add_node_to_func_queue = [&](const NodeDef& node) { + // Node itself can be a call to the function. + add_to_func_queue(node.op()); + + // Or node can have an attribute referencing a function. + for (const auto& attr : node.attr()) { + const auto& attr_value = attr.second; + + // 1. AttrValue.func + if (attr_value.has_func()) { + add_to_func_queue(attr_value.func().name()); + } + + // 2. AttrValue.ListValue.func + if (attr_value.has_list()) { + for (const auto& func : attr_value.list().func()) { + add_to_func_queue(func.name()); + } + } + } + }; + + // Add all functions that are directly called from the optimized graph. + const auto& graph_nodes = optimized_graph.node(); + std::for_each(graph_nodes.begin(), graph_nodes.end(), add_node_to_func_queue); + + // Process all reachable functions. + while (!func_queue.empty()) { + const FunctionDef* func = func_queue.back(); + func_queue.pop_back(); + + const string& func_name = func->signature().name(); + keep_funcs.insert(func_name); + + // Find all the functions that called from the function body. + const auto& func_body = func->node_def(); + std::for_each(func_body.begin(), func_body.end(), add_node_to_func_queue); + + // Check if the function has a registered gradient. + const string grad_func_name = flib.FindGradient(func_name); + if (!grad_func_name.empty()) add_to_func_queue(grad_func_name); + } + + FunctionDefLibrary lib; + for (const string& func_name : keep_funcs) { + const FunctionDef* func = CHECK_NOTNULL(flib.Find(func_name)); + *lib.add_function() = *func; + + const string grad_func_name = flib.FindGradient(func_name); + if (!grad_func_name.empty()) { + GradientDef* gd = lib.add_gradient(); + gd->set_function_name(func_name); + gd->set_gradient_func(grad_func_name); + } + } + + VLOG(3) << "Trimmed function library: " << keep_funcs.size() << " functions (" + << static_cast(keep_funcs.size() - flib.num_functions()) << ")"; + + return lib; +} + Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func, FunctionOptimizerContext* ctx, GraphDef* optimized_graph) { + VLOG(2) << "Specialize function instantiation: " + << SummarizeNodeDef(func_node); + const std::unordered_map func_attr( func_node.attr().begin(), func_node.attr().end()); @@ -141,20 +224,20 @@ Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func, TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(func, func_attr, flib, &item)); // TODO(ezhulenev): Push down const inputs and known input shapes. - FunctionDef specialized; - TF_RETURN_IF_ERROR(MakeSpecializedFunctionDef(item, flib, &specialized)); + FunctionDef specialized_func; + TF_RETURN_IF_ERROR(MakeFunctionDef(item, flib, &specialized_func)); // Find a name for specialized function. const string specialized_func_name = UniqueSpecializedFunctionName(func, func_node, flib); - specialized.mutable_signature()->set_name(specialized_func_name); - auto* specialized_attr = specialized.mutable_attr(); + specialized_func.mutable_signature()->set_name(specialized_func_name); + auto* specialized_attr = specialized_func.mutable_attr(); (*specialized_attr)[kGrapplerSpecializedFuncAttr].set_b(true); // Add specialized function to the library. TF_RETURN_IF_ERROR( - ctx->mutable_function_library().AddFunctionDef(specialized)); + ctx->mutable_function_library().AddFunctionDef(specialized_func)); // Add a function call node for the specialized function. NodeDef* specialized_func_node = optimized_graph->add_node(); @@ -226,6 +309,8 @@ Status HookInlinedFunctionOutputs( Status InlineFunction(const NodeDef& func_node, const FunctionDef& func, const FunctionOptimizerContext& ctx, GraphDef* optimized_graph) { + VLOG(2) << "Inline function instantiation: " << SummarizeNodeDef(func_node); + const std::unordered_map func_attr( func_node.attr().begin(), func_node.attr().end()); @@ -359,6 +444,8 @@ class SymbolicGradientEnv { Status InlineSymbolicGradient(const NodeDef& node, SymbolicGradientEnv* env, GraphDef* inlined_graph) { + VLOG(2) << "Inline symbolic gradient: " << SummarizeNodeDef(node); + GraphDef graph_def; // Create a node to anchor the gradient inputs @@ -454,13 +541,16 @@ Status InlineSymbolicGradient(const NodeDef& node, SymbolicGradientEnv* env, Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) { + VLOG(2) << "Optimize function library: id=" << item.id; + // Nothing to do here. if (item.graph.library().function_size() == 0) { + VLOG(3) << "Skip Grappler item with empty function library"; *optimized_graph = item.graph; return Status::OK(); } - FunctionOptimizerContext ctx(item, opt_level_); + FunctionOptimizerContext ctx(opt_level_, item); SymbolicGradientEnv env(item.graph.versions().producer(), item.graph.library()); @@ -506,9 +596,11 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, *optimized_graph->add_node() = node; } - // TODO(bsteiner): trim the library to remove unused function definitions *optimized_graph->mutable_versions() = item.graph.versions(); - *optimized_graph->mutable_library() = ctx.function_library().ToProto(); + *optimized_graph->mutable_library() = + options_.enable_trim_function_library + ? TrimFunctionLibrary(ctx.function_library(), *optimized_graph) + : ctx.function_library().ToProto(); return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.h b/tensorflow/core/grappler/optimizers/function_optimizer.h index c555fadf83..e307b4e533 100644 --- a/tensorflow/core/grappler/optimizers/function_optimizer.h +++ b/tensorflow/core/grappler/optimizers/function_optimizer.h @@ -26,8 +26,9 @@ namespace grappler { // operations to make the overall graph more efficient. class FunctionOptimizer : public GraphOptimizer { public: - FunctionOptimizer(RewriterConfig::Toggle opt_level) : opt_level_(opt_level) {} - ~FunctionOptimizer() override {} + explicit FunctionOptimizer(RewriterConfig::Toggle opt_level) + : opt_level_(opt_level) {} + ~FunctionOptimizer() override = default; string name() const override { return "function_optimizer"; }; @@ -44,6 +45,7 @@ class FunctionOptimizer : public GraphOptimizer { bool enable_function_inlining = true; bool enable_function_specialization = true; bool enable_symbolic_gradient_inlining = true; + bool enable_trim_function_library = true; }; RewriterConfig::Toggle opt_level_; diff --git a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc index fb006d4868..6147e8a27c 100644 --- a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc @@ -31,20 +31,8 @@ constexpr char kDevice[] = "/device:CPU:0"; class FunctionOptimizerTest : public GrapplerTest { protected: - void DisableAll(FunctionOptimizer* optimizer) { - optimizer->options_.enable_function_inlining = false; + void DisableFunctionSpecialization(FunctionOptimizer* optimizer) { optimizer->options_.enable_function_specialization = false; - optimizer->options_.enable_symbolic_gradient_inlining = false; - } - - void EnableOnlyFunctionInlining(FunctionOptimizer* optimizer) { - DisableAll(optimizer); - optimizer->options_.enable_function_inlining = true; - } - - void EnableOnlyFunctionSpecialization(FunctionOptimizer* optimizer) { - DisableAll(optimizer); - optimizer->options_.enable_function_specialization = true; } }; @@ -352,7 +340,7 @@ TEST_F(FunctionOptimizerTest, InlineFunction_FunctionWithoutInput) { using test::function::NDef; FunctionOptimizer optimizer(RewriterConfig::DEFAULT); - EnableOnlyFunctionInlining(&optimizer); + DisableFunctionSpecialization(&optimizer); // do not specialize noinline func const Tensor kTwo = test::AsScalar(2); FunctionDef func = FunctionDefHelper::Define( @@ -626,14 +614,13 @@ TEST_F(FunctionOptimizerTest, SpecializeFunction_XTimesTwo) { using test::function::NDef; FunctionOptimizer optimizer(RewriterConfig::DEFAULT); - EnableOnlyFunctionSpecialization(&optimizer); - // Mark XTimesTwo as noinline + // Mark XTimesTwo as noinline. FunctionDef x_times_two = test::function::XTimesTwo(); (*x_times_two.mutable_attr())["_noinline"].set_b(true); std::vector function_library = {x_times_two}; - // Build a graph to compute y = XTimesTwo(x) + // Build a graph to compute y = XTimesTwo(x). GrapplerItem item; item.graph = test::function::GDef( {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice), @@ -644,12 +631,13 @@ TEST_F(FunctionOptimizerTest, SpecializeFunction_XTimesTwo) { GraphDef output; TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); - // Make sure that specialized function was added to the library - EXPECT_EQ(2, output.library().function_size()); + // Make sure that specialized function was added to the library and original + // function was removed. + EXPECT_EQ(1, output.library().function_size()); EXPECT_EQ("XTimesTwo_specialized_for_y", - output.library().function(1).signature().name()); + output.library().function(0).signature().name()); - // And 'y' node is calling specialized function + // And 'y' node is calling specialized function. int count = 0; for (const NodeDef& node : output.node()) { if (node.name() == "y" && count++) { @@ -658,7 +646,7 @@ TEST_F(FunctionOptimizerTest, SpecializeFunction_XTimesTwo) { } EXPECT_EQ(1, count); - // And that graph evaluation yields the same result + // And that graph evaluation yields the same result. Tensor pi = test::AsScalar(3.14f); item.fetch = {"z"}; item.feed.emplace_back("x", pi); diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 558b8a77e8..22799311bc 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/optimizers/meta_optimizer.h" +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h" @@ -29,6 +30,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/memory_optimizer.h" #include "tensorflow/core/grappler/optimizers/model_pruner.h" #include "tensorflow/core/grappler/utils/colocation.h" +#include "tensorflow/core/grappler/utils/functions.h" #include "tensorflow/core/grappler/utils/topological_sort.h" #include "tensorflow/core/lib/core/status.h" @@ -36,6 +38,9 @@ namespace tensorflow { namespace grappler { namespace { + +constexpr int kDefaultNumberOfIterations = 1; + int64 NumEdges(const GraphDef& graph) { int64 num_edges = 0; for (const auto& node : graph.node()) { @@ -50,144 +55,138 @@ string PrintSizesBeforeAfter(const GraphDef& before, const GraphDef& after) { NumEdges(after), " edges (", NumEdges(after) - NumEdges(before), ")"); } + +int NumIterations(const RewriterConfig& cfg) { + return cfg.meta_optimizer_iterations() == RewriterConfig::DEFAULT_NUM_ITERS + ? kDefaultNumberOfIterations + : cfg.meta_optimizer_iterations(); +} + +// Check if optimizer is allowed to run only once. +int IsRunOnceOptimizer(const string& name) { return name == "layout"; } + } // namespace -std::unique_ptr MetaOptimizer::NewOptimizer( - const string& optimizer) { - std::unique_ptr graph_optimizer; - if (optimizer == "pruning") { - graph_optimizer.reset(new ModelPruner()); - } - if (optimizer == "function") { - graph_optimizer.reset(new FunctionOptimizer(cfg_.function_optimization())); +std::unique_ptr MetaOptimizer::MakeNewOptimizer( + const string& optimizer) const { +#define MK_OPT(NAME, VALUE) \ + if (optimizer == NAME) return std::unique_ptr(VALUE) + + MK_OPT("pruning", new ModelPruner()); + MK_OPT("function", new FunctionOptimizer(cfg_.function_optimization())); + MK_OPT("constfold", new ConstantFolding(cpu_device_)); + MK_OPT("layout", new LayoutOptimizer()); + MK_OPT("memory", new MemoryOptimizer(RewriterConfig::MANUAL)); + MK_OPT("arithmetic", new ArithmeticOptimizer(cfg_.arithmetic_optimization())); + MK_OPT("autoparallel", new AutoParallel(cfg_.auto_parallel().num_replicas())); + MK_OPT("loop", new LoopOptimizer(cfg_.loop_optimization())); + MK_OPT("dependency", new DependencyOptimizer(cfg_.dependency_optimization())); + MK_OPT("debug_stripper", new DebugStripper()); + + return std::unique_ptr(); +#undef MK_OPT +} + +Status MetaOptimizer::InitializeOptimizers( + std::vector>* optimizers) const { + if (!cfg_.disable_model_pruning()) { + optimizers->emplace_back(new ModelPruner()); } - if (optimizer == "constfold") { - graph_optimizer.reset(new ConstantFolding(cpu_device_)); + if (cfg_.function_optimization() != RewriterConfig::OFF) { + optimizers->emplace_back( + new FunctionOptimizer(cfg_.function_optimization())); } - if (optimizer == "layout") { - graph_optimizer.reset(new LayoutOptimizer()); + if (cfg_.debug_stripper() == RewriterConfig::ON) { + optimizers->emplace_back(new DebugStripper()); } - if (optimizer == "memory") { - graph_optimizer.reset(new MemoryOptimizer(RewriterConfig::MANUAL)); + if (cfg_.constant_folding() != RewriterConfig::OFF) { + optimizers->emplace_back( + new ConstantFolding(cfg_.constant_folding(), cpu_device_)); } - if (optimizer == "arithmetic") { - graph_optimizer.reset( + if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) { + optimizers->emplace_back( new ArithmeticOptimizer(cfg_.arithmetic_optimization())); } - if (optimizer == "autoparallel") { - graph_optimizer.reset( - new AutoParallel(cfg_.auto_parallel().num_replicas())); - } - if (optimizer == "loop") { - graph_optimizer.reset(new LoopOptimizer(cfg_.loop_optimization())); + if (cfg_.loop_optimization() != RewriterConfig::OFF) { + optimizers->emplace_back(new LoopOptimizer(cfg_.loop_optimization())); } - if (optimizer == "dependency") { - graph_optimizer.reset( + if (cfg_.dependency_optimization() != RewriterConfig::OFF) { + optimizers->emplace_back( new DependencyOptimizer(cfg_.dependency_optimization())); } - if (optimizer == "debug_stripper") { - graph_optimizer.reset(new DebugStripper()); + if (cfg_.layout_optimizer() != RewriterConfig::OFF) { + optimizers->emplace_back(new LayoutOptimizer()); + } + if (cfg_.memory_optimization() != RewriterConfig::NO_MEM_OPT) { + if (cfg_.memory_optimizer_target_node_name_scope().empty()) { + optimizers->emplace_back( + // Use the default target node name prefix "gradients/" + new MemoryOptimizer(cfg_.memory_optimization())); + } else { + optimizers->emplace_back( + new MemoryOptimizer(cfg_.memory_optimization(), + cfg_.memory_optimizer_target_node_name_scope())); + } + } + if (cfg_.auto_parallel().enable()) { + optimizers->emplace_back( + new AutoParallel(cfg_.auto_parallel().num_replicas())); } - return graph_optimizer; + return Status::OK(); } -Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, - GraphDef* optimized_graph) { - std::vector> optimizers; - if (cfg_.optimizers().empty()) { - if (!cfg_.disable_model_pruning()) { - optimizers.push_back(std::unique_ptr(new ModelPruner())); - } - if (cfg_.function_optimization() != RewriterConfig::OFF) { - optimizers.push_back(std::unique_ptr( - new FunctionOptimizer(cfg_.function_optimization()))); - } - if (cfg_.debug_stripper() == RewriterConfig::ON) { - optimizers.push_back( - std::unique_ptr(new DebugStripper())); - } - if (cfg_.constant_folding() != RewriterConfig::OFF) { - optimizers.push_back(std::unique_ptr( - new ConstantFolding(cfg_.constant_folding(), cpu_device_))); - } - if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) { - optimizers.push_back(std::unique_ptr( - new ArithmeticOptimizer(cfg_.arithmetic_optimization()))); +Status MetaOptimizer::InitializeOptimizersByName( + std::vector>* optimizers) const { + for (const string& optimizer_name : cfg_.optimizers()) { + auto optimizer = MakeNewOptimizer(optimizer_name); + if (optimizer) { + VLOG(2) << "Registered default graph optimizer: " << optimizer_name; + optimizers->push_back(std::move(optimizer)); + continue; } - if (cfg_.loop_optimization() != RewriterConfig::OFF) { - optimizers.push_back(std::unique_ptr( - new LoopOptimizer(cfg_.loop_optimization()))); - } - if (cfg_.dependency_optimization() != RewriterConfig::OFF) { - optimizers.push_back(std::unique_ptr( - new DependencyOptimizer(cfg_.dependency_optimization()))); - } - if (cfg_.layout_optimizer() != RewriterConfig::OFF) { - optimizers.push_back( - std::unique_ptr(new LayoutOptimizer())); - } - if (cfg_.memory_optimization() != RewriterConfig::NO_MEM_OPT) { - if (cfg_.memory_optimizer_target_node_name_scope().empty()) { - optimizers.push_back(std::unique_ptr( - // Use the default target node name prefix "gradients/" - new MemoryOptimizer(cfg_.memory_optimization()))); - } else { - optimizers.push_back( - std::unique_ptr(new MemoryOptimizer( - cfg_.memory_optimization(), - cfg_.memory_optimizer_target_node_name_scope()))); - } - } - if (cfg_.auto_parallel().enable()) { - optimizers.push_back(std::unique_ptr( - new AutoParallel(cfg_.auto_parallel().num_replicas()))); - } - } else { - const std::set available_optimizers = { - "pruning", "function", "constfold", "layout", - "memory", "autoparallel", "arithmetic", "loop", - "dependency", "debug_stripper"}; - std::vector custom_optimizer_names; - for (const auto& optimizer_name : cfg_.optimizers()) { - if (available_optimizers.find(optimizer_name) != - available_optimizers.end()) { - optimizers.push_back(NewOptimizer(optimizer_name)); - } else { - custom_optimizer_names.push_back(optimizer_name); - } - } - // Now run the custom optimizers. - for (const auto& optimizer_name : custom_optimizer_names) { - std::unique_ptr opt = - CustomGraphOptimizerRegistry::CreateByNameOrNull(optimizer_name); - if (opt == nullptr) continue; - TF_RETURN_IF_ERROR(opt->Init()); - optimizers.push_back(std::move(opt)); + + auto custom_optimizer = + CustomGraphOptimizerRegistry::CreateByNameOrNull(optimizer_name); + + if (custom_optimizer) { + VLOG(2) << "Registered custom graph optimizer: " << optimizer_name; + TF_RETURN_IF_ERROR(custom_optimizer->Init()); + optimizers->push_back(std::move(custom_optimizer)); + } else { + VLOG(2) << "Can't register an optimizer by name: " << optimizer_name; } } + return Status::OK(); +} + +Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) { + VLOG(2) << "Optimize GrapplerItem: item.id=" << item.id; + + std::vector> optimizers; + bool register_by_name = !cfg_.optimizers().empty(); + TF_RETURN_IF_ERROR(register_by_name ? InitializeOptimizersByName(&optimizers) + : InitializeOptimizers(&optimizers)); if (optimizers.empty()) { *optimized_graph = item.graph; return Status::OK(); } - // Some optimizers should be run only once. - const std::set run_once_optimizers = {"layout"}; - bool already_optimized = false; - const int num_iterations = - cfg_.meta_optimizer_iterations() == RewriterConfig::DEFAULT_NUM_ITERS - ? 1 - : cfg_.meta_optimizer_iterations(); + // Invariant: optimized_graph contains the most recently optimized version of + // the graph. GrapplerItem optimized_item = item; optimized_graph->Swap(&optimized_item.graph); - for (int iteration = 0; iteration < num_iterations; ++iteration) { - VLOG(1) << "Starting optimization iteration " << iteration + 1; + + GraphOptimizationResult optimization_result(item.id); + + for (int iteration = 0; iteration < NumIterations(cfg_); ++iteration) { + VLOG(4) << "Starting optimization iteration " << iteration + 1; + for (const auto& optimizer : optimizers) { - // Invariant: optimized_graph contains the most recently optimized - // version of the graph. - if (iteration > 0 && run_once_optimizers.count(optimizer->name())) { - continue; - } + // Some optimizers can run only once. + if (iteration > 0 && IsRunOnceOptimizer(optimizer->name())) continue; + uint64 start_us = Env::Default()->NowMicros(); // This swaps the current optimized_graph into optimized item and // resets optimized_graph to an empty graph. @@ -195,45 +194,114 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, *optimized_graph = GraphDef(); Status status = optimizer->Optimize(cluster, optimized_item, optimized_graph); - uint64 end_us = Env::Default()->NowMicros(); - float duration_ms = (end_us - start_us) / 1000.0f; + string result; if (!status.ok()) { - VLOG(1) << "Not able to apply optimizer " << optimizer->name() << ": " - << status.ToString(); optimized_graph->Swap(&optimized_item.graph); result = status.ToString(); } else { - already_optimized = true; + optimization_result.is_optimized = true; + float duration_ms = (end_us - start_us) / 1000.0f; result = strings::StrCat( - optimizer->name(), ": ", PrintSizesBeforeAfter(optimized_item.graph, *optimized_graph), ", time = ", duration_ms, "ms."); } - result_.emplace_back(optimizer->name(), result); - VLOG(1) << result; + VLOG(4) << optimizer->name() << ": " << result; + + OptimizerResult optimizer_result{optimizer->name(), result}; + optimization_result.results.push_back(optimizer_result); } } - if (already_optimized) { + // Record graph optimization result. + optimization_results_.push_back(optimization_result); + + if (optimization_result.is_optimized) { TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph)); ReassignColocation(optimized_graph); - // Make sure that the optimizers preserved the graph version and library. - DCHECK_GE(optimized_graph->library().function_size(), - item.graph.library().function_size()); - DCHECK_GE(optimized_graph->library().gradient_size(), - item.graph.library().gradient_size()); + // Make sure that the optimizers preserved the graph version. DCHECK_EQ(optimized_graph->versions().producer(), item.graph.versions().producer()); } + + return Status::OK(); +} + +Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) { + optimization_results_.clear(); + + // 1. Optimize main graph + TF_RETURN_IF_ERROR(OptimizeGraph(cluster, item, optimized_graph)); + + // 2. Optimize function library + FunctionLibraryDefinition flib(OpRegistry::Global(), + optimized_graph->library()); + + // Optimize each function only once. + std::unordered_set optimized_funcs; + bool optimize_function_library = true; + + while (optimize_function_library) { + optimize_function_library = false; + + for (const FunctionDef& func : optimized_graph->library().function()) { + const string& func_name = func.signature().name(); + + // Skip already optimized functions. + if (optimized_funcs.find(func_name) != optimized_funcs.end()) continue; + + // Skip parametrized functions (function type or body is defined only at + // function call time by caller node attributes). + if (IsParametrized(func)) continue; + + VLOG(3) << "Optimize function: function=" << func_name; + + // Function optimization might specialize nested function calls, so we + // have to reset the flag and do at least one more pass over the library. + optimize_function_library = true; + optimized_funcs.insert(func_name); + + // Make a GrapplerItem from a FunctionDef. + GrapplerFunctionItem func_item; + TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(func, flib, &func_item)); + + // Optimize function body graph. + GraphDef optimized_func_graph; + TF_RETURN_IF_ERROR( + OptimizeGraph(cluster, func_item, &optimized_func_graph)); + + // Function body optimization might have created new specialized + // functions, add them to the library. + TF_RETURN_IF_ERROR(flib.AddLibrary(optimized_func_graph.library())); + + // Convert optimized graph back to FunctionDef. + FunctionDef optimized_func; + func_item.SwapFunctionBody(std::move(optimized_func_graph)); + TF_RETURN_IF_ERROR(MakeFunctionDef(func_item, flib, &optimized_func)); + + // Replace optimized function with a new FunctionDef. + TF_RETURN_IF_ERROR(flib.RemoveFunction(func_name)); + TF_RETURN_IF_ERROR(flib.AddFunctionDef(optimized_func)); + } + + // If optimized at least one function, update the graph library. + if (optimize_function_library) { + *optimized_graph->mutable_library() = flib.ToProto(); + } + } + return Status::OK(); } void MetaOptimizer::PrintResult() { - for (const auto& result : result_) { - LOG(INFO) << "Return status of optimizer " << result.first << ": " - << result.second; + for (const GraphOptimizationResult& graph_result : optimization_results_) { + LOG(INFO) << "Optimization results for grappler item: " << graph_result.id; + for (const OptimizerResult& result : graph_result.results) { + LOG(INFO) << "Return status of optimizer " << result.optimizer_name + << ": " << result.result; + } } } diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h index 382cfe51d4..7cf9a40c2d 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.h +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h @@ -30,7 +30,7 @@ class MetaOptimizer : public GraphOptimizer { public: MetaOptimizer(DeviceBase* cpu_device, const RewriterConfig& cfg) : cpu_device_(cpu_device), cfg_(cfg) {} - ~MetaOptimizer() override {} + ~MetaOptimizer() override = default; string name() const override { return "meta_optimizer"; }; @@ -43,10 +43,37 @@ class MetaOptimizer : public GraphOptimizer { const GraphDef& optimized_graph, double result) override; private: - std::unique_ptr NewOptimizer(const string& optimizer); + std::unique_ptr MakeNewOptimizer( + const string& optimizer) const; + + // Initialize active optimizers from RewriterConfig toggles. + Status InitializeOptimizers( + std::vector>* optimizers) const; + // Initialize active optimizers from RewriterConfig optimizer names. + Status InitializeOptimizersByName( + std::vector>* optimizers) const; + + // Run optimization pass over a single GrapplerItem. Meta optimizer might run + // multiple such passes: 1) for the main graph 2) for the function library + Status OptimizeGraph(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph); + DeviceBase* const cpu_device_; // may be NULL RewriterConfig cfg_; - std::vector> result_; + + struct OptimizerResult { + string optimizer_name; + string result; + }; + + struct GraphOptimizationResult { + explicit GraphOptimizationResult(const string& id) : id(id) {} + string id; + bool is_optimized = false; + std::vector results; + }; + + std::vector optimization_results_; }; bool MetaOptimizerEnabled(const RewriterConfig& cfg); diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc index d9a386b9be..8793ad9633 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc @@ -16,11 +16,14 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/meta_optimizer.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/grappler_test.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -28,6 +31,8 @@ namespace tensorflow { namespace grappler { namespace { +constexpr char kDevice[] = "/device:CPU:0"; + class TestOptimizer : public CustomGraphOptimizer { public: static void SetOptimized(const bool flag_value) { optimized_ = flag_value; } @@ -56,7 +61,9 @@ bool TestOptimizer::optimized_; REGISTER_GRAPH_OPTIMIZER(TestOptimizer); -TEST(MetaOptimizerTest, RunsCustomOptimizer) { +class MetaOptimizerTest : public GrapplerTest {}; + +TEST_F(MetaOptimizerTest, RunsCustomOptimizer) { TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"}); GrapplerItem item; CHECK(fake_input.NextItem(&item)); @@ -72,7 +79,7 @@ TEST(MetaOptimizerTest, RunsCustomOptimizer) { EXPECT_TRUE(TestOptimizer::IsOptimized()); } -TEST(MetaOptimizerTest, RunOptimizersTwice) { +TEST_F(MetaOptimizerTest, RunOptimizersTwice) { TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"}); GrapplerItem item; CHECK(fake_input.NextItem(&item)); @@ -86,6 +93,167 @@ TEST(MetaOptimizerTest, RunOptimizersTwice) { TF_EXPECT_OK(status); } +TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) { + using test::function::NDef; + + // Enable ony function optimization. + RewriterConfig rewriter_config; + rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO); + rewriter_config.set_function_optimization(RewriterConfig::ON); + rewriter_config.add_optimizers("function"); + + MetaOptimizer optimizer(nullptr, rewriter_config); + + // Define function library: + // + // MyMul(x, y) = x * y + // *MySquare(x) = MyMul(x, x) + // *MyQuadratic(x) = MySquare(MySquare(x)) + // + // * - marked as noinline + + FunctionDef mul_func = FunctionDefHelper::Create( + "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"}, + {{{"mul"}, "Mul", {"x", "y"}, {{"T", "$T"}}}}, + /* Mapping between function returns and function node outputs. */ + {{"z", "mul:z:0"}}); + + FunctionDef square_func = FunctionDefHelper::Create( + "MySquare", {"x:T"}, {"z:T"}, {"T: {float, double}"}, + {{{"my_mul"}, "MyMul", {"x", "x"}, {{"T", "$T"}}}}, + /* Mapping between function returns and function node outputs. */ + {{"z", "my_mul:z:0"}}); + (*square_func.mutable_attr())["_noinline"].set_b(true); + + FunctionDef quadratic_func = FunctionDefHelper::Create( + "MyQuadratic", {"x:T"}, {"z:T"}, {"T: {float, double}"}, + {{{"square"}, "MySquare", {"x"}, {{"T", "$T"}}}, + {{"quadratic"}, "MySquare", {"square:z"}, {{"T", "$T"}}}}, + /* Mapping between function returns and function node outputs. */ + {{"z", "quadratic:z:0"}}); + (*quadratic_func.mutable_attr())["_noinline"].set_b(true); + + // Tensorflow graph: + // + // a = tf.Placeholder(tf.float); + // b = tf.Placeholder(tf.int32); + // + // square = MySquare(a); // a^2 + // quadratic = MyQuadratic(b); // b^4 + GrapplerItem item; + item.graph = test::function::GDef( + {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice), + NDef("b", "Placeholder", {}, {{"dtype", DT_INT32}}, kDevice), + // Calls into function library + NDef("square", "MySquare", {"a"}, {{"T", DT_FLOAT}}, kDevice), + NDef("quadratic", "MyQuadratic", {"b"}, {{"T", DT_INT32}}, kDevice), + // Forward outputs + NDef("out_s", "Identity", {"square:0"}, {{"T", DT_FLOAT}}, kDevice), + NDef("out_q", "Identity", {"quadratic:0"}, {{"T", DT_INT32}}, kDevice)}, + // FunctionLib + {mul_func, square_func, quadratic_func}); + + GraphDef output; + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + FunctionLibraryDefinition optimized_flib(OpRegistry::Global(), + output.library()); + + // Specialized and optimized functions should be added to the graph. + EXPECT_EQ(6, optimized_flib.num_functions()); + + // MyQuadratic should be specialized once: + // 0. 'quadratic' node in the main graph + const string optimized_0 = "MyQuadratic_specialized_for_quadratic"; + + // MySquare should be specialized and optimized for 3 instantiations: + // 1. 'square' node in the main graph + // 2. 'square' node in the MyQuadratic specialization + // 3. 'quadratic' node in the MyQuadratic specialization + + const string optimized_1 = "MySquare_specialized_for_square"; + const string optimized_2 = "MySquare_specialized_for_square_1"; + const string optimized_3 = "MySquare_specialized_for_quadratic"; + + const FunctionDef* optimized_func_0 = optimized_flib.Find(optimized_0); + const FunctionDef* optimized_func_1 = optimized_flib.Find(optimized_1); + const FunctionDef* optimized_func_2 = optimized_flib.Find(optimized_2); + const FunctionDef* optimized_func_3 = optimized_flib.Find(optimized_3); + + ASSERT_NE(optimized_func_0, nullptr); + ASSERT_NE(optimized_func_1, nullptr); + ASSERT_NE(optimized_func_2, nullptr); + ASSERT_NE(optimized_func_3, nullptr); + + // Graph should call optimized function. + int count = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "square" && count++) { + EXPECT_EQ("MySquare_specialized_for_square", node.op()); + } else if (node.name() == "quadratic" && count++) { + EXPECT_EQ("MyQuadratic_specialized_for_quadratic", node.op()); + } + } + EXPECT_EQ(2, count); + + // Specialized MySquare should call specialized functions. + count = 0; + for (const NodeDef& node : optimized_func_0->node_def()) { + if (node.name() == "square" && count++) { + EXPECT_EQ(optimized_2, node.op()); + } else if (node.name() == "quadratic" && count++) { + EXPECT_EQ(optimized_3, node.op()); + } + } + EXPECT_EQ(2, count); + + const std::vector optimized_funcs = { + optimized_func_1, optimized_func_1, optimized_func_3}; + + // MyMul should be inlined into all optimized versions of MySquare. + for (const FunctionDef* optimized_func : optimized_funcs) { + count = 0; + for (const NodeDef& node : optimized_func->node_def()) { + if (node.name() == "my_mul/inlined_inputs" && count++) { + EXPECT_EQ("IdentityN", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("x:0", node.input(0)); + EXPECT_EQ("x:0", node.input(1)); + } else if (node.name() == "my_mul/x" && count++) { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("my_mul/inlined_inputs:output:0", node.input(0)); + } else if (node.name() == "my_mul/y" && count++) { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("my_mul/inlined_inputs:output:1", node.input(0)); + } else if (node.name() == "my_mul/mul" && count++) { + EXPECT_EQ("Mul", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("my_mul/x:output:0", node.input(0)); + EXPECT_EQ("my_mul/y:output:0", node.input(1)); + } else if (node.name() == "my_mul" && count++) { + EXPECT_EQ("IdentityN", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("my_mul/mul:z:0", node.input(0)); + } + EXPECT_TRUE(node.device().empty()); + } + EXPECT_EQ(5, count); + } + + item.fetch = {"out_s", "out_q"}; + item.feed.emplace_back("a", test::AsScalar(2.0f)); + item.feed.emplace_back("b", test::AsScalar(4)); + auto tensors_expected = EvaluateFetchNodes(item); + + GrapplerItem optimized(item, std::move(output)); + auto tensors = EvaluateFetchNodes(optimized); + + test::ExpectTensorEqual(tensors_expected[0], tensors[0]); + test::ExpectTensorEqual(tensors_expected[1], tensors[1]); +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc index 638fe1999a..790809bc67 100644 --- a/tensorflow/core/grappler/utils/functions.cc +++ b/tensorflow/core/grappler/utils/functions.cc @@ -545,6 +545,12 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func, return Status::OK(); } +Status MakeGrapplerFunctionItem(const FunctionDef& func, + const FunctionLibraryDefinition& flib, + GrapplerFunctionItem* item) { + return MakeGrapplerFunctionItem(func, AttrValueMap(), flib, item); +} + // Register GrapplerFunctionItem input arg expansion and function body outputs // in the GrapplerFunctionConnectivity. Status RegisterGrapplerFunctionConnectivity( @@ -560,9 +566,9 @@ Status RegisterGrapplerFunctionConnectivity( return Status::OK(); } -Status MakeSpecializedFunctionDef(const GrapplerFunctionItem& item, - const FunctionLibraryDefinition& flib, - FunctionDef* func) { +Status MakeFunctionDef(const GrapplerFunctionItem& item, + const FunctionLibraryDefinition& flib, + FunctionDef* func) { func->mutable_signature()->set_name(item.id); func->mutable_signature()->set_is_stateful(item.is_stateful()); diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h index ab369bcad7..5e8b6c6960 100644 --- a/tensorflow/core/grappler/utils/functions.h +++ b/tensorflow/core/grappler/utils/functions.h @@ -38,7 +38,8 @@ using AttrValueMap = std::unordered_map; // function body in place of function inputs and a resolved input data type. struct InputArgExpansion { // TODO(ezhulenev): Add support for functions with tensor sequence inputs of - // different data types + // different data types. + // TODO(ezhulenev): Support type parametrized inputs? string input_name; // name of the function input argument DataType data_type; // input data type bool is_ref; // if true, inputs are required to be refs @@ -53,7 +54,8 @@ struct InputArgExpansion { // tensors of a function body nodes and a resolved output data type struct OutputArgExpansion { // TODO(ezhulenev): Add support for functions with tensor sequence outputs of - // different data types + // different data types. + // TODO(ezhulenev): Support type parametrized outputs? string output_name; // name of the function output argument DataType data_type; // output data type bool is_ref; // if true, outputs are refs @@ -186,13 +188,6 @@ bool HasParametrizedBody(const FunctionDef& func); // Check if function has parametrized type or body. bool IsParametrized(const FunctionDef& func); -// Make a GrapplerFunctionItem from the function definition and attributes. -// Return error if the given function def cannot be converted. -Status MakeGrapplerFunctionItem( - const FunctionDef& func, - const std::unordered_map& func_instantiation_attr, - const FunctionLibraryDefinition& flib, GrapplerFunctionItem* item); - // Register GrapplerFunctionItem input arg expansion and function body outputs // in the GrapplerFunctionConnectivity. Use function library definition to // lookup function body nodes output names and ranges. @@ -200,11 +195,28 @@ Status RegisterGrapplerFunctionConnectivity( const GrapplerFunctionItem& item, const FunctionLibraryDefinition& flib, GrapplerFunctionConnectivity* connectivity); -// Make a specialized FunctionDef from the GrapplerFunctionItem. Use function -// library definition to lookup function body nodes output names and ranges. -Status MakeSpecializedFunctionDef(const GrapplerFunctionItem& item, - const FunctionLibraryDefinition& flib, - FunctionDef* func); +// Make a GrapplerFunctionItem from the function definition and function +// instantiation attributes (caller node attributes). Returns error if the given +// function def cannot be converted (e.g. not all attributes are defined). +Status MakeGrapplerFunctionItem( + const FunctionDef& func, + const std::unordered_map& func_instantiation_attr, + const FunctionLibraryDefinition& flib, GrapplerFunctionItem* item); + +// Make a GrapplerFunction item from the function definition. Function must be +// fully defined (no type or body parametrization). +// TODO(ezhulenev): Support parametrized functions without fully defined +// instantiation attributes? Do we ever want to optimize parametrized function +// without specializing it to it's instantiation attributes (at least types)? +Status MakeGrapplerFunctionItem(const FunctionDef& func, + const FunctionLibraryDefinition& flib, + GrapplerFunctionItem* item); + +// Make a FunctionDef from the GrapplerFunctionItem. Use function library +// definition to lookup function body nodes output names and ranges. +Status MakeFunctionDef(const GrapplerFunctionItem& item, + const FunctionLibraryDefinition& flib, + FunctionDef* func); } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/utils/functions_test.cc b/tensorflow/core/grappler/utils/functions_test.cc index 54d235a8a4..6dfd49b943 100644 --- a/tensorflow/core/grappler/utils/functions_test.cc +++ b/tensorflow/core/grappler/utils/functions_test.cc @@ -524,7 +524,7 @@ TEST_F(FunctionsTest, FromFunctionDefWithoutInput) { EXPECT_EQ("two", cast.input(0)); } -TEST_F(FunctionsTest, MakeSpecializedFunctionDef) { +TEST_F(FunctionsTest, MakeFunctionDef) { const Tensor kTwo = test::AsScalar(2); FunctionDef func = FunctionDefHelper::Define( // Name @@ -550,7 +550,7 @@ TEST_F(FunctionsTest, MakeSpecializedFunctionDef) { TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item)); FunctionDef specialized; - TF_EXPECT_OK(MakeSpecializedFunctionDef(item, flib, &specialized)); + TF_EXPECT_OK(MakeFunctionDef(item, flib, &specialized)); // Input and output types are resolved based on instantiation attributes. EXPECT_EQ("x", specialized.signature().input_arg(0).name()); @@ -573,7 +573,7 @@ TEST_F(FunctionsTest, MakeSpecializedFunctionDef) { EXPECT_EQ(2, count); } -TEST_F(FunctionsTest, SwapFunctionBodyAndMakeSpecializedFunctionDef) { +TEST_F(FunctionsTest, SwapFunctionBodyAndMakeFunctionDef) { using test::function::NDef; FunctionDef mul_func = FunctionDefHelper::Create( @@ -606,7 +606,7 @@ TEST_F(FunctionsTest, SwapFunctionBodyAndMakeSpecializedFunctionDef) { // Replace function body with identity function item.SwapFunctionBody(std::move(id_func_body)); FunctionDef specialized; - TF_EXPECT_OK(MakeSpecializedFunctionDef(item, flib, &specialized)); + TF_EXPECT_OK(MakeFunctionDef(item, flib, &specialized)); // Check that graph body was updated. int count = 0; -- cgit v1.2.3