aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/training
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2016-10-20 14:33:13 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-20 16:04:19 -0700
commitc852668c41fc82b5ba2e440b18d3adea6847b54f (patch)
treefb3ac6f7277782f5f2ddbfdb0cd5628ca1c6f5d1 /tensorflow/cc/training
parentb09ab769296f361435aa1401db14f302937b6fec (diff)
Add c++ queue runner.
Change: 136769119
Diffstat (limited to 'tensorflow/cc/training')
-rw-r--r--tensorflow/cc/training/queue_runner.cc119
-rw-r--r--tensorflow/cc/training/queue_runner.h79
-rw-r--r--tensorflow/cc/training/queue_runner_test.cc215
3 files changed, 413 insertions, 0 deletions
diff --git a/tensorflow/cc/training/queue_runner.cc b/tensorflow/cc/training/queue_runner.cc
new file mode 100644
index 0000000000..81f49c5dcf
--- /dev/null
+++ b/tensorflow/cc/training/queue_runner.cc
@@ -0,0 +1,119 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/training/queue_runner.h"
+#include "tensorflow/core/platform/env.h"
+
+namespace tensorflow {
+
+QueueRunner::QueueRunner() : started_(false) {}
+
+QueueRunner::QueueRunner(const QueueRunnerDef& queue_runner_def)
+ : started_(false) {
+ TF_CHECK_OK(Init(queue_runner_def));
+}
+
+Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) {
+ if (started_.load()) {
+ return Status(error::ALREADY_EXISTS, "QueueRunner is already running.");
+ }
+ queue_name_ = queue_runner_def.queue_name();
+ enqueue_op_names_.insert(enqueue_op_names_.end(),
+ queue_runner_def.enqueue_op_name().begin(),
+ queue_runner_def.enqueue_op_name().end());
+ runs_ = enqueue_op_names_.size();
+ if (runs_ == 0) {
+ return Status(error::INVALID_ARGUMENT, "Empty enqueue ops to run.");
+ }
+ close_op_name_ = queue_runner_def.close_op_name();
+ cancel_op_name_ = queue_runner_def.cancel_op_name();
+ if (queue_runner_def.queue_closed_exception_types_size() == 0) {
+ queue_closed_exception_types_.insert(error::OUT_OF_RANGE);
+ } else {
+ for (const auto& code : queue_runner_def.queue_closed_exception_types()) {
+ queue_closed_exception_types_.insert(static_cast<int>(code));
+ }
+ }
+
+ thread_pool_.reset(
+ new thread::ThreadPool(Env::Default(), queue_name_, runs_));
+ should_stop_ = false;
+ return Status::OK();
+}
+
+QueueRunner::~QueueRunner() {
+ should_stop_ = true;
+ Join();
+}
+
+Status QueueRunner::Start(Session* sess) {
+ if (runs_ == 0) {
+ return Status(
+ error::INVALID_ARGUMENT,
+ "No enqueue ops to run. You may want to Init the QueueRunner first.");
+ }
+ started_ = true;
+ for (const string& enqueue_op : enqueue_op_names_) {
+ thread_pool_->Schedule(
+ std::bind(&QueueRunner::Run, this, sess, enqueue_op));
+ }
+ return Status::OK();
+}
+
+Status QueueRunner::Join() {
+ thread_pool_.reset();
+ started_ = false;
+ return status_;
+}
+
+void QueueRunner::Run(Session* sess, const string& enqueue_op) {
+ bool decremented = false;
+ while (!should_stop_) {
+ std::vector<Tensor> outputs;
+ auto status = sess->Run({}, {}, {enqueue_op}, &outputs);
+ if (status.ok()) {
+ continue;
+ } else if (queue_closed_exception_types_.count(
+ static_cast<int>(status.code())) > 0) {
+ mutex_lock l(mu_);
+ runs_--;
+ decremented = true;
+ should_stop_ = true;
+
+ // If all enqueue ops have finished, run the close op.
+ if (runs_ == 0 && !close_op_name_.empty()) {
+ std::vector<Tensor> outputs;
+ auto s = sess->Run({}, {}, {close_op_name_}, &outputs);
+ if (!s.ok()) {
+ status_ = status;
+ }
+ }
+ } else {
+ mutex_lock l(mu_);
+ should_stop_ = true;
+ // Only record the first failure status.
+ if (status_.ok()) {
+ status_ = status;
+ }
+ }
+ }
+
+ if (!decremented) {
+ mutex_lock l(mu_);
+ runs_--;
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/cc/training/queue_runner.h b/tensorflow/cc/training/queue_runner.h
new file mode 100644
index 0000000000..7eeab8bd45
--- /dev/null
+++ b/tensorflow/cc/training/queue_runner.h
@@ -0,0 +1,79 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_
+#define THIRD_PARTY_TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_
+
+#include <memory>
+#include <string>
+#include <unordered_set>
+#include <vector>
+#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/protobuf/queue_runner.pb.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tensorflow {
+
+// QueueRunner class imitates the behavior of the python version of QueueRunner
+// which creates a thread for each enqueue op, runs close op on completion.
+class QueueRunner {
+ public:
+ QueueRunner();
+
+ // The constructor initializes the class from the proto.
+ // TODO(yuefengz): we may want to initialize from queues and ops in the
+ // future.
+ explicit QueueRunner(const QueueRunnerDef& queue_runner_def);
+
+ // The destructor would join all the threads.
+ ~QueueRunner();
+
+ // Initializes the instance with the QueueRunnerDef proto.
+ Status Init(const QueueRunnerDef& queue_runner_def);
+
+ // Starts the queue runner with the given session.
+ Status Start(Session* sess);
+
+ // Joins all the threads. Returns okay if all threads run successfully;
+ // otherwise returns the first captured failure status.
+ Status Join();
+
+ private:
+ // The Run function for each thread.
+ void Run(Session* sess, const string& enqueue_op);
+
+ string queue_name_;
+ std::vector<string> enqueue_op_names_;
+ string close_op_name_;
+ // The cancel op is not being called currently.
+ string cancel_op_name_;
+ // code::Code casted to int to avoid a hash function.
+ std::unordered_set<int> queue_closed_exception_types_;
+
+ std::unique_ptr<thread::ThreadPool> thread_pool_;
+ bool should_stop_;
+ std::atomic<bool> started_;
+ mutex mu_;
+ // TODO(yuefengz): implement c++ coordinator.
+ int runs_ = 0;
+ Status status_;
+};
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_
diff --git a/tensorflow/cc/training/queue_runner_test.cc b/tensorflow/cc/training/queue_runner_test.cc
new file mode 100644
index 0000000000..8719677274
--- /dev/null
+++ b/tensorflow/cc/training/queue_runner_test.cc
@@ -0,0 +1,215 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/training/queue_runner.h"
+#include <string>
+#include <vector>
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/protobuf/queue_runner.pb.h"
+#include "tensorflow/core/public/session.h"
+
+namespace {
+
+using ::tensorflow::DataType;
+using ::tensorflow::error::Code;
+using ::tensorflow::GraphDef;
+using ::tensorflow::ops::Assign;
+using ::tensorflow::ops::Const;
+using ::tensorflow::ops::CountUpTo;
+using ::tensorflow::ops::FIFOQueue;
+using ::tensorflow::ops::InputList;
+using ::tensorflow::ops::QueueClose;
+using ::tensorflow::ops::QueueDequeue;
+using ::tensorflow::ops::QueueEnqueue;
+using ::tensorflow::ops::Square;
+using ::tensorflow::ops::Variable;
+using ::tensorflow::QueueRunner;
+using ::tensorflow::QueueRunnerDef;
+using ::tensorflow::Scope;
+using ::tensorflow::Session;
+using ::tensorflow::SessionOptions;
+using ::tensorflow::Tensor;
+using ::tensorflow::TensorShape;
+
+constexpr char kAssignOpName[] = "assign";
+constexpr char kCountUpToOpName[] = "count";
+constexpr char kIllegalOpName1[] = "would fail";
+constexpr char kIllegalOpName2[] = "fail again";
+constexpr char kQueueName[] = "unit_test";
+constexpr char kSquareOpName[] = "square";
+constexpr char kVarOpName[] = "var";
+
+GraphDef BuildSimpleGraph() {
+ Scope root = Scope::NewRootScope();
+ auto init_value = Const(root, 0);
+ auto var = Variable(root.WithOpName(kVarOpName), TensorShape({}),
+ DataType::DT_INT32);
+ auto assign = Assign(root.WithOpName(kAssignOpName), var, init_value);
+ auto count = CountUpTo(root.WithOpName(kCountUpToOpName), var, 10);
+ Square(root.WithOpName(kSquareOpName), var); // NOLINT
+
+ GraphDef graph_def;
+ TF_EXPECT_OK(root.ToGraphDef(&graph_def));
+ return graph_def;
+}
+
+QueueRunnerDef BuildQueueRunnerDef(
+ const std::string& queue_name, const std::vector<std::string>& enqueue_ops,
+ const std::string& close_op,
+ const std::vector<Code>& queue_closed_error_codes) {
+ QueueRunnerDef queue_runner_def;
+ *queue_runner_def.mutable_queue_name() = kQueueName;
+ for (const std::string& enqueue_op : enqueue_ops) {
+ *queue_runner_def.mutable_enqueue_op_name()->Add() = enqueue_op;
+ }
+ *queue_runner_def.mutable_close_op_name() = close_op;
+ for (const auto& error_code : queue_closed_error_codes) {
+ *queue_runner_def.mutable_queue_closed_exception_types()->Add() =
+ error_code;
+ }
+ return queue_runner_def;
+}
+
+std::unique_ptr<Session> BuildSessionAndInitVariable(
+ const GraphDef& graph_def) {
+ SessionOptions options;
+ std::unique_ptr<Session> session(NewSession(options));
+ TF_CHECK_OK(session->Create(graph_def));
+
+ std::vector<Tensor> nothing;
+ TF_CHECK_OK(session->Run({}, {}, {kAssignOpName}, &nothing));
+ return session;
+}
+
+TEST(QueueRunnerTest, BasicTest) {
+ GraphDef graph_def = BuildSimpleGraph();
+ auto session = BuildSessionAndInitVariable(graph_def);
+
+ QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
+ kQueueName, {kCountUpToOpName, kCountUpToOpName}, kSquareOpName, {});
+
+ QueueRunner qr(queue_runner_def);
+ qr.Start(session.get());
+ TF_EXPECT_OK(qr.Join());
+
+ std::vector<Tensor> outputs;
+ TF_EXPECT_OK(session->Run({}, {kSquareOpName}, {}, &outputs));
+ int square_value = *outputs[0].scalar<int>().data();
+ EXPECT_EQ(square_value, 100);
+}
+
+TEST(QueueRunnerTest, QueueClosedCode) {
+ GraphDef graph_def = BuildSimpleGraph();
+ auto session = BuildSessionAndInitVariable(graph_def);
+
+ QueueRunnerDef queue_runner_def =
+ BuildQueueRunnerDef(kQueueName, {kCountUpToOpName}, kSquareOpName,
+ {Code::OUT_OF_RANGE, Code::CANCELLED});
+
+ QueueRunner qr(queue_runner_def);
+ qr.Start(session.get());
+ TF_EXPECT_OK(qr.Join());
+
+ std::vector<Tensor> outputs;
+ TF_EXPECT_OK(session->Run({}, {kSquareOpName}, {}, &outputs));
+ int square_value = *outputs[0].scalar<int>().data();
+ EXPECT_EQ(square_value, 100);
+}
+
+TEST(QueueRunnerDef, CatchErrorInJoin) {
+ GraphDef graph_def = BuildSimpleGraph();
+ auto session = BuildSessionAndInitVariable(graph_def);
+
+ QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
+ kQueueName, {kIllegalOpName1, kIllegalOpName2}, kCountUpToOpName, {});
+
+ QueueRunner qr(queue_runner_def);
+ qr.Start(session.get());
+ EXPECT_EQ(qr.Join().code(), Code::NOT_FOUND);
+}
+
+TEST(QueueRunnerTest, RealEnqueueDequeue) {
+ Scope root = Scope::NewRootScope();
+ auto q0 = FIFOQueue(root.WithOpName("q0"), {DataType::DT_INT32});
+ auto ten = Const(root, 10);
+ auto enqueue0 = QueueEnqueue(root.WithOpName("enqueue0"), q0, {ten});
+ auto close0 = QueueClose(root.WithOpName("close0"), q0);
+ auto q1 = FIFOQueue(root.WithOpName("q1"), {DataType::DT_INT32});
+ auto dequeue0 =
+ QueueDequeue(root.WithOpName("dequeue0"), q0, {DataType::DT_INT32});
+ auto enqueue1 = QueueEnqueue(root.WithOpName("enqueue1"), q1, {dequeue0[0]});
+ auto dequeue1 =
+ QueueDequeue(root.WithOpName("dequeue1"), q1, {DataType::DT_INT32});
+ auto close1 = QueueClose(root.WithOpName("close1"), q1);
+
+ GraphDef graph_def;
+ TF_EXPECT_OK(root.ToGraphDef(&graph_def));
+
+ SessionOptions options;
+ std::unique_ptr<Session> session(NewSession(options));
+ TF_CHECK_OK(session->Create(graph_def));
+
+ QueueRunnerDef queue_runner_def =
+ BuildQueueRunnerDef(kQueueName, {"enqueue1"}, "close1", {});
+ QueueRunner qr;
+ qr.Init(queue_runner_def);
+ TF_CHECK_OK(qr.Start(session.get()));
+
+ std::vector<Tensor> outputs;
+ TF_EXPECT_OK(session->Run({}, {}, {"enqueue0"}, &outputs));
+ TF_EXPECT_OK(session->Run({}, {}, {"enqueue0"}, &outputs));
+ TF_EXPECT_OK(session->Run({}, {}, {"close0"}, &outputs));
+
+ TF_EXPECT_OK(qr.Join());
+ std::vector<Tensor> dq1;
+ TF_EXPECT_OK(session->Run({}, {"dequeue1"}, {}, &dq1));
+ EXPECT_EQ(*dq1[0].scalar<int>().data(), 10);
+ std::vector<Tensor> dq2;
+ TF_EXPECT_OK(session->Run({}, {"dequeue1"}, {}, &dq2));
+ EXPECT_EQ(*dq2[0].scalar<int>().data(), 10);
+
+ EXPECT_EQ(session->Run({}, {"dequeue1"}, {}, &dq1).code(),
+ Code::OUT_OF_RANGE);
+}
+
+TEST(QueueRunnerTest, EmptyEnqueueOps) {
+ QueueRunnerDef queue_runner_def =
+ BuildQueueRunnerDef(kQueueName, {}, kCountUpToOpName, {});
+
+ QueueRunner qr;
+ EXPECT_EQ(qr.Init(queue_runner_def).code(), Code::INVALID_ARGUMENT);
+}
+
+TEST(QueueRunnerTest, InitAfterStart) {
+ GraphDef graph_def = BuildSimpleGraph();
+ auto session = BuildSessionAndInitVariable(graph_def);
+ QueueRunnerDef queue_runner_def =
+ BuildQueueRunnerDef(kQueueName, {kCountUpToOpName}, kCountUpToOpName, {});
+
+ QueueRunner qr;
+ TF_EXPECT_OK(qr.Init(queue_runner_def));
+ TF_EXPECT_OK(qr.Start(session.get()));
+ EXPECT_EQ(qr.Init(queue_runner_def).code(), Code::ALREADY_EXISTS);
+}
+
+} // namespace