aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/function_test.cc
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-06-09 10:39:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-09 10:41:54 -0700
commit119db15241e29587e0b6ab3912bff5ff63d123eb (patch)
treeec75f6f45b0eea6a96ad4765f35b20fb583ea83a /tensorflow/core/common_runtime/function_test.cc
parent898f9664488f0036ccc02bbb34379cb613f07a55 (diff)
Add a registration mechanism for experimental executor implementations.
Also add an option to the FunctionLibraryRuntime's `InstantiateOptions` that enables users to select a particular executor implementation when instantiating a function. PiperOrigin-RevId: 199920648
Diffstat (limited to 'tensorflow/core/common_runtime/function_test.cc')
-rw-r--r--tensorflow/core/common_runtime/function_test.cc72
1 files changed, 68 insertions, 4 deletions
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index f4f5198396..1e837e9a7e 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/common_runtime/executor_factory.h"
#include "tensorflow/core/common_runtime/function_testlib.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
@@ -531,6 +532,69 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) {
}
}
+namespace {
+class DummyExecutorRegistrar {
+ public:
+ DummyExecutorRegistrar() {
+ ExecutorFactory::Register("DUMMY", new Factory());
+ }
+
+ private:
+ class Factory : public ExecutorFactory {
+ Status NewExecutor(const LocalExecutorParams& params,
+ std::unique_ptr<const Graph> graph,
+ std::unique_ptr<Executor>* out_executor) override {
+ return errors::Internal("This is a dummy.");
+ }
+ };
+};
+static DummyExecutorRegistrar registrar;
+} // namespace
+
+TEST_F(FunctionLibraryRuntimeTest, ExecutorFactory) {
+ Init({test::function::XTimesTwo()});
+
+ auto x = test::AsTensor<float>({1, 2, 3, 4});
+ Tensor y;
+
+ // Test that the default executor works.
+ {
+ FunctionLibraryRuntime::InstantiateOptions options;
+ options.executor_type = "";
+ TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}},
+ options, {x}, {&y}));
+ test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
+ }
+
+ // Test the explicit registration for the default executor.
+ {
+ FunctionLibraryRuntime::InstantiateOptions options;
+ options.executor_type = "DEFAULT";
+ TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}},
+ options, {x}, {&y}));
+ test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
+ }
+
+ // Test that a non-default executor factory can be invoked.
+ {
+ FunctionLibraryRuntime::InstantiateOptions options;
+ options.executor_type = "DUMMY";
+ HasError(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, options,
+ {x}, {&y}),
+ "Internal: This is a dummy.");
+ }
+
+ // Test that non-existent exector types trigger an error.
+ {
+ FunctionLibraryRuntime::InstantiateOptions options;
+ options.executor_type = "UNKNOWN_EXECUTOR";
+ HasError(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, options,
+ {x}, {&y}),
+ "Not found: No executor factory registered for the given executor "
+ "type: UNKNOWN_EXECUTOR");
+ }
+}
+
TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) {
Init({test::function::XTimesTwo(), test::function::XTimesFour(),
test::function::XTimes16()});
@@ -803,7 +867,7 @@ TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) {
Scope s = Scope::NewRootScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto x4_x2_scale = ops::Const<float>(
- s.WithOpName("x4/x2/scale/_12__cf__6")
+ s.WithOpName("x4/x2/scale/_12__cf__10")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
2.0f);
auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale);
@@ -913,7 +977,7 @@ TEST_F(FunctionLibraryRuntimeTest, Error_NotFound) {
"Not found: Function Foo is not defined.");
}
-TEST_F(FunctionLibraryRuntimeTest, Error_InstantiaionError) {
+TEST_F(FunctionLibraryRuntimeTest, Error_InstantiationError) {
auto bad_x_times_two = FDH::Define(
// Name
"XTimesTwo",
@@ -1009,13 +1073,13 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) {
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1);
auto scale = ops::Const(
- s.WithOpName("scale/_6__cf__11")
+ s.WithOpName("scale/_6__cf__15")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
2.0f);
auto func1_gx = ops::Mul(s.WithOpName("Func/_1/gx"), func0, scale);
auto func1_sx = ops::Shape(s.WithOpName("Func/_1/sx"), x);
auto const0 = ops::Const(
- s.WithOpName("Func/_1/sy/_5__cf__10")
+ s.WithOpName("Func/_1/sy/_5__cf__14")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
0, {0});
auto func1_rx = ops::internal::BroadcastGradientArgs(