aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-19 18:17:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-19 18:19:45 -0700
commit4f8768319cfa56c25973cc66d920146ad454bd97 (patch)
tree16d1306f0efe34ca84e72a47e03570e809af2cb9
parentb7cca088e90b4c2a28c1038980aa09240584e382 (diff)
Optimize Graph function library.
PiperOrigin-RevId: 193605910
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD4
-rw-r--r--tensorflow/core/grappler/optimizers/function_optimizer.cc126
-rw-r--r--tensorflow/core/grappler/optimizers/function_optimizer.h6
-rw-r--r--tensorflow/core/grappler/optimizers/function_optimizer_test.cc32
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc326
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.h33
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer_test.cc172
-rw-r--r--tensorflow/core/grappler/utils/functions.cc12
-rw-r--r--tensorflow/core/grappler/utils/functions.h40
-rw-r--r--tensorflow/core/grappler/utils/functions_test.cc8
10 files changed, 563 insertions, 196 deletions
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index a371186fe6..3ab8d8f584 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -518,11 +518,13 @@ cc_library(
":loop_optimizer",
":memory_optimizer",
":model_pruner",
+ "//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/utils:colocation",
+ "//tensorflow/core/grappler/utils:functions",
"//tensorflow/core/grappler/utils:topological_sort",
],
)
@@ -539,9 +541,11 @@ tf_cuda_cc_test(
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
+ "//tensorflow/core/grappler/utils:grappler_test",
],
)
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc
index d008a9719f..950933b933 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/graph_constructor.h"
@@ -75,12 +76,10 @@ string UniqueSpecializedFunctionName(const FunctionDef& func,
class FunctionOptimizerContext {
public:
- explicit FunctionOptimizerContext(const GrapplerItem& item,
- RewriterConfig::Toggle opt_level)
- : opt_level_(opt_level),
- function_library_(FunctionLibraryDefinition(OpRegistry::Global(),
- item.graph.library())) {
- InitializeInlinedFunctions(item);
+ explicit FunctionOptimizerContext(RewriterConfig::Toggle opt_level,
+ const GrapplerItem& item)
+ : function_library_(OpRegistry::Global(), item.graph.library()) {
+ InitializeInlinedFunctions(opt_level, item);
}
const FunctionLibraryDefinition& function_library() const {
@@ -101,8 +100,9 @@ class FunctionOptimizerContext {
}
private:
- void InitializeInlinedFunctions(const GrapplerItem& item) {
- bool aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
+ void InitializeInlinedFunctions(RewriterConfig::Toggle opt_level,
+ const GrapplerItem& item) {
+ bool aggressive = opt_level == RewriterConfig::AGGRESSIVE;
for (const FunctionDef& func : item.graph.library().function()) {
// Can't create IdentityN nodes with no input or output: skip these
@@ -120,7 +120,6 @@ class FunctionOptimizerContext {
}
}
- RewriterConfig::Toggle opt_level_;
FunctionLibraryDefinition function_library_;
// Functions that can be inlined into optimized graph.
std::unordered_map<string, const FunctionDef*> inlined_functions_;
@@ -128,9 +127,93 @@ class FunctionOptimizerContext {
TF_DISALLOW_COPY_AND_ASSIGN(FunctionOptimizerContext);
};
+// Return trimmed FunctionDefLibrary with functions that are reachable from
+// the optimized graph.
+FunctionDefLibrary TrimFunctionLibrary(const FunctionLibraryDefinition& flib,
+ const GraphDef& optimized_graph) {
+ // Functions that are reachable from the optimized graph.
+ std::unordered_set<string> keep_funcs;
+
+ std::vector<const FunctionDef*> func_queue;
+ func_queue.reserve(flib.num_functions());
+
+ // Add registered and not already processed functions to the queue by name.
+ const auto add_to_func_queue = [&](const string& func_name) {
+ const FunctionDef* func = flib.Find(func_name);
+ if (func && keep_funcs.find(func_name) == keep_funcs.end()) {
+ func_queue.push_back(func);
+ }
+ };
+
+ // Find all the functions that are reachable from the given node.
+ const auto add_node_to_func_queue = [&](const NodeDef& node) {
+ // Node itself can be a call to the function.
+ add_to_func_queue(node.op());
+
+ // Or node can have an attribute referencing a function.
+ for (const auto& attr : node.attr()) {
+ const auto& attr_value = attr.second;
+
+ // 1. AttrValue.func
+ if (attr_value.has_func()) {
+ add_to_func_queue(attr_value.func().name());
+ }
+
+ // 2. AttrValue.ListValue.func
+ if (attr_value.has_list()) {
+ for (const auto& func : attr_value.list().func()) {
+ add_to_func_queue(func.name());
+ }
+ }
+ }
+ };
+
+ // Add all functions that are directly called from the optimized graph.
+ const auto& graph_nodes = optimized_graph.node();
+ std::for_each(graph_nodes.begin(), graph_nodes.end(), add_node_to_func_queue);
+
+ // Process all reachable functions.
+ while (!func_queue.empty()) {
+ const FunctionDef* func = func_queue.back();
+ func_queue.pop_back();
+
+ const string& func_name = func->signature().name();
+ keep_funcs.insert(func_name);
+
+ // Find all the functions that called from the function body.
+ const auto& func_body = func->node_def();
+ std::for_each(func_body.begin(), func_body.end(), add_node_to_func_queue);
+
+ // Check if the function has a registered gradient.
+ const string grad_func_name = flib.FindGradient(func_name);
+ if (!grad_func_name.empty()) add_to_func_queue(grad_func_name);
+ }
+
+ FunctionDefLibrary lib;
+ for (const string& func_name : keep_funcs) {
+ const FunctionDef* func = CHECK_NOTNULL(flib.Find(func_name));
+ *lib.add_function() = *func;
+
+ const string grad_func_name = flib.FindGradient(func_name);
+ if (!grad_func_name.empty()) {
+ GradientDef* gd = lib.add_gradient();
+ gd->set_function_name(func_name);
+ gd->set_gradient_func(grad_func_name);
+ }
+ }
+
+ VLOG(3) << "Trimmed function library: " << keep_funcs.size() << " functions ("
+ << static_cast<int>(keep_funcs.size() - flib.num_functions()) << ")";
+
+ return lib;
+}
+
Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
FunctionOptimizerContext* ctx,
GraphDef* optimized_graph) {
+ VLOG(2) << "Specialize function instantiation: "
+ << SummarizeNodeDef(func_node);
+
const std::unordered_map<string, AttrValue> func_attr(
func_node.attr().begin(), func_node.attr().end());
@@ -141,20 +224,20 @@ Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
// TODO(ezhulenev): Push down const inputs and known input shapes.
- FunctionDef specialized;
- TF_RETURN_IF_ERROR(MakeSpecializedFunctionDef(item, flib, &specialized));
+ FunctionDef specialized_func;
+ TF_RETURN_IF_ERROR(MakeFunctionDef(item, flib, &specialized_func));
// Find a name for specialized function.
const string specialized_func_name =
UniqueSpecializedFunctionName(func, func_node, flib);
- specialized.mutable_signature()->set_name(specialized_func_name);
- auto* specialized_attr = specialized.mutable_attr();
+ specialized_func.mutable_signature()->set_name(specialized_func_name);
+ auto* specialized_attr = specialized_func.mutable_attr();
(*specialized_attr)[kGrapplerSpecializedFuncAttr].set_b(true);
// Add specialized function to the library.
TF_RETURN_IF_ERROR(
- ctx->mutable_function_library().AddFunctionDef(specialized));
+ ctx->mutable_function_library().AddFunctionDef(specialized_func));
// Add a function call node for the specialized function.
NodeDef* specialized_func_node = optimized_graph->add_node();
@@ -226,6 +309,8 @@ Status HookInlinedFunctionOutputs(
Status InlineFunction(const NodeDef& func_node, const FunctionDef& func,
const FunctionOptimizerContext& ctx,
GraphDef* optimized_graph) {
+ VLOG(2) << "Inline function instantiation: " << SummarizeNodeDef(func_node);
+
const std::unordered_map<string, AttrValue> func_attr(
func_node.attr().begin(), func_node.attr().end());
@@ -359,6 +444,8 @@ class SymbolicGradientEnv {
Status InlineSymbolicGradient(const NodeDef& node, SymbolicGradientEnv* env,
GraphDef* inlined_graph) {
+ VLOG(2) << "Inline symbolic gradient: " << SummarizeNodeDef(node);
+
GraphDef graph_def;
// Create a node to anchor the gradient inputs
@@ -454,13 +541,16 @@ Status InlineSymbolicGradient(const NodeDef& node, SymbolicGradientEnv* env,
Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
+ VLOG(2) << "Optimize function library: id=" << item.id;
+
// Nothing to do here.
if (item.graph.library().function_size() == 0) {
+ VLOG(3) << "Skip Grappler item with empty function library";
*optimized_graph = item.graph;
return Status::OK();
}
- FunctionOptimizerContext ctx(item, opt_level_);
+ FunctionOptimizerContext ctx(opt_level_, item);
SymbolicGradientEnv env(item.graph.versions().producer(),
item.graph.library());
@@ -506,9 +596,11 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
*optimized_graph->add_node() = node;
}
- // TODO(bsteiner): trim the library to remove unused function definitions
*optimized_graph->mutable_versions() = item.graph.versions();
- *optimized_graph->mutable_library() = ctx.function_library().ToProto();
+ *optimized_graph->mutable_library() =
+ options_.enable_trim_function_library
+ ? TrimFunctionLibrary(ctx.function_library(), *optimized_graph)
+ : ctx.function_library().ToProto();
return Status::OK();
}
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.h b/tensorflow/core/grappler/optimizers/function_optimizer.h
index c555fadf83..e307b4e533 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/function_optimizer.h
@@ -26,8 +26,9 @@ namespace grappler {
// operations to make the overall graph more efficient.
class FunctionOptimizer : public GraphOptimizer {
public:
- FunctionOptimizer(RewriterConfig::Toggle opt_level) : opt_level_(opt_level) {}
- ~FunctionOptimizer() override {}
+ explicit FunctionOptimizer(RewriterConfig::Toggle opt_level)
+ : opt_level_(opt_level) {}
+ ~FunctionOptimizer() override = default;
string name() const override { return "function_optimizer"; };
@@ -44,6 +45,7 @@ class FunctionOptimizer : public GraphOptimizer {
bool enable_function_inlining = true;
bool enable_function_specialization = true;
bool enable_symbolic_gradient_inlining = true;
+ bool enable_trim_function_library = true;
};
RewriterConfig::Toggle opt_level_;
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
index fb006d4868..6147e8a27c 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
@@ -31,20 +31,8 @@ constexpr char kDevice[] = "/device:CPU:0";
class FunctionOptimizerTest : public GrapplerTest {
protected:
- void DisableAll(FunctionOptimizer* optimizer) {
- optimizer->options_.enable_function_inlining = false;
+ void DisableFunctionSpecialization(FunctionOptimizer* optimizer) {
optimizer->options_.enable_function_specialization = false;
- optimizer->options_.enable_symbolic_gradient_inlining = false;
- }
-
- void EnableOnlyFunctionInlining(FunctionOptimizer* optimizer) {
- DisableAll(optimizer);
- optimizer->options_.enable_function_inlining = true;
- }
-
- void EnableOnlyFunctionSpecialization(FunctionOptimizer* optimizer) {
- DisableAll(optimizer);
- optimizer->options_.enable_function_specialization = true;
}
};
@@ -352,7 +340,7 @@ TEST_F(FunctionOptimizerTest, InlineFunction_FunctionWithoutInput) {
using test::function::NDef;
FunctionOptimizer optimizer(RewriterConfig::DEFAULT);
- EnableOnlyFunctionInlining(&optimizer);
+ DisableFunctionSpecialization(&optimizer); // do not specialize noinline func
const Tensor kTwo = test::AsScalar<int64>(2);
FunctionDef func = FunctionDefHelper::Define(
@@ -626,14 +614,13 @@ TEST_F(FunctionOptimizerTest, SpecializeFunction_XTimesTwo) {
using test::function::NDef;
FunctionOptimizer optimizer(RewriterConfig::DEFAULT);
- EnableOnlyFunctionSpecialization(&optimizer);
- // Mark XTimesTwo as noinline
+ // Mark XTimesTwo as noinline.
FunctionDef x_times_two = test::function::XTimesTwo();
(*x_times_two.mutable_attr())["_noinline"].set_b(true);
std::vector<FunctionDef> function_library = {x_times_two};
- // Build a graph to compute y = XTimesTwo(x)
+ // Build a graph to compute y = XTimesTwo(x).
GrapplerItem item;
item.graph = test::function::GDef(
{NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
@@ -644,12 +631,13 @@ TEST_F(FunctionOptimizerTest, SpecializeFunction_XTimesTwo) {
GraphDef output;
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
- // Make sure that specialized function was added to the library
- EXPECT_EQ(2, output.library().function_size());
+ // Make sure that specialized function was added to the library and original
+ // function was removed.
+ EXPECT_EQ(1, output.library().function_size());
EXPECT_EQ("XTimesTwo_specialized_for_y",
- output.library().function(1).signature().name());
+ output.library().function(0).signature().name());
- // And 'y' node is calling specialized function
+ // And 'y' node is calling specialized function.
int count = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "y" && count++) {
@@ -658,7 +646,7 @@ TEST_F(FunctionOptimizerTest, SpecializeFunction_XTimesTwo) {
}
EXPECT_EQ(1, count);
- // And that graph evaluation yields the same result
+ // And that graph evaluation yields the same result.
Tensor pi = test::AsScalar<float>(3.14f);
item.fetch = {"z"};
item.feed.emplace_back("x", pi);
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index 558b8a77e8..22799311bc 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
+#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
@@ -29,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
#include "tensorflow/core/grappler/utils/colocation.h"
+#include "tensorflow/core/grappler/utils/functions.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/status.h"
@@ -36,6 +38,9 @@ namespace tensorflow {
namespace grappler {
namespace {
+
+constexpr int kDefaultNumberOfIterations = 1;
+
int64 NumEdges(const GraphDef& graph) {
int64 num_edges = 0;
for (const auto& node : graph.node()) {
@@ -50,144 +55,138 @@ string PrintSizesBeforeAfter(const GraphDef& before, const GraphDef& after) {
NumEdges(after), " edges (",
NumEdges(after) - NumEdges(before), ")");
}
+
+int NumIterations(const RewriterConfig& cfg) {
+ return cfg.meta_optimizer_iterations() == RewriterConfig::DEFAULT_NUM_ITERS
+ ? kDefaultNumberOfIterations
+ : cfg.meta_optimizer_iterations();
+}
+
+// Check if optimizer is allowed to run only once.
+int IsRunOnceOptimizer(const string& name) { return name == "layout"; }
+
} // namespace
-std::unique_ptr<GraphOptimizer> MetaOptimizer::NewOptimizer(
- const string& optimizer) {
- std::unique_ptr<GraphOptimizer> graph_optimizer;
- if (optimizer == "pruning") {
- graph_optimizer.reset(new ModelPruner());
- }
- if (optimizer == "function") {
- graph_optimizer.reset(new FunctionOptimizer(cfg_.function_optimization()));
+std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
+ const string& optimizer) const {
+#define MK_OPT(NAME, VALUE) \
+ if (optimizer == NAME) return std::unique_ptr<GraphOptimizer>(VALUE)
+
+ MK_OPT("pruning", new ModelPruner());
+ MK_OPT("function", new FunctionOptimizer(cfg_.function_optimization()));
+ MK_OPT("constfold", new ConstantFolding(cpu_device_));
+ MK_OPT("layout", new LayoutOptimizer());
+ MK_OPT("memory", new MemoryOptimizer(RewriterConfig::MANUAL));
+ MK_OPT("arithmetic", new ArithmeticOptimizer(cfg_.arithmetic_optimization()));
+ MK_OPT("autoparallel", new AutoParallel(cfg_.auto_parallel().num_replicas()));
+ MK_OPT("loop", new LoopOptimizer(cfg_.loop_optimization()));
+ MK_OPT("dependency", new DependencyOptimizer(cfg_.dependency_optimization()));
+ MK_OPT("debug_stripper", new DebugStripper());
+
+ return std::unique_ptr<GraphOptimizer>();
+#undef MK_OPT
+}
+
+Status MetaOptimizer::InitializeOptimizers(
+ std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
+ if (!cfg_.disable_model_pruning()) {
+ optimizers->emplace_back(new ModelPruner());
}
- if (optimizer == "constfold") {
- graph_optimizer.reset(new ConstantFolding(cpu_device_));
+ if (cfg_.function_optimization() != RewriterConfig::OFF) {
+ optimizers->emplace_back(
+ new FunctionOptimizer(cfg_.function_optimization()));
}
- if (optimizer == "layout") {
- graph_optimizer.reset(new LayoutOptimizer());
+ if (cfg_.debug_stripper() == RewriterConfig::ON) {
+ optimizers->emplace_back(new DebugStripper());
}
- if (optimizer == "memory") {
- graph_optimizer.reset(new MemoryOptimizer(RewriterConfig::MANUAL));
+ if (cfg_.constant_folding() != RewriterConfig::OFF) {
+ optimizers->emplace_back(
+ new ConstantFolding(cfg_.constant_folding(), cpu_device_));
}
- if (optimizer == "arithmetic") {
- graph_optimizer.reset(
+ if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
+ optimizers->emplace_back(
new ArithmeticOptimizer(cfg_.arithmetic_optimization()));
}
- if (optimizer == "autoparallel") {
- graph_optimizer.reset(
- new AutoParallel(cfg_.auto_parallel().num_replicas()));
- }
- if (optimizer == "loop") {
- graph_optimizer.reset(new LoopOptimizer(cfg_.loop_optimization()));
+ if (cfg_.loop_optimization() != RewriterConfig::OFF) {
+ optimizers->emplace_back(new LoopOptimizer(cfg_.loop_optimization()));
}
- if (optimizer == "dependency") {
- graph_optimizer.reset(
+ if (cfg_.dependency_optimization() != RewriterConfig::OFF) {
+ optimizers->emplace_back(
new DependencyOptimizer(cfg_.dependency_optimization()));
}
- if (optimizer == "debug_stripper") {
- graph_optimizer.reset(new DebugStripper());
+ if (cfg_.layout_optimizer() != RewriterConfig::OFF) {
+ optimizers->emplace_back(new LayoutOptimizer());
+ }
+ if (cfg_.memory_optimization() != RewriterConfig::NO_MEM_OPT) {
+ if (cfg_.memory_optimizer_target_node_name_scope().empty()) {
+ optimizers->emplace_back(
+ // Use the default target node name prefix "gradients/"
+ new MemoryOptimizer(cfg_.memory_optimization()));
+ } else {
+ optimizers->emplace_back(
+ new MemoryOptimizer(cfg_.memory_optimization(),
+ cfg_.memory_optimizer_target_node_name_scope()));
+ }
+ }
+ if (cfg_.auto_parallel().enable()) {
+ optimizers->emplace_back(
+ new AutoParallel(cfg_.auto_parallel().num_replicas()));
}
- return graph_optimizer;
+ return Status::OK();
}
-Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
- GraphDef* optimized_graph) {
- std::vector<std::unique_ptr<GraphOptimizer>> optimizers;
- if (cfg_.optimizers().empty()) {
- if (!cfg_.disable_model_pruning()) {
- optimizers.push_back(std::unique_ptr<GraphOptimizer>(new ModelPruner()));
- }
- if (cfg_.function_optimization() != RewriterConfig::OFF) {
- optimizers.push_back(std::unique_ptr<GraphOptimizer>(
- new FunctionOptimizer(cfg_.function_optimization())));
- }
- if (cfg_.debug_stripper() == RewriterConfig::ON) {
- optimizers.push_back(
- std::unique_ptr<GraphOptimizer>(new DebugStripper()));
- }
- if (cfg_.constant_folding() != RewriterConfig::OFF) {
- optimizers.push_back(std::unique_ptr<GraphOptimizer>(
- new ConstantFolding(cfg_.constant_folding(), cpu_device_)));
- }
- if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
- optimizers.push_back(std::unique_ptr<GraphOptimizer>(
- new ArithmeticOptimizer(cfg_.arithmetic_optimization())));
+Status MetaOptimizer::InitializeOptimizersByName(
+ std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
+ for (const string& optimizer_name : cfg_.optimizers()) {
+ auto optimizer = MakeNewOptimizer(optimizer_name);
+ if (optimizer) {
+ VLOG(2) << "Registered default graph optimizer: " << optimizer_name;
+ optimizers->push_back(std::move(optimizer));
+ continue;
}
- if (cfg_.loop_optimization() != RewriterConfig::OFF) {
- optimizers.push_back(std::unique_ptr<GraphOptimizer>(
- new LoopOptimizer(cfg_.loop_optimization())));
- }
- if (cfg_.dependency_optimization() != RewriterConfig::OFF) {
- optimizers.push_back(std::unique_ptr<GraphOptimizer>(
- new DependencyOptimizer(cfg_.dependency_optimization())));
- }
- if (cfg_.layout_optimizer() != RewriterConfig::OFF) {
- optimizers.push_back(
- std::unique_ptr<GraphOptimizer>(new LayoutOptimizer()));
- }
- if (cfg_.memory_optimization() != RewriterConfig::NO_MEM_OPT) {
- if (cfg_.memory_optimizer_target_node_name_scope().empty()) {
- optimizers.push_back(std::unique_ptr<GraphOptimizer>(
- // Use the default target node name prefix "gradients/"
- new MemoryOptimizer(cfg_.memory_optimization())));
- } else {
- optimizers.push_back(
- std::unique_ptr<GraphOptimizer>(new MemoryOptimizer(
- cfg_.memory_optimization(),
- cfg_.memory_optimizer_target_node_name_scope())));
- }
- }
- if (cfg_.auto_parallel().enable()) {
- optimizers.push_back(std::unique_ptr<GraphOptimizer>(
- new AutoParallel(cfg_.auto_parallel().num_replicas())));
- }
- } else {
- const std::set<string> available_optimizers = {
- "pruning", "function", "constfold", "layout",
- "memory", "autoparallel", "arithmetic", "loop",
- "dependency", "debug_stripper"};
- std::vector<string> custom_optimizer_names;
- for (const auto& optimizer_name : cfg_.optimizers()) {
- if (available_optimizers.find(optimizer_name) !=
- available_optimizers.end()) {
- optimizers.push_back(NewOptimizer(optimizer_name));
- } else {
- custom_optimizer_names.push_back(optimizer_name);
- }
- }
- // Now run the custom optimizers.
- for (const auto& optimizer_name : custom_optimizer_names) {
- std::unique_ptr<CustomGraphOptimizer> opt =
- CustomGraphOptimizerRegistry::CreateByNameOrNull(optimizer_name);
- if (opt == nullptr) continue;
- TF_RETURN_IF_ERROR(opt->Init());
- optimizers.push_back(std::move(opt));
+
+ auto custom_optimizer =
+ CustomGraphOptimizerRegistry::CreateByNameOrNull(optimizer_name);
+
+ if (custom_optimizer) {
+ VLOG(2) << "Registered custom graph optimizer: " << optimizer_name;
+ TF_RETURN_IF_ERROR(custom_optimizer->Init());
+ optimizers->push_back(std::move(custom_optimizer));
+ } else {
+ VLOG(2) << "Can't register an optimizer by name: " << optimizer_name;
}
}
+ return Status::OK();
+}
+
+Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) {
+ VLOG(2) << "Optimize GrapplerItem: item.id=" << item.id;
+
+ std::vector<std::unique_ptr<GraphOptimizer>> optimizers;
+ bool register_by_name = !cfg_.optimizers().empty();
+ TF_RETURN_IF_ERROR(register_by_name ? InitializeOptimizersByName(&optimizers)
+ : InitializeOptimizers(&optimizers));
if (optimizers.empty()) {
*optimized_graph = item.graph;
return Status::OK();
}
- // Some optimizers should be run only once.
- const std::set<string> run_once_optimizers = {"layout"};
- bool already_optimized = false;
- const int num_iterations =
- cfg_.meta_optimizer_iterations() == RewriterConfig::DEFAULT_NUM_ITERS
- ? 1
- : cfg_.meta_optimizer_iterations();
+ // Invariant: optimized_graph contains the most recently optimized version of
+ // the graph.
GrapplerItem optimized_item = item;
optimized_graph->Swap(&optimized_item.graph);
- for (int iteration = 0; iteration < num_iterations; ++iteration) {
- VLOG(1) << "Starting optimization iteration " << iteration + 1;
+
+ GraphOptimizationResult optimization_result(item.id);
+
+ for (int iteration = 0; iteration < NumIterations(cfg_); ++iteration) {
+ VLOG(4) << "Starting optimization iteration " << iteration + 1;
+
for (const auto& optimizer : optimizers) {
- // Invariant: optimized_graph contains the most recently optimized
- // version of the graph.
- if (iteration > 0 && run_once_optimizers.count(optimizer->name())) {
- continue;
- }
+ // Some optimizers can run only once.
+ if (iteration > 0 && IsRunOnceOptimizer(optimizer->name())) continue;
+
uint64 start_us = Env::Default()->NowMicros();
// This swaps the current optimized_graph into optimized item and
// resets optimized_graph to an empty graph.
@@ -195,45 +194,114 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
*optimized_graph = GraphDef();
Status status =
optimizer->Optimize(cluster, optimized_item, optimized_graph);
-
uint64 end_us = Env::Default()->NowMicros();
- float duration_ms = (end_us - start_us) / 1000.0f;
+
string result;
if (!status.ok()) {
- VLOG(1) << "Not able to apply optimizer " << optimizer->name() << ": "
- << status.ToString();
optimized_graph->Swap(&optimized_item.graph);
result = status.ToString();
} else {
- already_optimized = true;
+ optimization_result.is_optimized = true;
+ float duration_ms = (end_us - start_us) / 1000.0f;
result = strings::StrCat(
- optimizer->name(), ": ",
PrintSizesBeforeAfter(optimized_item.graph, *optimized_graph),
", time = ", duration_ms, "ms.");
}
- result_.emplace_back(optimizer->name(), result);
- VLOG(1) << result;
+ VLOG(4) << optimizer->name() << ": " << result;
+
+ OptimizerResult optimizer_result{optimizer->name(), result};
+ optimization_result.results.push_back(optimizer_result);
}
}
- if (already_optimized) {
+ // Record graph optimization result.
+ optimization_results_.push_back(optimization_result);
+
+ if (optimization_result.is_optimized) {
TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
ReassignColocation(optimized_graph);
- // Make sure that the optimizers preserved the graph version and library.
- DCHECK_GE(optimized_graph->library().function_size(),
- item.graph.library().function_size());
- DCHECK_GE(optimized_graph->library().gradient_size(),
- item.graph.library().gradient_size());
+ // Make sure that the optimizers preserved the graph version.
DCHECK_EQ(optimized_graph->versions().producer(),
item.graph.versions().producer());
}
+
+ return Status::OK();
+}
+
+Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) {
+ optimization_results_.clear();
+
+ // 1. Optimize main graph
+ TF_RETURN_IF_ERROR(OptimizeGraph(cluster, item, optimized_graph));
+
+ // 2. Optimize function library
+ FunctionLibraryDefinition flib(OpRegistry::Global(),
+ optimized_graph->library());
+
+ // Optimize each function only once.
+ std::unordered_set<string> optimized_funcs;
+ bool optimize_function_library = true;
+
+ while (optimize_function_library) {
+ optimize_function_library = false;
+
+ for (const FunctionDef& func : optimized_graph->library().function()) {
+ const string& func_name = func.signature().name();
+
+ // Skip already optimized functions.
+ if (optimized_funcs.find(func_name) != optimized_funcs.end()) continue;
+
+ // Skip parametrized functions (function type or body is defined only at
+ // function call time by caller node attributes).
+ if (IsParametrized(func)) continue;
+
+ VLOG(3) << "Optimize function: function=" << func_name;
+
+ // Function optimization might specialize nested function calls, so we
+ // have to reset the flag and do at least one more pass over the library.
+ optimize_function_library = true;
+ optimized_funcs.insert(func_name);
+
+ // Make a GrapplerItem from a FunctionDef.
+ GrapplerFunctionItem func_item;
+ TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(func, flib, &func_item));
+
+ // Optimize function body graph.
+ GraphDef optimized_func_graph;
+ TF_RETURN_IF_ERROR(
+ OptimizeGraph(cluster, func_item, &optimized_func_graph));
+
+ // Function body optimization might have created new specialized
+ // functions, add them to the library.
+ TF_RETURN_IF_ERROR(flib.AddLibrary(optimized_func_graph.library()));
+
+ // Convert optimized graph back to FunctionDef.
+ FunctionDef optimized_func;
+ func_item.SwapFunctionBody(std::move(optimized_func_graph));
+ TF_RETURN_IF_ERROR(MakeFunctionDef(func_item, flib, &optimized_func));
+
+ // Replace optimized function with a new FunctionDef.
+ TF_RETURN_IF_ERROR(flib.RemoveFunction(func_name));
+ TF_RETURN_IF_ERROR(flib.AddFunctionDef(optimized_func));
+ }
+
+ // If optimized at least one function, update the graph library.
+ if (optimize_function_library) {
+ *optimized_graph->mutable_library() = flib.ToProto();
+ }
+ }
+
return Status::OK();
}
void MetaOptimizer::PrintResult() {
- for (const auto& result : result_) {
- LOG(INFO) << "Return status of optimizer " << result.first << ": "
- << result.second;
+ for (const GraphOptimizationResult& graph_result : optimization_results_) {
+ LOG(INFO) << "Optimization results for grappler item: " << graph_result.id;
+ for (const OptimizerResult& result : graph_result.results) {
+ LOG(INFO) << "Return status of optimizer " << result.optimizer_name
+ << ": " << result.result;
+ }
}
}
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h
index 382cfe51d4..7cf9a40c2d 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h
@@ -30,7 +30,7 @@ class MetaOptimizer : public GraphOptimizer {
public:
MetaOptimizer(DeviceBase* cpu_device, const RewriterConfig& cfg)
: cpu_device_(cpu_device), cfg_(cfg) {}
- ~MetaOptimizer() override {}
+ ~MetaOptimizer() override = default;
string name() const override { return "meta_optimizer"; };
@@ -43,10 +43,37 @@ class MetaOptimizer : public GraphOptimizer {
const GraphDef& optimized_graph, double result) override;
private:
- std::unique_ptr<GraphOptimizer> NewOptimizer(const string& optimizer);
+ std::unique_ptr<GraphOptimizer> MakeNewOptimizer(
+ const string& optimizer) const;
+
+ // Initialize active optimizers from RewriterConfig toggles.
+ Status InitializeOptimizers(
+ std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const;
+ // Initialize active optimizers from RewriterConfig optimizer names.
+ Status InitializeOptimizersByName(
+ std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const;
+
+ // Run optimization pass over a single GrapplerItem. Meta optimizer might run
+ // multiple such passes: 1) for the main graph 2) for the function library
+ Status OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph);
+
DeviceBase* const cpu_device_; // may be NULL
RewriterConfig cfg_;
- std::vector<std::pair<string, string>> result_;
+
+ struct OptimizerResult {
+ string optimizer_name;
+ string result;
+ };
+
+ struct GraphOptimizationResult {
+ explicit GraphOptimizationResult(const string& id) : id(id) {}
+ string id;
+ bool is_optimized = false;
+ std::vector<OptimizerResult> results;
+ };
+
+ std::vector<GraphOptimizationResult> optimization_results_;
};
bool MetaOptimizerEnabled(const RewriterConfig& cfg);
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
index d9a386b9be..8793ad9633 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
@@ -16,11 +16,14 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/grappler_test.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -28,6 +31,8 @@ namespace tensorflow {
namespace grappler {
namespace {
+constexpr char kDevice[] = "/device:CPU:0";
+
class TestOptimizer : public CustomGraphOptimizer {
public:
static void SetOptimized(const bool flag_value) { optimized_ = flag_value; }
@@ -56,7 +61,9 @@ bool TestOptimizer::optimized_;
REGISTER_GRAPH_OPTIMIZER(TestOptimizer);
-TEST(MetaOptimizerTest, RunsCustomOptimizer) {
+class MetaOptimizerTest : public GrapplerTest {};
+
+TEST_F(MetaOptimizerTest, RunsCustomOptimizer) {
TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
GrapplerItem item;
CHECK(fake_input.NextItem(&item));
@@ -72,7 +79,7 @@ TEST(MetaOptimizerTest, RunsCustomOptimizer) {
EXPECT_TRUE(TestOptimizer::IsOptimized());
}
-TEST(MetaOptimizerTest, RunOptimizersTwice) {
+TEST_F(MetaOptimizerTest, RunOptimizersTwice) {
TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
GrapplerItem item;
CHECK(fake_input.NextItem(&item));
@@ -86,6 +93,167 @@ TEST(MetaOptimizerTest, RunOptimizersTwice) {
TF_EXPECT_OK(status);
}
+TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) {
+ using test::function::NDef;
+
+ // Enable ony function optimization.
+ RewriterConfig rewriter_config;
+ rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
+ rewriter_config.set_function_optimization(RewriterConfig::ON);
+ rewriter_config.add_optimizers("function");
+
+ MetaOptimizer optimizer(nullptr, rewriter_config);
+
+ // Define function library:
+ //
+ // MyMul(x, y) = x * y
+ // *MySquare(x) = MyMul(x, x)
+ // *MyQuadratic(x) = MySquare(MySquare(x))
+ //
+ // * - marked as noinline
+
+ FunctionDef mul_func = FunctionDefHelper::Create(
+ "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
+ {{{"mul"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
+ /* Mapping between function returns and function node outputs. */
+ {{"z", "mul:z:0"}});
+
+ FunctionDef square_func = FunctionDefHelper::Create(
+ "MySquare", {"x:T"}, {"z:T"}, {"T: {float, double}"},
+ {{{"my_mul"}, "MyMul", {"x", "x"}, {{"T", "$T"}}}},
+ /* Mapping between function returns and function node outputs. */
+ {{"z", "my_mul:z:0"}});
+ (*square_func.mutable_attr())["_noinline"].set_b(true);
+
+ FunctionDef quadratic_func = FunctionDefHelper::Create(
+ "MyQuadratic", {"x:T"}, {"z:T"}, {"T: {float, double}"},
+ {{{"square"}, "MySquare", {"x"}, {{"T", "$T"}}},
+ {{"quadratic"}, "MySquare", {"square:z"}, {{"T", "$T"}}}},
+ /* Mapping between function returns and function node outputs. */
+ {{"z", "quadratic:z:0"}});
+ (*quadratic_func.mutable_attr())["_noinline"].set_b(true);
+
+ // Tensorflow graph:
+ //
+ // a = tf.Placeholder(tf.float);
+ // b = tf.Placeholder(tf.int32);
+ //
+ // square = MySquare(a); // a^2
+ // quadratic = MyQuadratic(b); // b^4
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
+ NDef("b", "Placeholder", {}, {{"dtype", DT_INT32}}, kDevice),
+ // Calls into function library
+ NDef("square", "MySquare", {"a"}, {{"T", DT_FLOAT}}, kDevice),
+ NDef("quadratic", "MyQuadratic", {"b"}, {{"T", DT_INT32}}, kDevice),
+ // Forward outputs
+ NDef("out_s", "Identity", {"square:0"}, {{"T", DT_FLOAT}}, kDevice),
+ NDef("out_q", "Identity", {"quadratic:0"}, {{"T", DT_INT32}}, kDevice)},
+ // FunctionLib
+ {mul_func, square_func, quadratic_func});
+
+ GraphDef output;
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ FunctionLibraryDefinition optimized_flib(OpRegistry::Global(),
+ output.library());
+
+ // Specialized and optimized functions should be added to the graph.
+ EXPECT_EQ(6, optimized_flib.num_functions());
+
+ // MyQuadratic should be specialized once:
+ // 0. 'quadratic' node in the main graph
+ const string optimized_0 = "MyQuadratic_specialized_for_quadratic";
+
+ // MySquare should be specialized and optimized for 3 instantiations:
+ // 1. 'square' node in the main graph
+ // 2. 'square' node in the MyQuadratic specialization
+ // 3. 'quadratic' node in the MyQuadratic specialization
+
+ const string optimized_1 = "MySquare_specialized_for_square";
+ const string optimized_2 = "MySquare_specialized_for_square_1";
+ const string optimized_3 = "MySquare_specialized_for_quadratic";
+
+ const FunctionDef* optimized_func_0 = optimized_flib.Find(optimized_0);
+ const FunctionDef* optimized_func_1 = optimized_flib.Find(optimized_1);
+ const FunctionDef* optimized_func_2 = optimized_flib.Find(optimized_2);
+ const FunctionDef* optimized_func_3 = optimized_flib.Find(optimized_3);
+
+ ASSERT_NE(optimized_func_0, nullptr);
+ ASSERT_NE(optimized_func_1, nullptr);
+ ASSERT_NE(optimized_func_2, nullptr);
+ ASSERT_NE(optimized_func_3, nullptr);
+
+ // Graph should call optimized function.
+ int count = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "square" && count++) {
+ EXPECT_EQ("MySquare_specialized_for_square", node.op());
+ } else if (node.name() == "quadratic" && count++) {
+ EXPECT_EQ("MyQuadratic_specialized_for_quadratic", node.op());
+ }
+ }
+ EXPECT_EQ(2, count);
+
+ // Specialized MySquare should call specialized functions.
+ count = 0;
+ for (const NodeDef& node : optimized_func_0->node_def()) {
+ if (node.name() == "square" && count++) {
+ EXPECT_EQ(optimized_2, node.op());
+ } else if (node.name() == "quadratic" && count++) {
+ EXPECT_EQ(optimized_3, node.op());
+ }
+ }
+ EXPECT_EQ(2, count);
+
+ const std::vector<const FunctionDef*> optimized_funcs = {
+ optimized_func_1, optimized_func_1, optimized_func_3};
+
+ // MyMul should be inlined into all optimized versions of MySquare.
+ for (const FunctionDef* optimized_func : optimized_funcs) {
+ count = 0;
+ for (const NodeDef& node : optimized_func->node_def()) {
+ if (node.name() == "my_mul/inlined_inputs" && count++) {
+ EXPECT_EQ("IdentityN", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("x:0", node.input(0));
+ EXPECT_EQ("x:0", node.input(1));
+ } else if (node.name() == "my_mul/x" && count++) {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("my_mul/inlined_inputs:output:0", node.input(0));
+ } else if (node.name() == "my_mul/y" && count++) {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("my_mul/inlined_inputs:output:1", node.input(0));
+ } else if (node.name() == "my_mul/mul" && count++) {
+ EXPECT_EQ("Mul", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("my_mul/x:output:0", node.input(0));
+ EXPECT_EQ("my_mul/y:output:0", node.input(1));
+ } else if (node.name() == "my_mul" && count++) {
+ EXPECT_EQ("IdentityN", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("my_mul/mul:z:0", node.input(0));
+ }
+ EXPECT_TRUE(node.device().empty());
+ }
+ EXPECT_EQ(5, count);
+ }
+
+ item.fetch = {"out_s", "out_q"};
+ item.feed.emplace_back("a", test::AsScalar<float>(2.0f));
+ item.feed.emplace_back("b", test::AsScalar<int>(4));
+ auto tensors_expected = EvaluateFetchNodes(item);
+
+ GrapplerItem optimized(item, std::move(output));
+ auto tensors = EvaluateFetchNodes(optimized);
+
+ test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
+ test::ExpectTensorEqual<int>(tensors_expected[1], tensors[1]);
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc
index 638fe1999a..790809bc67 100644
--- a/tensorflow/core/grappler/utils/functions.cc
+++ b/tensorflow/core/grappler/utils/functions.cc
@@ -545,6 +545,12 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
return Status::OK();
}
+Status MakeGrapplerFunctionItem(const FunctionDef& func,
+ const FunctionLibraryDefinition& flib,
+ GrapplerFunctionItem* item) {
+ return MakeGrapplerFunctionItem(func, AttrValueMap(), flib, item);
+}
+
// Register GrapplerFunctionItem input arg expansion and function body outputs
// in the GrapplerFunctionConnectivity.
Status RegisterGrapplerFunctionConnectivity(
@@ -560,9 +566,9 @@ Status RegisterGrapplerFunctionConnectivity(
return Status::OK();
}
-Status MakeSpecializedFunctionDef(const GrapplerFunctionItem& item,
- const FunctionLibraryDefinition& flib,
- FunctionDef* func) {
+Status MakeFunctionDef(const GrapplerFunctionItem& item,
+ const FunctionLibraryDefinition& flib,
+ FunctionDef* func) {
func->mutable_signature()->set_name(item.id);
func->mutable_signature()->set_is_stateful(item.is_stateful());
diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h
index ab369bcad7..5e8b6c6960 100644
--- a/tensorflow/core/grappler/utils/functions.h
+++ b/tensorflow/core/grappler/utils/functions.h
@@ -38,7 +38,8 @@ using AttrValueMap = std::unordered_map<string, AttrValue>;
// function body in place of function inputs and a resolved input data type.
struct InputArgExpansion {
// TODO(ezhulenev): Add support for functions with tensor sequence inputs of
- // different data types
+ // different data types.
+ // TODO(ezhulenev): Support type parametrized inputs?
string input_name; // name of the function input argument
DataType data_type; // input data type
bool is_ref; // if true, inputs are required to be refs
@@ -53,7 +54,8 @@ struct InputArgExpansion {
// tensors of a function body nodes and a resolved output data type
struct OutputArgExpansion {
// TODO(ezhulenev): Add support for functions with tensor sequence outputs of
- // different data types
+ // different data types.
+ // TODO(ezhulenev): Support type parametrized outputs?
string output_name; // name of the function output argument
DataType data_type; // output data type
bool is_ref; // if true, outputs are refs
@@ -186,13 +188,6 @@ bool HasParametrizedBody(const FunctionDef& func);
// Check if function has parametrized type or body.
bool IsParametrized(const FunctionDef& func);
-// Make a GrapplerFunctionItem from the function definition and attributes.
-// Return error if the given function def cannot be converted.
-Status MakeGrapplerFunctionItem(
- const FunctionDef& func,
- const std::unordered_map<string, AttrValue>& func_instantiation_attr,
- const FunctionLibraryDefinition& flib, GrapplerFunctionItem* item);
-
// Register GrapplerFunctionItem input arg expansion and function body outputs
// in the GrapplerFunctionConnectivity. Use function library definition to
// lookup function body nodes output names and ranges.
@@ -200,11 +195,28 @@ Status RegisterGrapplerFunctionConnectivity(
const GrapplerFunctionItem& item, const FunctionLibraryDefinition& flib,
GrapplerFunctionConnectivity* connectivity);
-// Make a specialized FunctionDef from the GrapplerFunctionItem. Use function
-// library definition to lookup function body nodes output names and ranges.
-Status MakeSpecializedFunctionDef(const GrapplerFunctionItem& item,
- const FunctionLibraryDefinition& flib,
- FunctionDef* func);
+// Make a GrapplerFunctionItem from the function definition and function
+// instantiation attributes (caller node attributes). Returns error if the given
+// function def cannot be converted (e.g. not all attributes are defined).
+Status MakeGrapplerFunctionItem(
+ const FunctionDef& func,
+ const std::unordered_map<string, AttrValue>& func_instantiation_attr,
+ const FunctionLibraryDefinition& flib, GrapplerFunctionItem* item);
+
+// Make a GrapplerFunction item from the function definition. Function must be
+// fully defined (no type or body parametrization).
+// TODO(ezhulenev): Support parametrized functions without fully defined
+// instantiation attributes? Do we ever want to optimize parametrized function
+// without specializing it to it's instantiation attributes (at least types)?
+Status MakeGrapplerFunctionItem(const FunctionDef& func,
+ const FunctionLibraryDefinition& flib,
+ GrapplerFunctionItem* item);
+
+// Make a FunctionDef from the GrapplerFunctionItem. Use function library
+// definition to lookup function body nodes output names and ranges.
+Status MakeFunctionDef(const GrapplerFunctionItem& item,
+ const FunctionLibraryDefinition& flib,
+ FunctionDef* func);
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/utils/functions_test.cc b/tensorflow/core/grappler/utils/functions_test.cc
index 54d235a8a4..6dfd49b943 100644
--- a/tensorflow/core/grappler/utils/functions_test.cc
+++ b/tensorflow/core/grappler/utils/functions_test.cc
@@ -524,7 +524,7 @@ TEST_F(FunctionsTest, FromFunctionDefWithoutInput) {
EXPECT_EQ("two", cast.input(0));
}
-TEST_F(FunctionsTest, MakeSpecializedFunctionDef) {
+TEST_F(FunctionsTest, MakeFunctionDef) {
const Tensor kTwo = test::AsScalar<int64>(2);
FunctionDef func = FunctionDefHelper::Define(
// Name
@@ -550,7 +550,7 @@ TEST_F(FunctionsTest, MakeSpecializedFunctionDef) {
TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
FunctionDef specialized;
- TF_EXPECT_OK(MakeSpecializedFunctionDef(item, flib, &specialized));
+ TF_EXPECT_OK(MakeFunctionDef(item, flib, &specialized));
// Input and output types are resolved based on instantiation attributes.
EXPECT_EQ("x", specialized.signature().input_arg(0).name());
@@ -573,7 +573,7 @@ TEST_F(FunctionsTest, MakeSpecializedFunctionDef) {
EXPECT_EQ(2, count);
}
-TEST_F(FunctionsTest, SwapFunctionBodyAndMakeSpecializedFunctionDef) {
+TEST_F(FunctionsTest, SwapFunctionBodyAndMakeFunctionDef) {
using test::function::NDef;
FunctionDef mul_func = FunctionDefHelper::Create(
@@ -606,7 +606,7 @@ TEST_F(FunctionsTest, SwapFunctionBodyAndMakeSpecializedFunctionDef) {
// Replace function body with identity function
item.SwapFunctionBody(std::move(id_func_body));
FunctionDef specialized;
- TF_EXPECT_OK(MakeSpecializedFunctionDef(item, flib, &specialized));
+ TF_EXPECT_OK(MakeFunctionDef(item, flib, &specialized));
// Check that graph body was updated.
int count = 0;