aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-10-10 08:17:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-10 08:20:50 -0700
commitafcc1a4452de7898391683f7cbb16ff548f839a1 (patch)
treee65547621501d7a9d9ccba50939391f77ecc9323 /tensorflow/core/common_runtime
parent1ae0a45a5de65ab4ae6def232da016e7ee32773c (diff)
Allow the executor type for a function to be specified as an attr on a function.
This change complements the existing `InstantiateOptions::executor_type` option, which takes precedence over the attr if both are provided. It enables the choice of executor to be separated from both the calling op implementation and the function definition, which simplifies the use of custom executors in operations that take a function as an attr (e.g.) `tf.data` and the functional control-flow ops. PiperOrigin-RevId: 216532778
Diffstat (limited to 'tensorflow/core/common_runtime')
-rw-r--r--tensorflow/core/common_runtime/function.cc2
-rw-r--r--tensorflow/core/common_runtime/function_test.cc38
2 files changed, 35 insertions, 5 deletions
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 472865ca43..e0e5f4a215 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -551,7 +551,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
item->func_graph = fbody;
item->overlay_lib = options.overlay_lib;
item->instantiation_counter = 1;
- item->executor_type = options.executor_type;
+ item->executor_type = ExecutorType(options, attrs);
items_.emplace(next_handle_, std::unique_ptr<Item>(item));
next_handle_++;
}
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 7bab9be9a6..716167132b 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -584,7 +584,28 @@ TEST_F(FunctionLibraryRuntimeTest, ExecutorFactory) {
"Internal: This is a dummy.");
}
- // Test that non-existent exector types trigger an error.
+ // Test that a non-default executor factory can be invoked via an attr.
+ {
+ FunctionLibraryRuntime::InstantiateOptions options;
+ HasError(InstantiateAndRun(flr0_, "XTimesTwo",
+ {{"T", DT_FLOAT}, {"_executor", "DUMMY"}},
+ options, {x}, {&y}),
+ "Internal: This is a dummy.");
+ }
+
+ // Test that a non-default executor factory specified via an
+ // `InstantiateOptions` supersedes the attr when both are present.
+ {
+ FunctionLibraryRuntime::InstantiateOptions options;
+ options.executor_type = "DUMMY";
+ HasError(
+ InstantiateAndRun(flr0_, "XTimesTwo",
+ {{"T", DT_FLOAT}, {"_executor", "UNKNOWN_EXECUTOR"}},
+ options, {x}, {&y}),
+ "Internal: This is a dummy.");
+ }
+
+ // Test that non-existent executor types trigger an error.
{
FunctionLibraryRuntime::InstantiateOptions options;
options.executor_type = "UNKNOWN_EXECUTOR";
@@ -593,6 +614,15 @@ TEST_F(FunctionLibraryRuntimeTest, ExecutorFactory) {
"Not found: No executor factory registered for the given executor "
"type: UNKNOWN_EXECUTOR");
}
+ {
+ FunctionLibraryRuntime::InstantiateOptions options;
+ HasError(
+ InstantiateAndRun(flr0_, "XTimesTwo",
+ {{"T", DT_FLOAT}, {"_executor", "UNKNOWN_EXECUTOR"}},
+ options, {x}, {&y}),
+ "Not found: No executor factory registered for the given executor "
+ "type: UNKNOWN_EXECUTOR");
+ }
}
TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) {
@@ -869,7 +899,7 @@ TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) {
Scope s = Scope::NewRootScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto x4_x2_scale = ops::Const<float>(
- s.WithOpName("x4/x2/scale/_12__cf__10")
+ s.WithOpName("x4/x2/scale/_12__cf__13")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
2.0f);
auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale);
@@ -1076,13 +1106,13 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) {
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1);
auto scale = ops::Const(
- s.WithOpName("scale/_6__cf__15")
+ s.WithOpName("scale/_6__cf__18")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
2.0f);
auto func1_gx = ops::Mul(s.WithOpName("Func/_1/gx"), func0, scale);
auto func1_sx = ops::Shape(s.WithOpName("Func/_1/sx"), x);
auto const0 = ops::Const(
- s.WithOpName("Func/_1/sy/_5__cf__14")
+ s.WithOpName("Func/_1/sy/_5__cf__17")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
0, {0});
auto func1_rx = ops::internal::BroadcastGradientArgs(