diff options
author | Derek Murray <mrry@google.com> | 2018-06-09 10:39:16 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-09 10:41:54 -0700 |
commit | 119db15241e29587e0b6ab3912bff5ff63d123eb (patch) | |
tree | ec75f6f45b0eea6a96ad4765f35b20fb583ea83a /tensorflow/core/common_runtime/function_test.cc | |
parent | 898f9664488f0036ccc02bbb34379cb613f07a55 (diff) |
Add a registration mechanism for experimental executor implementations.
Also add an option to the FunctionLibraryRuntime's `InstantiateOptions` that
enables users to select a particular executor implementation when instantiating
a function.
PiperOrigin-RevId: 199920648
Diffstat (limited to 'tensorflow/core/common_runtime/function_test.cc')
-rw-r--r-- | tensorflow/core/common_runtime/function_test.cc | 72 |
1 files changed, 68 insertions, 4 deletions
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index f4f5198396..1e837e9a7e 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/common_runtime/executor_factory.h" #include "tensorflow/core/common_runtime/function_testlib.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" @@ -531,6 +532,69 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) { } } +namespace { +class DummyExecutorRegistrar { + public: + DummyExecutorRegistrar() { + ExecutorFactory::Register("DUMMY", new Factory()); + } + + private: + class Factory : public ExecutorFactory { + Status NewExecutor(const LocalExecutorParams& params, + std::unique_ptr<const Graph> graph, + std::unique_ptr<Executor>* out_executor) override { + return errors::Internal("This is a dummy."); + } + }; +}; +static DummyExecutorRegistrar registrar; +} // namespace + +TEST_F(FunctionLibraryRuntimeTest, ExecutorFactory) { + Init({test::function::XTimesTwo()}); + + auto x = test::AsTensor<float>({1, 2, 3, 4}); + Tensor y; + + // Test that the default executor works. + { + FunctionLibraryRuntime::InstantiateOptions options; + options.executor_type = ""; + TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, + options, {x}, {&y})); + test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8})); + } + + // Test the explicit registration for the default executor. + { + FunctionLibraryRuntime::InstantiateOptions options; + options.executor_type = "DEFAULT"; + TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, + options, {x}, {&y})); + test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8})); + } + + // Test that a non-default executor factory can be invoked. + { + FunctionLibraryRuntime::InstantiateOptions options; + options.executor_type = "DUMMY"; + HasError(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, options, + {x}, {&y}), + "Internal: This is a dummy."); + } + + // Test that non-existent exector types trigger an error. + { + FunctionLibraryRuntime::InstantiateOptions options; + options.executor_type = "UNKNOWN_EXECUTOR"; + HasError(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, options, + {x}, {&y}), + "Not found: No executor factory registered for the given executor " + "type: UNKNOWN_EXECUTOR"); + } +} + TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) { Init({test::function::XTimesTwo(), test::function::XTimesFour(), test::function::XTimes16()}); @@ -803,7 +867,7 @@ TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) { Scope s = Scope::NewRootScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); auto x4_x2_scale = ops::Const<float>( - s.WithOpName("x4/x2/scale/_12__cf__6") + s.WithOpName("x4/x2/scale/_12__cf__10") .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"), 2.0f); auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale); @@ -913,7 +977,7 @@ TEST_F(FunctionLibraryRuntimeTest, Error_NotFound) { "Not found: Function Foo is not defined."); } -TEST_F(FunctionLibraryRuntimeTest, Error_InstantiaionError) { +TEST_F(FunctionLibraryRuntimeTest, Error_InstantiationError) { auto bad_x_times_two = FDH::Define( // Name "XTimesTwo", @@ -1009,13 +1073,13 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) { auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1); auto scale = ops::Const( - s.WithOpName("scale/_6__cf__11") + s.WithOpName("scale/_6__cf__15") .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"), 2.0f); auto func1_gx = ops::Mul(s.WithOpName("Func/_1/gx"), func0, scale); auto func1_sx = ops::Shape(s.WithOpName("Func/_1/sx"), x); auto const0 = ops::Const( - s.WithOpName("Func/_1/sy/_5__cf__10") + s.WithOpName("Func/_1/sy/_5__cf__14") .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"), 0, {0}); auto func1_rx = ops::internal::BroadcastGradientArgs( |