diff options
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/core/common_runtime/function.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/function_test.cc | 38 | ||||
-rw-r--r-- | tensorflow/core/framework/function.cc | 24 | ||||
-rw-r--r-- | tensorflow/core/framework/function.h | 7 |
4 files changed, 62 insertions, 9 deletions
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 472865ca43..e0e5f4a215 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -551,7 +551,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate( item->func_graph = fbody; item->overlay_lib = options.overlay_lib; item->instantiation_counter = 1; - item->executor_type = options.executor_type; + item->executor_type = ExecutorType(options, attrs); items_.emplace(next_handle_, std::unique_ptr<Item>(item)); next_handle_++; } diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index 7bab9be9a6..716167132b 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -584,7 +584,28 @@ TEST_F(FunctionLibraryRuntimeTest, ExecutorFactory) { "Internal: This is a dummy."); } - // Test that non-existent exector types trigger an error. + // Test that a non-default executor factory can be invoked via an attr. + { + FunctionLibraryRuntime::InstantiateOptions options; + HasError(InstantiateAndRun(flr0_, "XTimesTwo", + {{"T", DT_FLOAT}, {"_executor", "DUMMY"}}, + options, {x}, {&y}), + "Internal: This is a dummy."); + } + + // Test that a non-default executor factory specified via an + // `InstantiateOptions` supersedes the attr when both are present. + { + FunctionLibraryRuntime::InstantiateOptions options; + options.executor_type = "DUMMY"; + HasError( + InstantiateAndRun(flr0_, "XTimesTwo", + {{"T", DT_FLOAT}, {"_executor", "UNKNOWN_EXECUTOR"}}, + options, {x}, {&y}), + "Internal: This is a dummy."); + } + + // Test that non-existent executor types trigger an error. { FunctionLibraryRuntime::InstantiateOptions options; options.executor_type = "UNKNOWN_EXECUTOR"; @@ -593,6 +614,15 @@ TEST_F(FunctionLibraryRuntimeTest, ExecutorFactory) { "Not found: No executor factory registered for the given executor " "type: UNKNOWN_EXECUTOR"); } + { + FunctionLibraryRuntime::InstantiateOptions options; + HasError( + InstantiateAndRun(flr0_, "XTimesTwo", + {{"T", DT_FLOAT}, {"_executor", "UNKNOWN_EXECUTOR"}}, + options, {x}, {&y}), + "Not found: No executor factory registered for the given executor " + "type: UNKNOWN_EXECUTOR"); + } } TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) { @@ -869,7 +899,7 @@ TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) { Scope s = Scope::NewRootScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); auto x4_x2_scale = ops::Const<float>( - s.WithOpName("x4/x2/scale/_12__cf__10") + s.WithOpName("x4/x2/scale/_12__cf__13") .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"), 2.0f); auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale); @@ -1076,13 +1106,13 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) { auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1); auto scale = ops::Const( - s.WithOpName("scale/_6__cf__15") + s.WithOpName("scale/_6__cf__18") .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"), 2.0f); auto func1_gx = ops::Mul(s.WithOpName("Func/_1/gx"), func0, scale); auto func1_sx = ops::Shape(s.WithOpName("Func/_1/sx"), x); auto const0 = ops::Const( - s.WithOpName("Func/_1/sy/_5__cf__14") + s.WithOpName("Func/_1/sy/_5__cf__17") .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"), 0, {0}); auto func1_rx = ops::internal::BroadcastGradientArgs( diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 20f957190b..aa2f274752 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -796,12 +796,28 @@ uint64 FunctionDefHash(const FunctionDef& fdef) { return h; } +static constexpr const char* const kExecutorAttr = "_executor"; + +/* static */ +string FunctionLibraryRuntime::ExecutorType(const InstantiateOptions& options, + AttrSlice attrs) { + if (!options.executor_type.empty()) { + return options.executor_type; + } else if (const AttrValue* executor_attr = attrs.Find(kExecutorAttr)) { + return executor_attr->s(); + } else { + return string(); + } +} + string Canonicalize(const string& funcname, AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options) { std::vector<string> entries; entries.reserve(options.target.empty() ? attrs.size() : (attrs.size() + 1)); for (auto p : attrs) { - entries.push_back(strings::StrCat(p.first, "=", Print(p.second))); + if (p.first != kExecutorAttr) { + entries.push_back(strings::StrCat(p.first, "=", Print(p.second))); + } } if (!options.target.empty()) { entries.push_back( @@ -815,9 +831,9 @@ string Canonicalize(const string& funcname, AttrSlice attrs, entries.push_back( strings::StrCat("_state_handle", "=", options.state_handle)); } - if (!options.executor_type.empty()) { - entries.push_back( - strings::StrCat("_executor_type", "=", options.executor_type)); + string executor_type = FunctionLibraryRuntime::ExecutorType(options, attrs); + if (!executor_type.empty()) { + entries.push_back(strings::StrCat(kExecutorAttr, "=", executor_type)); } std::sort(entries.begin(), entries.end()); return strings::StrCat(funcname, "[", str_util::Join(entries, ","), "]"); diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 4d6d68e214..d4beca7e11 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -609,6 +609,13 @@ class FunctionLibraryRuntime { virtual Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def, std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr, FunctionLibraryRuntime** out_flr) = 0; + + // Returns the name of the executor class (in the sense of + // `ExecutorFactory::GetFactory()`) that will be used based on the given + // dynamic `options` and static `attrs`. If none is specified, this method + // will return an empty string, which leaves the decision up to the runtime. + static string ExecutorType(const InstantiateOptions& options, + AttrSlice attrs); }; // Returns a canonicalized string for the instantiation of the |