aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/common_runtime/function.cc2
-rw-r--r--tensorflow/core/common_runtime/function_test.cc38
-rw-r--r--tensorflow/core/framework/function.cc24
-rw-r--r--tensorflow/core/framework/function.h7
4 files changed, 62 insertions, 9 deletions
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 472865ca43..e0e5f4a215 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -551,7 +551,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
item->func_graph = fbody;
item->overlay_lib = options.overlay_lib;
item->instantiation_counter = 1;
- item->executor_type = options.executor_type;
+ item->executor_type = ExecutorType(options, attrs);
items_.emplace(next_handle_, std::unique_ptr<Item>(item));
next_handle_++;
}
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 7bab9be9a6..716167132b 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -584,7 +584,28 @@ TEST_F(FunctionLibraryRuntimeTest, ExecutorFactory) {
"Internal: This is a dummy.");
}
- // Test that non-existent exector types trigger an error.
+ // Test that a non-default executor factory can be invoked via an attr.
+ {
+ FunctionLibraryRuntime::InstantiateOptions options;
+ HasError(InstantiateAndRun(flr0_, "XTimesTwo",
+ {{"T", DT_FLOAT}, {"_executor", "DUMMY"}},
+ options, {x}, {&y}),
+ "Internal: This is a dummy.");
+ }
+
+ // Test that a non-default executor factory specified via an
+ // `InstantiateOptions` supersedes the attr when both are present.
+ {
+ FunctionLibraryRuntime::InstantiateOptions options;
+ options.executor_type = "DUMMY";
+ HasError(
+ InstantiateAndRun(flr0_, "XTimesTwo",
+ {{"T", DT_FLOAT}, {"_executor", "UNKNOWN_EXECUTOR"}},
+ options, {x}, {&y}),
+ "Internal: This is a dummy.");
+ }
+
+ // Test that non-existent executor types trigger an error.
{
FunctionLibraryRuntime::InstantiateOptions options;
options.executor_type = "UNKNOWN_EXECUTOR";
@@ -593,6 +614,15 @@ TEST_F(FunctionLibraryRuntimeTest, ExecutorFactory) {
"Not found: No executor factory registered for the given executor "
"type: UNKNOWN_EXECUTOR");
}
+ {
+ FunctionLibraryRuntime::InstantiateOptions options;
+ HasError(
+ InstantiateAndRun(flr0_, "XTimesTwo",
+ {{"T", DT_FLOAT}, {"_executor", "UNKNOWN_EXECUTOR"}},
+ options, {x}, {&y}),
+ "Not found: No executor factory registered for the given executor "
+ "type: UNKNOWN_EXECUTOR");
+ }
}
TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) {
@@ -869,7 +899,7 @@ TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) {
Scope s = Scope::NewRootScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto x4_x2_scale = ops::Const<float>(
- s.WithOpName("x4/x2/scale/_12__cf__10")
+ s.WithOpName("x4/x2/scale/_12__cf__13")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
2.0f);
auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale);
@@ -1076,13 +1106,13 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) {
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1);
auto scale = ops::Const(
- s.WithOpName("scale/_6__cf__15")
+ s.WithOpName("scale/_6__cf__18")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
2.0f);
auto func1_gx = ops::Mul(s.WithOpName("Func/_1/gx"), func0, scale);
auto func1_sx = ops::Shape(s.WithOpName("Func/_1/sx"), x);
auto const0 = ops::Const(
- s.WithOpName("Func/_1/sy/_5__cf__14")
+ s.WithOpName("Func/_1/sy/_5__cf__17")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
0, {0});
auto func1_rx = ops::internal::BroadcastGradientArgs(
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index 20f957190b..aa2f274752 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -796,12 +796,28 @@ uint64 FunctionDefHash(const FunctionDef& fdef) {
return h;
}
+static constexpr const char* const kExecutorAttr = "_executor";
+
+/* static */
+string FunctionLibraryRuntime::ExecutorType(const InstantiateOptions& options,
+ AttrSlice attrs) {
+ if (!options.executor_type.empty()) {
+ return options.executor_type;
+ } else if (const AttrValue* executor_attr = attrs.Find(kExecutorAttr)) {
+ return executor_attr->s();
+ } else {
+ return string();
+ }
+}
+
string Canonicalize(const string& funcname, AttrSlice attrs,
const FunctionLibraryRuntime::InstantiateOptions& options) {
std::vector<string> entries;
entries.reserve(options.target.empty() ? attrs.size() : (attrs.size() + 1));
for (auto p : attrs) {
- entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
+ if (p.first != kExecutorAttr) {
+ entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
+ }
}
if (!options.target.empty()) {
entries.push_back(
@@ -815,9 +831,9 @@ string Canonicalize(const string& funcname, AttrSlice attrs,
entries.push_back(
strings::StrCat("_state_handle", "=", options.state_handle));
}
- if (!options.executor_type.empty()) {
- entries.push_back(
- strings::StrCat("_executor_type", "=", options.executor_type));
+ string executor_type = FunctionLibraryRuntime::ExecutorType(options, attrs);
+ if (!executor_type.empty()) {
+ entries.push_back(strings::StrCat(kExecutorAttr, "=", executor_type));
}
std::sort(entries.begin(), entries.end());
return strings::StrCat(funcname, "[", str_util::Join(entries, ","), "]");
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index 4d6d68e214..d4beca7e11 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -609,6 +609,13 @@ class FunctionLibraryRuntime {
virtual Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
FunctionLibraryRuntime** out_flr) = 0;
+
+ // Returns the name of the executor class (in the sense of
+ // `ExecutorFactory::GetFactory()`) that will be used based on the given
+ // dynamic `options` and static `attrs`. If none is specified, this method
+ // will return an empty string, which leaves the decision up to the runtime.
+ static string ExecutorType(const InstantiateOptions& options,
+ AttrSlice attrs);
};
// Returns a canonicalized string for the instantiation of the