diff options
author | Derek Murray <mrry@google.com> | 2018-06-09 10:39:16 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-09 10:41:54 -0700 |
commit | 119db15241e29587e0b6ab3912bff5ff63d123eb (patch) | |
tree | ec75f6f45b0eea6a96ad4765f35b20fb583ea83a | |
parent | 898f9664488f0036ccc02bbb34379cb613f07a55 (diff) |
Add a registration mechanism for experimental executor implementations.
Also add an option to the FunctionLibraryRuntime's `InstantiateOptions` that
enables users to select a particular executor implementation when instantiating
a function.
PiperOrigin-RevId: 199920648
-rw-r--r-- | tensorflow/core/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/executor.cc | 27 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/executor_factory.cc | 85 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/executor_factory.h | 51 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/executor_test.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/function.cc | 16 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/function_test.cc | 72 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/kernel_benchmark_testlib.cc | 18 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/kernel_benchmark_testlib.h | 4 | ||||
-rw-r--r-- | tensorflow/core/framework/function.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/framework/function.h | 6 |
11 files changed, 267 insertions, 22 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 5ff65f4f72..f17f39099a 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -2633,6 +2633,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/dma_helper.h", "common_runtime/eigen_thread_pool.h", "common_runtime/executor.h", + "common_runtime/executor_factory.h", "common_runtime/graph_optimizer.h", "common_runtime/local_device.h", "common_runtime/lower_if_op.h", @@ -2682,6 +2683,7 @@ tf_cuda_library( "common_runtime/device_resolver_local.cc", "common_runtime/device_set.cc", "common_runtime/executor.cc", + "common_runtime/executor_factory.cc", "common_runtime/function.cc", "common_runtime/graph_optimizer.cc", "common_runtime/graph_runner.cc", diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 585d777e81..f7f2cdc14f 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -23,6 +23,7 @@ limitations under the License. #include <vector> #include "tensorflow/core/common_runtime/costmodel_manager.h" +#include "tensorflow/core/common_runtime/executor_factory.h" #include "tensorflow/core/common_runtime/pending_counts.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/framework/allocation_description.pb.h" @@ -2764,4 +2765,30 @@ Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib, void DeleteNonCachedKernel(OpKernel* kernel) { delete kernel; } +namespace { + +class DefaultExecutorRegistrar { + public: + DefaultExecutorRegistrar() { + Factory* factory = new Factory; + ExecutorFactory::Register("", factory); + ExecutorFactory::Register("DEFAULT", factory); + } + + private: + class Factory : public ExecutorFactory { + Status NewExecutor(const LocalExecutorParams& params, + std::unique_ptr<const Graph> graph, + std::unique_ptr<Executor>* out_executor) override { + Executor* ret = nullptr; + TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(graph), &ret)); + out_executor->reset(ret); + return Status::OK(); + } + }; +}; +static DefaultExecutorRegistrar registrar; + +} // namespace + } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/executor_factory.cc b/tensorflow/core/common_runtime/executor_factory.cc new file mode 100644 index 0000000000..ee7c7c3a73 --- /dev/null +++ b/tensorflow/core/common_runtime/executor_factory.cc @@ -0,0 +1,85 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/executor_factory.h" + +#include <unordered_map> + +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { + +static mutex executor_factory_lock(LINKER_INITIALIZED); + +typedef std::unordered_map<string, ExecutorFactory*> ExecutorFactories; +ExecutorFactories* executor_factories() { + static ExecutorFactories* factories = new ExecutorFactories; + return factories; +} + +} // namespace + +void ExecutorFactory::Register(const string& executor_type, + ExecutorFactory* factory) { + mutex_lock l(executor_factory_lock); + if (!executor_factories()->insert({executor_type, factory}).second) { + LOG(FATAL) << "Two executor factories are being registered " + << "under" << executor_type; + } +} + +namespace { +const string RegisteredFactoriesErrorMessageLocked() + SHARED_LOCKS_REQUIRED(executor_factory_lock) { + std::vector<string> factory_types; + for (const auto& executor_factory : *executor_factories()) { + factory_types.push_back(executor_factory.first); + } + return strings::StrCat("Registered factories are {", + str_util::Join(factory_types, ", "), "}."); +} +} // namespace + +Status ExecutorFactory::GetFactory(const string& executor_type, + ExecutorFactory** out_factory) { + tf_shared_lock l(executor_factory_lock); + + auto iter = executor_factories()->find(executor_type); + if (iter == executor_factories()->end()) { + return errors::NotFound( + "No executor factory registered for the given executor type: ", + executor_type, " ", RegisteredFactoriesErrorMessageLocked()); + } + + *out_factory = iter->second; + return Status::OK(); +} + +Status NewExecutor(const string& executor_type, + const LocalExecutorParams& params, + std::unique_ptr<const Graph> graph, + std::unique_ptr<Executor>* out_executor) { + ExecutorFactory* factory = nullptr; + TF_RETURN_IF_ERROR(ExecutorFactory::GetFactory(executor_type, &factory)); + return factory->NewExecutor(params, std::move(graph), out_executor); +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/executor_factory.h b/tensorflow/core/common_runtime/executor_factory.h new file mode 100644 index 0000000000..f81bb080eb --- /dev/null +++ b/tensorflow/core/common_runtime/executor_factory.h @@ -0,0 +1,51 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_FACTORY_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_FACTORY_H_ + +#include <string> + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class Executor; +class Graph; +struct LocalExecutorParams; + +class ExecutorFactory { + public: + virtual Status NewExecutor(const LocalExecutorParams& params, + std::unique_ptr<const Graph> graph, + std::unique_ptr<Executor>* out_executor) = 0; + virtual ~ExecutorFactory() {} + + static void Register(const string& executor_type, ExecutorFactory* factory); + static Status GetFactory(const string& executor_type, + ExecutorFactory** out_factory); +}; + +Status NewExecutor(const string& executor_type, + const LocalExecutorParams& params, + std::unique_ptr<const Graph> graph, + std::unique_ptr<Executor>* out_executor); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_FACTORY_H_ diff --git a/tensorflow/core/common_runtime/executor_test.cc b/tensorflow/core/common_runtime/executor_test.cc index b24969613c..7697103faf 100644 --- a/tensorflow/core/common_runtime/executor_test.cc +++ b/tensorflow/core/common_runtime/executor_test.cc @@ -464,8 +464,8 @@ BENCHMARK(BM_executor)->ArgPair(1024, 1024); static void BM_FeedInputFetchOutput(int iters) { Graph* g = new Graph(OpRegistry::Global()); // z = x + y: x and y are provided as benchmark inputs. z is the - // output of the benchmark. Conceptually, the caller is "a", the - // benchmark is "b". + // output of the benchmark. Conceptually, the caller is ALICE, the + // benchmark is BOB. Node* x = test::graph::Recv(g, "x", "float", ALICE, 1, BOB); Node* y = test::graph::Recv(g, "y", "float", ALICE, 1, BOB); Node* sum = test::graph::Add(g, x, y); diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 5d9be70522..68d37ddbcd 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/common_runtime/executor_factory.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/common_runtime/memory_types.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" @@ -215,6 +216,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { const FunctionLibraryDefinition* overlay_lib = nullptr; // Not owned. FunctionBody* func_graph = nullptr; Executor* exec = nullptr; + string executor_type; ~Item() { delete this->func_graph; @@ -549,6 +551,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate( item->func_graph = fbody; item->overlay_lib = options.overlay_lib; item->instantiation_counter = 1; + item->executor_type = options.executor_type; items_.emplace(next_handle_, std::unique_ptr<Item>(item)); next_handle_++; } @@ -623,10 +626,12 @@ void PruneFunctionBody(Graph* g) { Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) { const FunctionBody* fbody; const FunctionLibraryDefinition* lib_def; + string executor_type; { mutex_lock l(mu_); fbody = (*item)->func_graph; lib_def = (*item)->overlay_lib; + executor_type = (*item)->executor_type; } if (!lib_def) { lib_def = base_lib_def_; @@ -656,17 +661,14 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) { DeleteNonCachedKernel(kernel); }; Graph* graph = g.get(); - Executor* exec; - TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(g), &exec)); - + std::unique_ptr<Executor> exec; + TF_RETURN_IF_ERROR(NewExecutor(executor_type, params, std::move(g), &exec)); { // Guard item since it is already inserted in items_. mutex_lock l(mu_); - if ((*item)->exec) { - delete exec; - } else { + if ((*item)->exec == nullptr) { (*item)->graph = graph; - (*item)->exec = exec; + (*item)->exec = exec.release(); } } return Status::OK(); diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index f4f5198396..1e837e9a7e 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/common_runtime/executor_factory.h" #include "tensorflow/core/common_runtime/function_testlib.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" @@ -531,6 +532,69 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) { } } +namespace { +class DummyExecutorRegistrar { + public: + DummyExecutorRegistrar() { + ExecutorFactory::Register("DUMMY", new Factory()); + } + + private: + class Factory : public ExecutorFactory { + Status NewExecutor(const LocalExecutorParams& params, + std::unique_ptr<const Graph> graph, + std::unique_ptr<Executor>* out_executor) override { + return errors::Internal("This is a dummy."); + } + }; +}; +static DummyExecutorRegistrar registrar; +} // namespace + +TEST_F(FunctionLibraryRuntimeTest, ExecutorFactory) { + Init({test::function::XTimesTwo()}); + + auto x = test::AsTensor<float>({1, 2, 3, 4}); + Tensor y; + + // Test that the default executor works. + { + FunctionLibraryRuntime::InstantiateOptions options; + options.executor_type = ""; + TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, + options, {x}, {&y})); + test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8})); + } + + // Test the explicit registration for the default executor. + { + FunctionLibraryRuntime::InstantiateOptions options; + options.executor_type = "DEFAULT"; + TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, + options, {x}, {&y})); + test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8})); + } + + // Test that a non-default executor factory can be invoked. + { + FunctionLibraryRuntime::InstantiateOptions options; + options.executor_type = "DUMMY"; + HasError(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, options, + {x}, {&y}), + "Internal: This is a dummy."); + } + + // Test that non-existent exector types trigger an error. + { + FunctionLibraryRuntime::InstantiateOptions options; + options.executor_type = "UNKNOWN_EXECUTOR"; + HasError(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, options, + {x}, {&y}), + "Not found: No executor factory registered for the given executor " + "type: UNKNOWN_EXECUTOR"); + } +} + TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) { Init({test::function::XTimesTwo(), test::function::XTimesFour(), test::function::XTimes16()}); @@ -803,7 +867,7 @@ TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) { Scope s = Scope::NewRootScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); auto x4_x2_scale = ops::Const<float>( - s.WithOpName("x4/x2/scale/_12__cf__6") + s.WithOpName("x4/x2/scale/_12__cf__10") .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"), 2.0f); auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale); @@ -913,7 +977,7 @@ TEST_F(FunctionLibraryRuntimeTest, Error_NotFound) { "Not found: Function Foo is not defined."); } -TEST_F(FunctionLibraryRuntimeTest, Error_InstantiaionError) { +TEST_F(FunctionLibraryRuntimeTest, Error_InstantiationError) { auto bad_x_times_two = FDH::Define( // Name "XTimesTwo", @@ -1009,13 +1073,13 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) { auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1); auto scale = ops::Const( - s.WithOpName("scale/_6__cf__11") + s.WithOpName("scale/_6__cf__15") .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"), 2.0f); auto func1_gx = ops::Mul(s.WithOpName("Func/_1/gx"), func0, scale); auto func1_sx = ops::Shape(s.WithOpName("Func/_1/sx"), x); auto const0 = ops::Const( - s.WithOpName("Func/_1/sy/_5__cf__10") + s.WithOpName("Func/_1/sy/_5__cf__14") .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"), 0, {0}); auto func1_rx = ops::internal::BroadcastGradientArgs( diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc index 7de1b80e2d..1f585a8c24 100644 --- a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc +++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc @@ -18,6 +18,7 @@ limitations under the License. #include <vector> #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/executor_factory.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -43,7 +44,7 @@ namespace test { // TODO(hongm): Convert `g` and `init` to using std::unique_ptr. Benchmark::Benchmark(const string& device, Graph* g, const SessionOptions* options, Graph* init, - Rendezvous* rendez) { + Rendezvous* rendez, const char* executor_type) { SessionOptions default_options; if (!options) { options = &default_options; @@ -86,23 +87,26 @@ Benchmark::Benchmark(const string& device, Graph* g, }; if (init) { - Executor* init_exec; - TF_CHECK_OK( - NewLocalExecutor(params, std::unique_ptr<Graph>(init), &init_exec)); + std::unique_ptr<Executor> init_exec; + TF_CHECK_OK(NewExecutor(executor_type, params, std::unique_ptr<Graph>(init), + &init_exec)); Executor::Args args; args.rendezvous = rendez_; args.runner = runner; TF_CHECK_OK(init_exec->Run(args)); - delete init_exec; } - TF_CHECK_OK(NewLocalExecutor(params, std::unique_ptr<Graph>(g), &exec_)); + TF_CHECK_OK( + NewExecutor(executor_type, params, std::unique_ptr<Graph>(g), &exec_)); } Benchmark::~Benchmark() { if (device_) { rendez_->Unref(); - delete exec_; + // We delete `exec_` before `device_` because the `exec_` destructor may + // run kernel destructors that may attempt to access state borrowed from + // `device_`, such as the resource manager. + exec_.reset(); delete device_; delete pool_; } diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.h b/tensorflow/core/common_runtime/kernel_benchmark_testlib.h index 3a7b3a5ace..995a15a299 100644 --- a/tensorflow/core/common_runtime/kernel_benchmark_testlib.h +++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.h @@ -39,7 +39,7 @@ class Benchmark { // "init", and one reference on "rendez" (if not null). Benchmark(const string& device, Graph* g, const SessionOptions* options = nullptr, Graph* init = nullptr, - Rendezvous* rendez = nullptr); + Rendezvous* rendez = nullptr, const char* executor_type = ""); ~Benchmark(); // Executes the graph for "iters" times. @@ -57,7 +57,7 @@ class Benchmark { thread::ThreadPool* pool_ = nullptr; Device* device_ = nullptr; Rendezvous* rendez_ = nullptr; - Executor* exec_ = nullptr; + std::unique_ptr<Executor> exec_; TF_DISALLOW_COPY_AND_ASSIGN(Benchmark); }; diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 647c66099c..88d9d65f5a 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -815,6 +815,10 @@ string Canonicalize(const string& funcname, AttrSlice attrs, entries.push_back( strings::StrCat("_state_handle", "=", options.state_handle)); } + if (!options.executor_type.empty()) { + entries.push_back( + strings::StrCat("_executor_type", "=", options.executor_type)); + } std::sort(entries.begin(), entries.end()); return strings::StrCat(funcname, "[", str_util::Join(entries, ","), "]"); } diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 872906756a..8e607b927c 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -450,6 +450,12 @@ class FunctionLibraryRuntime { // state (in stateful kernels); and two functions with different // values for `state_handle` will have independent state. string state_handle; + + // This interface is EXPERIMENTAL and subject to change. + // + // Instatiates the function using an executor of the given type. If empty, + // the default TensorFlow executor will be used. + string executor_type; }; typedef uint64 Handle; virtual Status Instantiate(const string& function_name, AttrSlice attrs, |