aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime
diff options
context:
space:
mode:
authorGravatar Mingsheng Hong <hongm@google.com>2018-02-13 14:52:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-13 14:59:15 -0800
commit858ce1bccd9dd084ed9cb35eb5629e3a349cc7c2 (patch)
treec0289531dd44c562678587a725a5e408a26f173a /tensorflow/core/common_runtime
parent89e7941fe6fe85af4fb7fe5499871d6b9d1c36ab (diff)
Code cleanup: Made Executor related API take std::unique_ptr<const Graph> instead of const Graph* as input.
PiperOrigin-RevId: 185592574
Diffstat (limited to 'tensorflow/core/common_runtime')
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc2
-rw-r--r--tensorflow/core/common_runtime/executor.cc22
-rw-r--r--tensorflow/core/common_runtime/executor.h8
-rw-r--r--tensorflow/core/common_runtime/function.cc2
-rw-r--r--tensorflow/core/common_runtime/function_test.cc6
-rw-r--r--tensorflow/core/common_runtime/graph_runner.cc14
-rw-r--r--tensorflow/core/common_runtime/kernel_benchmark_testlib.cc6
7 files changed, 31 insertions, 29 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index df6f4b8877..ecbffcbf6c 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -1250,7 +1250,7 @@ Status DirectSession::GetOrCreateExecutors(
item->device = device;
Executor* executor;
TF_RETURN_IF_ERROR(
- NewLocalExecutor(params, partition_graph.release(), &executor));
+ NewLocalExecutor(params, std::move(partition_graph), &executor));
item->executor.reset(executor);
}
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 6998cbecee..b06b75d658 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -332,8 +332,8 @@ class GraphView {
class ExecutorImpl : public Executor {
public:
- ExecutorImpl(const LocalExecutorParams& p, const Graph* g)
- : params_(p), graph_(g), gview_() {
+ ExecutorImpl(const LocalExecutorParams& p, std::unique_ptr<const Graph> g)
+ : params_(p), graph_(std::move(g)), gview_() {
CHECK(p.create_kernel != nullptr);
CHECK(p.delete_kernel != nullptr);
}
@@ -348,7 +348,6 @@ class ExecutorImpl : public Executor {
for (auto fiter : frame_info_) {
delete fiter.second;
}
- delete graph_;
}
Status Initialize();
@@ -412,7 +411,7 @@ class ExecutorImpl : public Executor {
// Owned.
LocalExecutorParams params_;
- const Graph* graph_;
+ std::unique_ptr<const Graph> graph_;
GraphView gview_;
// A cached value of params_
@@ -605,11 +604,11 @@ void GetMaxPendingCounts(const Node* n, size_t* max_pending,
}
Status ExecutorImpl::Initialize() {
- gview_.Initialize(graph_);
+ gview_.Initialize(graph_.get());
// Build the information about frames in this subgraph.
ControlFlowInfo cf_info;
- TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph_, &cf_info));
+ TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph_.get(), &cf_info));
// Cache this value so we make this virtual function call once, rather
// that O(# steps * # nodes per step) times.
@@ -676,9 +675,9 @@ Status ExecutorImpl::Initialize() {
// Initialize PendingCounts only after item->pending_id is initialized for
// all nodes.
- InitializePending(graph_, cf_info);
+ InitializePending(graph_.get(), cf_info);
- return gview_.SetAllocAttrs(graph_, params_.device);
+ return gview_.SetAllocAttrs(graph_.get(), params_.device);
}
Status GraphView::SetAllocAttrs(const Graph* g, const Device* device) {
@@ -1415,7 +1414,7 @@ void ExecutorImpl::InitializePending(const Graph* graph,
}
void ExecutorState::RunAsync(Executor::DoneCallback done) {
- const Graph* graph = impl_->graph_;
+ const Graph* graph = impl_->graph_.get();
TaggedNodeSeq ready;
// Ask the device to fill in the device context map.
@@ -2606,9 +2605,10 @@ void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
} // end namespace
-Status NewLocalExecutor(const LocalExecutorParams& params, const Graph* graph,
+Status NewLocalExecutor(const LocalExecutorParams& params,
+ std::unique_ptr<const Graph> graph,
Executor** executor) {
- ExecutorImpl* impl = new ExecutorImpl(params, graph);
+ ExecutorImpl* impl = new ExecutorImpl(params, std::move(graph));
const Status s = impl->Initialize();
if (s.ok()) {
*executor = impl;
diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h
index 3fd932da5b..adf80a2417 100644
--- a/tensorflow/core/common_runtime/executor.h
+++ b/tensorflow/core/common_runtime/executor.h
@@ -122,9 +122,8 @@ class Executor {
// Creates an Executor that computes the given "graph".
//
-// If successful, returns the constructed executor in "*executor". The
-// caller keeps the ownership of "device". The returned executor takes
-// the ownership of "graph". Otherwise, returns an error status.
+// If successful, returns the constructed executor in "*executor". Otherwise,
+// returns an error status.
//
// "params" provides a set of context for the executor. We expect that
// different context would provide different implementations.
@@ -143,7 +142,8 @@ struct LocalExecutorParams {
Executor::Args::NodeOutputsCallback node_outputs_cb;
};
::tensorflow::Status NewLocalExecutor(const LocalExecutorParams& params,
- const Graph* graph, Executor** executor);
+ std::unique_ptr<const Graph> graph,
+ Executor** executor);
// A class to help run multiple executors in parallel and wait until
// all of them are complete.
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index d349d2bb12..b941819838 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -631,7 +631,7 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
};
Graph* graph = g.get();
Executor* exec;
- TF_RETURN_IF_ERROR(NewLocalExecutor(params, g.release(), &exec));
+ TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(g), &exec));
{
// Guard item since it is already inserted in items_.
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 8b05146299..63ad0d231c 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -71,11 +71,11 @@ class FunctionTest : public ::testing::Test {
arg_types_ = result.arg_types;
ret_types_ = result.ret_types;
- Graph* g = new Graph(OpRegistry::Global());
+ std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
GraphConstructorOptions opts;
opts.allow_internal_ops = true;
opts.expect_device_spec = false;
- TF_CHECK_OK(ConvertNodeDefsToGraph(opts, result.nodes, g));
+ TF_CHECK_OK(ConvertNodeDefsToGraph(opts, result.nodes, g.get()));
const int version = g->versions().producer();
LocalExecutorParams params;
@@ -89,7 +89,7 @@ class FunctionTest : public ::testing::Test {
DeleteNonCachedKernel(kernel);
};
Executor* exec;
- TF_CHECK_OK(NewLocalExecutor(params, g, &exec));
+ TF_CHECK_OK(NewLocalExecutor(params, std::move(g), &exec));
exec_.reset(exec);
}
diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc
index a21304f7ef..f1082a6003 100644
--- a/tensorflow/core/common_runtime/graph_runner.cc
+++ b/tensorflow/core/common_runtime/graph_runner.cc
@@ -156,21 +156,21 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
// should not be running expensive operators.
auto runner = [](Executor::Args::Closure c) { c(); };
- // Take ownership and pass to NewLocalExecutor
- Graph* g = graph_to_run.release();
-
LocalExecutorParams params;
// The ownership of the output tensors are bound to this device's lifetime.
params.device = cpu_device_.get();
params.function_library = function_library;
- params.create_kernel = [this, g](const NodeDef& ndef, OpKernel** kernel) {
- return CreateNonCachedKernel(cpu_device_.get(), nullptr, ndef,
- g->versions().producer(), kernel);
+ const int producer = graph_to_run->versions().producer();
+ params.create_kernel = [this, producer](const NodeDef& ndef,
+ OpKernel** kernel) {
+ return CreateNonCachedKernel(cpu_device_.get(), nullptr, ndef, producer,
+ kernel);
};
params.delete_kernel = [](OpKernel* kernel) { delete kernel; };
Executor* executor;
- TF_RETURN_IF_ERROR(NewLocalExecutor(params, g, &executor));
+ TF_RETURN_IF_ERROR(
+ NewLocalExecutor(params, std::move(graph_to_run), &executor));
std::unique_ptr<Executor> executor_unref(executor);
Executor::Args args;
diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc
index 420dfe338e..64d8849475 100644
--- a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc
+++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc
@@ -39,6 +39,7 @@ limitations under the License.
namespace tensorflow {
namespace test {
+// TODO(hongm): Convert `g` and `init` to using std::unique_ptr.
Benchmark::Benchmark(const string& device, Graph* g,
const SessionOptions* options, Graph* init,
Rendezvous* rendez) {
@@ -85,7 +86,8 @@ Benchmark::Benchmark(const string& device, Graph* g,
if (init) {
Executor* init_exec;
- TF_CHECK_OK(NewLocalExecutor(params, init, &init_exec));
+ TF_CHECK_OK(
+ NewLocalExecutor(params, std::unique_ptr<Graph>(init), &init_exec));
Executor::Args args;
args.rendezvous = rendez_;
args.runner = runner;
@@ -93,7 +95,7 @@ Benchmark::Benchmark(const string& device, Graph* g,
delete init_exec;
}
- TF_CHECK_OK(NewLocalExecutor(params, g, &exec_));
+ TF_CHECK_OK(NewLocalExecutor(params, std::unique_ptr<Graph>(g), &exec_));
}
Benchmark::~Benchmark() {