aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-10-10 08:17:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-10 08:20:50 -0700
commitafcc1a4452de7898391683f7cbb16ff548f839a1 (patch)
treee65547621501d7a9d9ccba50939391f77ecc9323
parent1ae0a45a5de65ab4ae6def232da016e7ee32773c (diff)
Allow the executor type for a function to be specified as an attr on a function.
This change complements the existing `InstantiateOptions::executor_type` option, which takes precedence over the attr if both are provided. It enables the choice of executor to be separated from both the calling op implementation and the function definition, which simplifies the use of custom executors in operations that take a function as an attr (e.g.) `tf.data` and the functional control-flow ops. PiperOrigin-RevId: 216532778
-rw-r--r--tensorflow/core/common_runtime/function.cc2
-rw-r--r--tensorflow/core/common_runtime/function_test.cc38
-rw-r--r--tensorflow/core/framework/function.cc24
-rw-r--r--tensorflow/core/framework/function.h7
4 files changed, 62 insertions, 9 deletions
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 472865ca43..e0e5f4a215 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -551,7 +551,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
item->func_graph = fbody;
item->overlay_lib = options.overlay_lib;
item->instantiation_counter = 1;
- item->executor_type = options.executor_type;
+ item->executor_type = ExecutorType(options, attrs);
items_.emplace(next_handle_, std::unique_ptr<Item>(item));
next_handle_++;
}
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 7bab9be9a6..716167132b 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -584,7 +584,28 @@ TEST_F(FunctionLibraryRuntimeTest, ExecutorFactory) {
"Internal: This is a dummy.");
}
- // Test that non-existent exector types trigger an error.
+ // Test that a non-default executor factory can be invoked via an attr.
+ {
+ FunctionLibraryRuntime::InstantiateOptions options;
+ HasError(InstantiateAndRun(flr0_, "XTimesTwo",
+ {{"T", DT_FLOAT}, {"_executor", "DUMMY"}},
+ options, {x}, {&y}),
+ "Internal: This is a dummy.");
+ }
+
+ // Test that a non-default executor factory specified via an
+ // `InstantiateOptions` supersedes the attr when both are present.
+ {
+ FunctionLibraryRuntime::InstantiateOptions options;
+ options.executor_type = "DUMMY";
+ HasError(
+ InstantiateAndRun(flr0_, "XTimesTwo",
+ {{"T", DT_FLOAT}, {"_executor", "UNKNOWN_EXECUTOR"}},
+ options, {x}, {&y}),
+ "Internal: This is a dummy.");
+ }
+
+ // Test that non-existent executor types trigger an error.
{
FunctionLibraryRuntime::InstantiateOptions options;
options.executor_type = "UNKNOWN_EXECUTOR";
@@ -593,6 +614,15 @@ TEST_F(FunctionLibraryRuntimeTest, ExecutorFactory) {
"Not found: No executor factory registered for the given executor "
"type: UNKNOWN_EXECUTOR");
}
+ {
+ FunctionLibraryRuntime::InstantiateOptions options;
+ HasError(
+ InstantiateAndRun(flr0_, "XTimesTwo",
+ {{"T", DT_FLOAT}, {"_executor", "UNKNOWN_EXECUTOR"}},
+ options, {x}, {&y}),
+ "Not found: No executor factory registered for the given executor "
+ "type: UNKNOWN_EXECUTOR");
+ }
}
TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) {
@@ -869,7 +899,7 @@ TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) {
Scope s = Scope::NewRootScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto x4_x2_scale = ops::Const<float>(
- s.WithOpName("x4/x2/scale/_12__cf__10")
+ s.WithOpName("x4/x2/scale/_12__cf__13")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
2.0f);
auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale);
@@ -1076,13 +1106,13 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) {
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1);
auto scale = ops::Const(
- s.WithOpName("scale/_6__cf__15")
+ s.WithOpName("scale/_6__cf__18")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
2.0f);
auto func1_gx = ops::Mul(s.WithOpName("Func/_1/gx"), func0, scale);
auto func1_sx = ops::Shape(s.WithOpName("Func/_1/sx"), x);
auto const0 = ops::Const(
- s.WithOpName("Func/_1/sy/_5__cf__14")
+ s.WithOpName("Func/_1/sy/_5__cf__17")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
0, {0});
auto func1_rx = ops::internal::BroadcastGradientArgs(
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index 20f957190b..aa2f274752 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -796,12 +796,28 @@ uint64 FunctionDefHash(const FunctionDef& fdef) {
return h;
}
+static constexpr const char* const kExecutorAttr = "_executor";
+
+/* static */
+string FunctionLibraryRuntime::ExecutorType(const InstantiateOptions& options,
+ AttrSlice attrs) {
+ if (!options.executor_type.empty()) {
+ return options.executor_type;
+ } else if (const AttrValue* executor_attr = attrs.Find(kExecutorAttr)) {
+ return executor_attr->s();
+ } else {
+ return string();
+ }
+}
+
string Canonicalize(const string& funcname, AttrSlice attrs,
const FunctionLibraryRuntime::InstantiateOptions& options) {
std::vector<string> entries;
entries.reserve(options.target.empty() ? attrs.size() : (attrs.size() + 1));
for (auto p : attrs) {
- entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
+ if (p.first != kExecutorAttr) {
+ entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
+ }
}
if (!options.target.empty()) {
entries.push_back(
@@ -815,9 +831,9 @@ string Canonicalize(const string& funcname, AttrSlice attrs,
entries.push_back(
strings::StrCat("_state_handle", "=", options.state_handle));
}
- if (!options.executor_type.empty()) {
- entries.push_back(
- strings::StrCat("_executor_type", "=", options.executor_type));
+ string executor_type = FunctionLibraryRuntime::ExecutorType(options, attrs);
+ if (!executor_type.empty()) {
+ entries.push_back(strings::StrCat(kExecutorAttr, "=", executor_type));
}
std::sort(entries.begin(), entries.end());
return strings::StrCat(funcname, "[", str_util::Join(entries, ","), "]");
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index 4d6d68e214..d4beca7e11 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -609,6 +609,13 @@ class FunctionLibraryRuntime {
virtual Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
FunctionLibraryRuntime** out_flr) = 0;
+
+ // Returns the name of the executor class (in the sense of
+ // `ExecutorFactory::GetFactory()`) that will be used based on the given
+ // dynamic `options` and static `attrs`. If none is specified, this method
+ // will return an empty string, which leaves the decision up to the runtime.
+ static string ExecutorType(const InstantiateOptions& options,
+ AttrSlice attrs);
};
// Returns a canonicalized string for the instantiation of the