aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-06-09 10:39:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-09 10:41:54 -0700
commit119db15241e29587e0b6ab3912bff5ff63d123eb (patch)
treeec75f6f45b0eea6a96ad4765f35b20fb583ea83a
parent898f9664488f0036ccc02bbb34379cb613f07a55 (diff)
Add a registration mechanism for experimental executor implementations.
Also add an option to the FunctionLibraryRuntime's `InstantiateOptions` that enables users to select a particular executor implementation when instantiating a function. PiperOrigin-RevId: 199920648
-rw-r--r--tensorflow/core/BUILD2
-rw-r--r--tensorflow/core/common_runtime/executor.cc27
-rw-r--r--tensorflow/core/common_runtime/executor_factory.cc85
-rw-r--r--tensorflow/core/common_runtime/executor_factory.h51
-rw-r--r--tensorflow/core/common_runtime/executor_test.cc4
-rw-r--r--tensorflow/core/common_runtime/function.cc16
-rw-r--r--tensorflow/core/common_runtime/function_test.cc72
-rw-r--r--tensorflow/core/common_runtime/kernel_benchmark_testlib.cc18
-rw-r--r--tensorflow/core/common_runtime/kernel_benchmark_testlib.h4
-rw-r--r--tensorflow/core/framework/function.cc4
-rw-r--r--tensorflow/core/framework/function.h6
11 files changed, 267 insertions, 22 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 5ff65f4f72..f17f39099a 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -2633,6 +2633,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
"common_runtime/dma_helper.h",
"common_runtime/eigen_thread_pool.h",
"common_runtime/executor.h",
+ "common_runtime/executor_factory.h",
"common_runtime/graph_optimizer.h",
"common_runtime/local_device.h",
"common_runtime/lower_if_op.h",
@@ -2682,6 +2683,7 @@ tf_cuda_library(
"common_runtime/device_resolver_local.cc",
"common_runtime/device_set.cc",
"common_runtime/executor.cc",
+ "common_runtime/executor_factory.cc",
"common_runtime/function.cc",
"common_runtime/graph_optimizer.cc",
"common_runtime/graph_runner.cc",
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 585d777e81..f7f2cdc14f 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/common_runtime/costmodel_manager.h"
+#include "tensorflow/core/common_runtime/executor_factory.h"
#include "tensorflow/core/common_runtime/pending_counts.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
@@ -2764,4 +2765,30 @@ Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
void DeleteNonCachedKernel(OpKernel* kernel) { delete kernel; }
+namespace {
+
+class DefaultExecutorRegistrar {
+ public:
+ DefaultExecutorRegistrar() {
+ Factory* factory = new Factory;
+ ExecutorFactory::Register("", factory);
+ ExecutorFactory::Register("DEFAULT", factory);
+ }
+
+ private:
+ class Factory : public ExecutorFactory {
+ Status NewExecutor(const LocalExecutorParams& params,
+ std::unique_ptr<const Graph> graph,
+ std::unique_ptr<Executor>* out_executor) override {
+ Executor* ret = nullptr;
+ TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(graph), &ret));
+ out_executor->reset(ret);
+ return Status::OK();
+ }
+ };
+};
+static DefaultExecutorRegistrar registrar;
+
+} // namespace
+
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/executor_factory.cc b/tensorflow/core/common_runtime/executor_factory.cc
new file mode 100644
index 0000000000..ee7c7c3a73
--- /dev/null
+++ b/tensorflow/core/common_runtime/executor_factory.cc
@@ -0,0 +1,85 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/executor_factory.h"
+
+#include <unordered_map>
+
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace {
+
+static mutex executor_factory_lock(LINKER_INITIALIZED);
+
+typedef std::unordered_map<string, ExecutorFactory*> ExecutorFactories;
+ExecutorFactories* executor_factories() {
+ static ExecutorFactories* factories = new ExecutorFactories;
+ return factories;
+}
+
+} // namespace
+
+void ExecutorFactory::Register(const string& executor_type,
+ ExecutorFactory* factory) {
+ mutex_lock l(executor_factory_lock);
+ if (!executor_factories()->insert({executor_type, factory}).second) {
+ LOG(FATAL) << "Two executor factories are being registered "
+ << "under" << executor_type;
+ }
+}
+
+namespace {
+const string RegisteredFactoriesErrorMessageLocked()
+ SHARED_LOCKS_REQUIRED(executor_factory_lock) {
+ std::vector<string> factory_types;
+ for (const auto& executor_factory : *executor_factories()) {
+ factory_types.push_back(executor_factory.first);
+ }
+ return strings::StrCat("Registered factories are {",
+ str_util::Join(factory_types, ", "), "}.");
+}
+} // namespace
+
+Status ExecutorFactory::GetFactory(const string& executor_type,
+ ExecutorFactory** out_factory) {
+ tf_shared_lock l(executor_factory_lock);
+
+ auto iter = executor_factories()->find(executor_type);
+ if (iter == executor_factories()->end()) {
+ return errors::NotFound(
+ "No executor factory registered for the given executor type: ",
+ executor_type, " ", RegisteredFactoriesErrorMessageLocked());
+ }
+
+ *out_factory = iter->second;
+ return Status::OK();
+}
+
+Status NewExecutor(const string& executor_type,
+ const LocalExecutorParams& params,
+ std::unique_ptr<const Graph> graph,
+ std::unique_ptr<Executor>* out_executor) {
+ ExecutorFactory* factory = nullptr;
+ TF_RETURN_IF_ERROR(ExecutorFactory::GetFactory(executor_type, &factory));
+ return factory->NewExecutor(params, std::move(graph), out_executor);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/executor_factory.h b/tensorflow/core/common_runtime/executor_factory.h
new file mode 100644
index 0000000000..f81bb080eb
--- /dev/null
+++ b/tensorflow/core/common_runtime/executor_factory.h
@@ -0,0 +1,51 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_FACTORY_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_FACTORY_H_
+
+#include <string>
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+class Executor;
+class Graph;
+struct LocalExecutorParams;
+
+class ExecutorFactory {
+ public:
+ virtual Status NewExecutor(const LocalExecutorParams& params,
+ std::unique_ptr<const Graph> graph,
+ std::unique_ptr<Executor>* out_executor) = 0;
+ virtual ~ExecutorFactory() {}
+
+ static void Register(const string& executor_type, ExecutorFactory* factory);
+ static Status GetFactory(const string& executor_type,
+ ExecutorFactory** out_factory);
+};
+
+Status NewExecutor(const string& executor_type,
+ const LocalExecutorParams& params,
+ std::unique_ptr<const Graph> graph,
+ std::unique_ptr<Executor>* out_executor);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_FACTORY_H_
diff --git a/tensorflow/core/common_runtime/executor_test.cc b/tensorflow/core/common_runtime/executor_test.cc
index b24969613c..7697103faf 100644
--- a/tensorflow/core/common_runtime/executor_test.cc
+++ b/tensorflow/core/common_runtime/executor_test.cc
@@ -464,8 +464,8 @@ BENCHMARK(BM_executor)->ArgPair(1024, 1024);
static void BM_FeedInputFetchOutput(int iters) {
Graph* g = new Graph(OpRegistry::Global());
// z = x + y: x and y are provided as benchmark inputs. z is the
- // output of the benchmark. Conceptually, the caller is "a", the
- // benchmark is "b".
+ // output of the benchmark. Conceptually, the caller is ALICE, the
+ // benchmark is BOB.
Node* x = test::graph::Recv(g, "x", "float", ALICE, 1, BOB);
Node* y = test::graph::Recv(g, "y", "float", ALICE, 1, BOB);
Node* sum = test::graph::Add(g, x, y);
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 5d9be70522..68d37ddbcd 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/common_runtime/executor_factory.h"
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/common_runtime/memory_types.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
@@ -215,6 +216,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
const FunctionLibraryDefinition* overlay_lib = nullptr; // Not owned.
FunctionBody* func_graph = nullptr;
Executor* exec = nullptr;
+ string executor_type;
~Item() {
delete this->func_graph;
@@ -549,6 +551,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
item->func_graph = fbody;
item->overlay_lib = options.overlay_lib;
item->instantiation_counter = 1;
+ item->executor_type = options.executor_type;
items_.emplace(next_handle_, std::unique_ptr<Item>(item));
next_handle_++;
}
@@ -623,10 +626,12 @@ void PruneFunctionBody(Graph* g) {
Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
const FunctionBody* fbody;
const FunctionLibraryDefinition* lib_def;
+ string executor_type;
{
mutex_lock l(mu_);
fbody = (*item)->func_graph;
lib_def = (*item)->overlay_lib;
+ executor_type = (*item)->executor_type;
}
if (!lib_def) {
lib_def = base_lib_def_;
@@ -656,17 +661,14 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
DeleteNonCachedKernel(kernel);
};
Graph* graph = g.get();
- Executor* exec;
- TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(g), &exec));
-
+ std::unique_ptr<Executor> exec;
+ TF_RETURN_IF_ERROR(NewExecutor(executor_type, params, std::move(g), &exec));
{
// Guard item since it is already inserted in items_.
mutex_lock l(mu_);
- if ((*item)->exec) {
- delete exec;
- } else {
+ if ((*item)->exec == nullptr) {
(*item)->graph = graph;
- (*item)->exec = exec;
+ (*item)->exec = exec.release();
}
}
return Status::OK();
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index f4f5198396..1e837e9a7e 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/common_runtime/executor_factory.h"
#include "tensorflow/core/common_runtime/function_testlib.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
@@ -531,6 +532,69 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) {
}
}
+namespace {
+class DummyExecutorRegistrar {
+ public:
+ DummyExecutorRegistrar() {
+ ExecutorFactory::Register("DUMMY", new Factory());
+ }
+
+ private:
+ class Factory : public ExecutorFactory {
+ Status NewExecutor(const LocalExecutorParams& params,
+ std::unique_ptr<const Graph> graph,
+ std::unique_ptr<Executor>* out_executor) override {
+ return errors::Internal("This is a dummy.");
+ }
+ };
+};
+static DummyExecutorRegistrar registrar;
+} // namespace
+
+TEST_F(FunctionLibraryRuntimeTest, ExecutorFactory) {
+ Init({test::function::XTimesTwo()});
+
+ auto x = test::AsTensor<float>({1, 2, 3, 4});
+ Tensor y;
+
+ // Test that the default executor works.
+ {
+ FunctionLibraryRuntime::InstantiateOptions options;
+ options.executor_type = "";
+ TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}},
+ options, {x}, {&y}));
+ test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
+ }
+
+ // Test the explicit registration for the default executor.
+ {
+ FunctionLibraryRuntime::InstantiateOptions options;
+ options.executor_type = "DEFAULT";
+ TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}},
+ options, {x}, {&y}));
+ test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
+ }
+
+ // Test that a non-default executor factory can be invoked.
+ {
+ FunctionLibraryRuntime::InstantiateOptions options;
+ options.executor_type = "DUMMY";
+ HasError(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, options,
+ {x}, {&y}),
+ "Internal: This is a dummy.");
+ }
+
+ // Test that non-existent exector types trigger an error.
+ {
+ FunctionLibraryRuntime::InstantiateOptions options;
+ options.executor_type = "UNKNOWN_EXECUTOR";
+ HasError(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, options,
+ {x}, {&y}),
+ "Not found: No executor factory registered for the given executor "
+ "type: UNKNOWN_EXECUTOR");
+ }
+}
+
TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) {
Init({test::function::XTimesTwo(), test::function::XTimesFour(),
test::function::XTimes16()});
@@ -803,7 +867,7 @@ TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) {
Scope s = Scope::NewRootScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto x4_x2_scale = ops::Const<float>(
- s.WithOpName("x4/x2/scale/_12__cf__6")
+ s.WithOpName("x4/x2/scale/_12__cf__10")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
2.0f);
auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale);
@@ -913,7 +977,7 @@ TEST_F(FunctionLibraryRuntimeTest, Error_NotFound) {
"Not found: Function Foo is not defined.");
}
-TEST_F(FunctionLibraryRuntimeTest, Error_InstantiaionError) {
+TEST_F(FunctionLibraryRuntimeTest, Error_InstantiationError) {
auto bad_x_times_two = FDH::Define(
// Name
"XTimesTwo",
@@ -1009,13 +1073,13 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) {
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1);
auto scale = ops::Const(
- s.WithOpName("scale/_6__cf__11")
+ s.WithOpName("scale/_6__cf__15")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
2.0f);
auto func1_gx = ops::Mul(s.WithOpName("Func/_1/gx"), func0, scale);
auto func1_sx = ops::Shape(s.WithOpName("Func/_1/sx"), x);
auto const0 = ops::Const(
- s.WithOpName("Func/_1/sy/_5__cf__10")
+ s.WithOpName("Func/_1/sy/_5__cf__14")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
0, {0});
auto func1_rx = ops::internal::BroadcastGradientArgs(
diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc
index 7de1b80e2d..1f585a8c24 100644
--- a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc
+++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/executor_factory.h"
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -43,7 +44,7 @@ namespace test {
// TODO(hongm): Convert `g` and `init` to using std::unique_ptr.
Benchmark::Benchmark(const string& device, Graph* g,
const SessionOptions* options, Graph* init,
- Rendezvous* rendez) {
+ Rendezvous* rendez, const char* executor_type) {
SessionOptions default_options;
if (!options) {
options = &default_options;
@@ -86,23 +87,26 @@ Benchmark::Benchmark(const string& device, Graph* g,
};
if (init) {
- Executor* init_exec;
- TF_CHECK_OK(
- NewLocalExecutor(params, std::unique_ptr<Graph>(init), &init_exec));
+ std::unique_ptr<Executor> init_exec;
+ TF_CHECK_OK(NewExecutor(executor_type, params, std::unique_ptr<Graph>(init),
+ &init_exec));
Executor::Args args;
args.rendezvous = rendez_;
args.runner = runner;
TF_CHECK_OK(init_exec->Run(args));
- delete init_exec;
}
- TF_CHECK_OK(NewLocalExecutor(params, std::unique_ptr<Graph>(g), &exec_));
+ TF_CHECK_OK(
+ NewExecutor(executor_type, params, std::unique_ptr<Graph>(g), &exec_));
}
Benchmark::~Benchmark() {
if (device_) {
rendez_->Unref();
- delete exec_;
+ // We delete `exec_` before `device_` because the `exec_` destructor may
+ // run kernel destructors that may attempt to access state borrowed from
+ // `device_`, such as the resource manager.
+ exec_.reset();
delete device_;
delete pool_;
}
diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.h b/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
index 3a7b3a5ace..995a15a299 100644
--- a/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
+++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
@@ -39,7 +39,7 @@ class Benchmark {
// "init", and one reference on "rendez" (if not null).
Benchmark(const string& device, Graph* g,
const SessionOptions* options = nullptr, Graph* init = nullptr,
- Rendezvous* rendez = nullptr);
+ Rendezvous* rendez = nullptr, const char* executor_type = "");
~Benchmark();
// Executes the graph for "iters" times.
@@ -57,7 +57,7 @@ class Benchmark {
thread::ThreadPool* pool_ = nullptr;
Device* device_ = nullptr;
Rendezvous* rendez_ = nullptr;
- Executor* exec_ = nullptr;
+ std::unique_ptr<Executor> exec_;
TF_DISALLOW_COPY_AND_ASSIGN(Benchmark);
};
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index 647c66099c..88d9d65f5a 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -815,6 +815,10 @@ string Canonicalize(const string& funcname, AttrSlice attrs,
entries.push_back(
strings::StrCat("_state_handle", "=", options.state_handle));
}
+ if (!options.executor_type.empty()) {
+ entries.push_back(
+ strings::StrCat("_executor_type", "=", options.executor_type));
+ }
std::sort(entries.begin(), entries.end());
return strings::StrCat(funcname, "[", str_util::Join(entries, ","), "]");
}
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index 872906756a..8e607b927c 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -450,6 +450,12 @@ class FunctionLibraryRuntime {
// state (in stateful kernels); and two functions with different
// values for `state_handle` will have independent state.
string state_handle;
+
+ // This interface is EXPERIMENTAL and subject to change.
+ //
+ // Instatiates the function using an executor of the given type. If empty,
+ // the default TensorFlow executor will be used.
+ string executor_type;
};
typedef uint64 Handle;
virtual Status Instantiate(const string& function_name, AttrSlice attrs,