diff options
author | Eugene Zhulenev <ezhulenev@google.com> | 2018-10-01 03:34:35 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-01 03:38:39 -0700 |
commit | 9a169bf3ba840af8ab3caae7ea1c69c682be3ab7 (patch) | |
tree | 018cc2727fe0f063c9f02f8aae8a4ef406d782da /tensorflow/core/grappler | |
parent | c1c63c936c4bc51b401b82fbe54ed1945f49a314 (diff) |
Add allowed optimizations to GrapplerItem.
(1) Skip UnaryOpComposition rewrite if the optimized graph needs to have a gradient registered for all nodes.
PiperOrigin-RevId: 215188461
Diffstat (limited to 'tensorflow/core/grappler')
-rw-r--r-- | tensorflow/core/grappler/grappler_item.cc | 1 | ||||
-rw-r--r-- | tensorflow/core/grappler/grappler_item.h | 9 | ||||
-rw-r--r-- | tensorflow/core/grappler/op_types.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/grappler/op_types.h | 1 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/meta_optimizer.cc | 19 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/meta_optimizer_test.cc | 126 |
8 files changed, 166 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/grappler_item.cc b/tensorflow/core/grappler/grappler_item.cc index bbc0fedd22..2c490f3966 100644 --- a/tensorflow/core/grappler/grappler_item.cc +++ b/tensorflow/core/grappler/grappler_item.cc @@ -38,6 +38,7 @@ GrapplerItem::GrapplerItem(const GrapplerItem& other, GraphDef* graph_def) { restore_op = other.restore_op; save_restore_loc_tensor = other.save_restore_loc_tensor; queue_runners = other.queue_runners; + allowed_optimizations = other.allowed_optimizations; graph.Swap(graph_def); } diff --git a/tensorflow/core/grappler/grappler_item.h b/tensorflow/core/grappler/grappler_item.h index 939e5fa046..a0748abfe6 100644 --- a/tensorflow/core/grappler/grappler_item.h +++ b/tensorflow/core/grappler/grappler_item.h @@ -77,6 +77,15 @@ struct GrapplerItem { // Return a set of node names that must be preserved. This includes feed and // fetch nodes, keep_ops, init_ops. std::unordered_set<string> NodesToPreserve() const; + + // Restrict types of optimizations that are allowed for this GrapplerItem. + struct AllowedOptimizations { + // Is it allowed to add nodes to the graph that do not have registered + // gradient function. + bool non_differentiable_rewrites = true; + }; + + AllowedOptimizations allowed_optimizations; }; // Return the transitive fanin of a set of terminal nodes. diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 3521669b63..9f0d9dbf28 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -425,6 +425,10 @@ bool IsSwitch(const NodeDef& node) { return op == "Switch" || op == "RefSwitch"; } +bool IsSymbolicGradient(const NodeDef& node) { + return node.op() == "SymbolicGradient"; +} + bool IsTanhGrad(const NodeDef& node) { return node.op() == "TanhGrad"; } bool IsTile(const NodeDef& node) { return node.op() == "Tile"; } diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 25ab6b65ac..7f86a5f295 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -149,6 +149,7 @@ bool IsStridedSliceGrad(const NodeDef& node); bool IsSub(const NodeDef& node); bool IsSum(const NodeDef& node); bool IsSwitch(const NodeDef& node); +bool IsSymbolicGradient(const NodeDef& node); bool IsTanhGrad(const NodeDef& node); bool IsTile(const NodeDef& node); bool IsTranspose(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 960d1addb3..c708f84948 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -525,6 +525,7 @@ cc_library( "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler/utils:colocation", @@ -541,6 +542,7 @@ tf_cuda_cc_test( ":custom_graph_optimizer_registry", ":meta_optimizer", "//tensorflow/cc:cc_ops", + "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core:tensorflow", "//tensorflow/core:test", diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 3388ee8035..7d5014ee0a 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -3249,6 +3249,10 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/, optimized_graph_ = &optimized_item.graph; node_map_.reset(new NodeMap(optimized_graph_)); + // Disable restricted graph rewrites. + options_.unary_ops_composition &= + item.allowed_optimizations.non_differentiable_rewrites; + if (options_.dedup_computations) { DedupComputations(); } diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 406c1b60ce..a5f851fb1a 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/core/grappler/utils/functions.h" #include "tensorflow/core/grappler/utils/topological_sort.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/util/ptr_util.h" namespace tensorflow { @@ -413,6 +414,15 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, FunctionLibraryDefinition flib(OpRegistry::Global(), optimized_graph->library()); + // Find functions for which we might need to compute a gradient at runtime. + gtl::FlatSet<string> differentiable_functions; + for (const NodeDef& node : optimized_graph->node()) { + if (IsSymbolicGradient(node)) { + const auto* f_attr = gtl::FindOrNull(node.attr(), "f"); + if (f_attr) differentiable_functions.insert(f_attr->func().name()); + } + } + // Optimize each function only once. std::unordered_set<string> optimized_funcs; bool optimize_function_library = true; @@ -428,6 +438,8 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // Skip parametrized functions (function type or body is defined only at // function call time by caller node attributes). + // They should be specialized to their instantiation type parameters by + // the function optimizer, before we can optimize function body. if (IsParametrized(func)) continue; VLOG(3) << "Optimize function: function=" << func_name; @@ -442,6 +454,13 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem( func, flib, item.graph.versions().producer(), &func_item)); + // If we need to compute the gradient of optimized function at runtime, we + // can't perform non-differentiable rewrites. + if (differentiable_functions.find(func_name) != + differentiable_functions.end()) { + func_item.allowed_optimizations.non_differentiable_rewrites = false; + } + // Optimize function body graph. GraphDef optimized_func_graph; TF_RETURN_IF_ERROR( diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc index c477c4d4b1..3f3f43382f 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils/grappler_test.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -82,6 +83,48 @@ class TestOptimizerWithParams : public TestOptimizer { REGISTER_GRAPH_OPTIMIZER(TestOptimizerWithParams); +// Record various properties of the GrapplerItems passed for optimization. +class GrapplerItemPropertiesAccumulator : public CustomGraphOptimizer { + public: + static void SetAllowedOptimizations( + gtl::FlatMap<string, GrapplerItem::AllowedOptimizations>* + allowed_optimizations) { + allowed_optimizations_ = allowed_optimizations; + } + static void ResetAllowedOptimizations() { allowed_optimizations_ = nullptr; } + + GrapplerItemPropertiesAccumulator() {} + string name() const override { + return "grappler_item_properties_accumulator"; + } + + Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return Status::OK(); + } + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override { + *optimized_graph = item.graph; + if (allowed_optimizations_) { + allowed_optimizations_->insert({item.id, item.allowed_optimizations}); + } + return Status::OK(); + } + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimized_graph, double result) override {} + + private: + static gtl::FlatMap<string, GrapplerItem::AllowedOptimizations>* + allowed_optimizations_; +}; + +gtl::FlatMap<string, GrapplerItem::AllowedOptimizations>* + GrapplerItemPropertiesAccumulator::allowed_optimizations_; + +REGISTER_GRAPH_OPTIMIZER(GrapplerItemPropertiesAccumulator); + class MetaOptimizerTest : public GrapplerTest {}; TEST_F(MetaOptimizerTest, RunsCustomOptimizer) { @@ -335,6 +378,89 @@ TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) { test::ExpectTensorEqual<int>(tensors_expected[1], tensors[1]); } +TEST_F(MetaOptimizerTest, OptimizeFunctionLibraryWithRestrictions) { + using test::function::NDef; + using FDH = FunctionDefHelper; + + // We will record what type of optimizations meta optimizer allows for each + // GrapplerItem (main graph and graphs for each function). + gtl::FlatMap<string, GrapplerItem::AllowedOptimizations> + allowed_optimizations; + GrapplerItemPropertiesAccumulator::SetAllowedOptimizations( + &allowed_optimizations); + + // Just record properties of optimized Grappler items. + RewriterConfig rewriter_config; + rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO); + rewriter_config.add_optimizers("GrapplerItemPropertiesAccumulator"); + rewriter_config.set_min_graph_nodes(-1); + + MetaOptimizer optimizer(nullptr, rewriter_config); + + // Define simple function library with two identical mul functions. + FunctionDef mul_func_1 = FunctionDefHelper::Create( + "MyMul1", {"x:float", "y:float"}, {"z:float"}, {}, + {{{"mul"}, "Mul", {"x", "y"}, {}}}, + /* Mapping between function returns and function node outputs. */ + {{"z", "mul:z:0"}}); + + FunctionDef mul_func_2 = FunctionDefHelper::Create( + "MyMul2", {"x:float", "y:float"}, {"z:float"}, {}, + {{{"mul"}, "Mul", {"x", "y"}, {}}}, + /* Mapping between function returns and function node outputs. */ + {{"z", "mul:z:0"}}); + + // Tensorflow graph: + // + // x0 = tf.Placeholder(tf.float); + // x1 = tf.Placeholder(tf.float); + // dy = tf.Placeholder(tf.float); + // + // mul_1 = MyMul1(x0, x1); + // mul_2 = MyMul2(x0, x1); + // dx = SymbolicGradient({x0, x1, dy}, f=MyMul2) + GrapplerItem item; + item.id = "main"; + item.graph = test::function::GDef( + {NDef("x0", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice), + NDef("x1", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice), + NDef("dy", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice), + // Calls into function library + NDef("mul_1", "MyMul1", {"x0", "x1"}, {}, kDevice), + NDef("mul_2", "MyMul2", {"x0", "x1"}, {}, kDevice), + // Symbolic gradient of a MyMul2 + NDef("dx", "SymbolicGradient", {"x0", "x1", "dy"}, + {{"f", FDH::FunctionRef("MyMul2", {})}, + {"Tin", DataTypeSlice{DT_FLOAT}}, + {"Tout", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}, + kDevice)}, + // FunctionLib + {mul_func_1, mul_func_2}); + item.fetch = {"mul_1", "mul_2", "dx"}; + + GraphDef output; + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + // Our custom optimizer must be called for the main graph and for the two + // functions. + ASSERT_EQ(allowed_optimizations.size(), 3); + + auto allowed_optimizations_main = + gtl::FindOrNull(allowed_optimizations, "main"); + ASSERT_NE(allowed_optimizations_main, nullptr); + EXPECT_TRUE(allowed_optimizations_main->non_differentiable_rewrites); + + auto allowed_optimizations_my_mul_1 = + gtl::FindOrNull(allowed_optimizations, "MyMul1"); + ASSERT_NE(allowed_optimizations_my_mul_1, nullptr); + EXPECT_TRUE(allowed_optimizations_my_mul_1->non_differentiable_rewrites); + + auto allowed_optimizations_my_mul_2 = + gtl::FindOrNull(allowed_optimizations, "MyMul2"); + ASSERT_NE(allowed_optimizations_my_mul_2, nullptr); + EXPECT_FALSE(allowed_optimizations_my_mul_2->non_differentiable_rewrites); +} + } // namespace } // namespace grappler } // namespace tensorflow |