diff options
author | Yuefeng Zhou <yuefengz@google.com> | 2016-10-20 14:33:13 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-10-20 16:04:19 -0700 |
commit | c852668c41fc82b5ba2e440b18d3adea6847b54f (patch) | |
tree | fb3ac6f7277782f5f2ddbfdb0cd5628ca1c6f5d1 /tensorflow/cc/training | |
parent | b09ab769296f361435aa1401db14f302937b6fec (diff) |
Add c++ queue runner.
Change: 136769119
Diffstat (limited to 'tensorflow/cc/training')
-rw-r--r-- | tensorflow/cc/training/queue_runner.cc | 119 | ||||
-rw-r--r-- | tensorflow/cc/training/queue_runner.h | 79 | ||||
-rw-r--r-- | tensorflow/cc/training/queue_runner_test.cc | 215 |
3 files changed, 413 insertions, 0 deletions
diff --git a/tensorflow/cc/training/queue_runner.cc b/tensorflow/cc/training/queue_runner.cc new file mode 100644 index 0000000000..81f49c5dcf --- /dev/null +++ b/tensorflow/cc/training/queue_runner.cc @@ -0,0 +1,119 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/training/queue_runner.h" +#include "tensorflow/core/platform/env.h" + +namespace tensorflow { + +QueueRunner::QueueRunner() : started_(false) {} + +QueueRunner::QueueRunner(const QueueRunnerDef& queue_runner_def) + : started_(false) { + TF_CHECK_OK(Init(queue_runner_def)); +} + +Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) { + if (started_.load()) { + return Status(error::ALREADY_EXISTS, "QueueRunner is already running."); + } + queue_name_ = queue_runner_def.queue_name(); + enqueue_op_names_.insert(enqueue_op_names_.end(), + queue_runner_def.enqueue_op_name().begin(), + queue_runner_def.enqueue_op_name().end()); + runs_ = enqueue_op_names_.size(); + if (runs_ == 0) { + return Status(error::INVALID_ARGUMENT, "Empty enqueue ops to run."); + } + close_op_name_ = queue_runner_def.close_op_name(); + cancel_op_name_ = queue_runner_def.cancel_op_name(); + if (queue_runner_def.queue_closed_exception_types_size() == 0) { + queue_closed_exception_types_.insert(error::OUT_OF_RANGE); + } else { + for (const auto& code : queue_runner_def.queue_closed_exception_types()) { + queue_closed_exception_types_.insert(static_cast<int>(code)); + } + } + + thread_pool_.reset( + new thread::ThreadPool(Env::Default(), queue_name_, runs_)); + should_stop_ = false; + return Status::OK(); +} + +QueueRunner::~QueueRunner() { + should_stop_ = true; + Join(); +} + +Status QueueRunner::Start(Session* sess) { + if (runs_ == 0) { + return Status( + error::INVALID_ARGUMENT, + "No enqueue ops to run. You may want to Init the QueueRunner first."); + } + started_ = true; + for (const string& enqueue_op : enqueue_op_names_) { + thread_pool_->Schedule( + std::bind(&QueueRunner::Run, this, sess, enqueue_op)); + } + return Status::OK(); +} + +Status QueueRunner::Join() { + thread_pool_.reset(); + started_ = false; + return status_; +} + +void QueueRunner::Run(Session* sess, const string& enqueue_op) { + bool decremented = false; + while (!should_stop_) { + std::vector<Tensor> outputs; + auto status = sess->Run({}, {}, {enqueue_op}, &outputs); + if (status.ok()) { + continue; + } else if (queue_closed_exception_types_.count( + static_cast<int>(status.code())) > 0) { + mutex_lock l(mu_); + runs_--; + decremented = true; + should_stop_ = true; + + // If all enqueue ops have finished, run the close op. + if (runs_ == 0 && !close_op_name_.empty()) { + std::vector<Tensor> outputs; + auto s = sess->Run({}, {}, {close_op_name_}, &outputs); + if (!s.ok()) { + status_ = status; + } + } + } else { + mutex_lock l(mu_); + should_stop_ = true; + // Only record the first failure status. + if (status_.ok()) { + status_ = status; + } + } + } + + if (!decremented) { + mutex_lock l(mu_); + runs_--; + } +} + +} // namespace tensorflow diff --git a/tensorflow/cc/training/queue_runner.h b/tensorflow/cc/training/queue_runner.h new file mode 100644 index 0000000000..7eeab8bd45 --- /dev/null +++ b/tensorflow/cc/training/queue_runner.h @@ -0,0 +1,79 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_ +#define THIRD_PARTY_TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_ + +#include <memory> +#include <string> +#include <unordered_set> +#include <vector> +#include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/protobuf/queue_runner.pb.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { + +// QueueRunner class imitates the behavior of the python version of QueueRunner +// which creates a thread for each enqueue op, runs close op on completion. +class QueueRunner { + public: + QueueRunner(); + + // The constructor initializes the class from the proto. + // TODO(yuefengz): we may want to initialize from queues and ops in the + // future. + explicit QueueRunner(const QueueRunnerDef& queue_runner_def); + + // The destructor would join all the threads. + ~QueueRunner(); + + // Initializes the instance with the QueueRunnerDef proto. + Status Init(const QueueRunnerDef& queue_runner_def); + + // Starts the queue runner with the given session. + Status Start(Session* sess); + + // Joins all the threads. Returns okay if all threads run successfully; + // otherwise returns the first captured failure status. + Status Join(); + + private: + // The Run function for each thread. + void Run(Session* sess, const string& enqueue_op); + + string queue_name_; + std::vector<string> enqueue_op_names_; + string close_op_name_; + // The cancel op is not being called currently. + string cancel_op_name_; + // code::Code casted to int to avoid a hash function. + std::unordered_set<int> queue_closed_exception_types_; + + std::unique_ptr<thread::ThreadPool> thread_pool_; + bool should_stop_; + std::atomic<bool> started_; + mutex mu_; + // TODO(yuefengz): implement c++ coordinator. + int runs_ = 0; + Status status_; +}; + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_ diff --git a/tensorflow/cc/training/queue_runner_test.cc b/tensorflow/cc/training/queue_runner_test.cc new file mode 100644 index 0000000000..8719677274 --- /dev/null +++ b/tensorflow/cc/training/queue_runner_test.cc @@ -0,0 +1,215 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/training/queue_runner.h" +#include <string> +#include <vector> +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/queue_runner.pb.h" +#include "tensorflow/core/public/session.h" + +namespace { + +using ::tensorflow::DataType; +using ::tensorflow::error::Code; +using ::tensorflow::GraphDef; +using ::tensorflow::ops::Assign; +using ::tensorflow::ops::Const; +using ::tensorflow::ops::CountUpTo; +using ::tensorflow::ops::FIFOQueue; +using ::tensorflow::ops::InputList; +using ::tensorflow::ops::QueueClose; +using ::tensorflow::ops::QueueDequeue; +using ::tensorflow::ops::QueueEnqueue; +using ::tensorflow::ops::Square; +using ::tensorflow::ops::Variable; +using ::tensorflow::QueueRunner; +using ::tensorflow::QueueRunnerDef; +using ::tensorflow::Scope; +using ::tensorflow::Session; +using ::tensorflow::SessionOptions; +using ::tensorflow::Tensor; +using ::tensorflow::TensorShape; + +constexpr char kAssignOpName[] = "assign"; +constexpr char kCountUpToOpName[] = "count"; +constexpr char kIllegalOpName1[] = "would fail"; +constexpr char kIllegalOpName2[] = "fail again"; +constexpr char kQueueName[] = "unit_test"; +constexpr char kSquareOpName[] = "square"; +constexpr char kVarOpName[] = "var"; + +GraphDef BuildSimpleGraph() { + Scope root = Scope::NewRootScope(); + auto init_value = Const(root, 0); + auto var = Variable(root.WithOpName(kVarOpName), TensorShape({}), + DataType::DT_INT32); + auto assign = Assign(root.WithOpName(kAssignOpName), var, init_value); + auto count = CountUpTo(root.WithOpName(kCountUpToOpName), var, 10); + Square(root.WithOpName(kSquareOpName), var); // NOLINT + + GraphDef graph_def; + TF_EXPECT_OK(root.ToGraphDef(&graph_def)); + return graph_def; +} + +QueueRunnerDef BuildQueueRunnerDef( + const std::string& queue_name, const std::vector<std::string>& enqueue_ops, + const std::string& close_op, + const std::vector<Code>& queue_closed_error_codes) { + QueueRunnerDef queue_runner_def; + *queue_runner_def.mutable_queue_name() = kQueueName; + for (const std::string& enqueue_op : enqueue_ops) { + *queue_runner_def.mutable_enqueue_op_name()->Add() = enqueue_op; + } + *queue_runner_def.mutable_close_op_name() = close_op; + for (const auto& error_code : queue_closed_error_codes) { + *queue_runner_def.mutable_queue_closed_exception_types()->Add() = + error_code; + } + return queue_runner_def; +} + +std::unique_ptr<Session> BuildSessionAndInitVariable( + const GraphDef& graph_def) { + SessionOptions options; + std::unique_ptr<Session> session(NewSession(options)); + TF_CHECK_OK(session->Create(graph_def)); + + std::vector<Tensor> nothing; + TF_CHECK_OK(session->Run({}, {}, {kAssignOpName}, ¬hing)); + return session; +} + +TEST(QueueRunnerTest, BasicTest) { + GraphDef graph_def = BuildSimpleGraph(); + auto session = BuildSessionAndInitVariable(graph_def); + + QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( + kQueueName, {kCountUpToOpName, kCountUpToOpName}, kSquareOpName, {}); + + QueueRunner qr(queue_runner_def); + qr.Start(session.get()); + TF_EXPECT_OK(qr.Join()); + + std::vector<Tensor> outputs; + TF_EXPECT_OK(session->Run({}, {kSquareOpName}, {}, &outputs)); + int square_value = *outputs[0].scalar<int>().data(); + EXPECT_EQ(square_value, 100); +} + +TEST(QueueRunnerTest, QueueClosedCode) { + GraphDef graph_def = BuildSimpleGraph(); + auto session = BuildSessionAndInitVariable(graph_def); + + QueueRunnerDef queue_runner_def = + BuildQueueRunnerDef(kQueueName, {kCountUpToOpName}, kSquareOpName, + {Code::OUT_OF_RANGE, Code::CANCELLED}); + + QueueRunner qr(queue_runner_def); + qr.Start(session.get()); + TF_EXPECT_OK(qr.Join()); + + std::vector<Tensor> outputs; + TF_EXPECT_OK(session->Run({}, {kSquareOpName}, {}, &outputs)); + int square_value = *outputs[0].scalar<int>().data(); + EXPECT_EQ(square_value, 100); +} + +TEST(QueueRunnerDef, CatchErrorInJoin) { + GraphDef graph_def = BuildSimpleGraph(); + auto session = BuildSessionAndInitVariable(graph_def); + + QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( + kQueueName, {kIllegalOpName1, kIllegalOpName2}, kCountUpToOpName, {}); + + QueueRunner qr(queue_runner_def); + qr.Start(session.get()); + EXPECT_EQ(qr.Join().code(), Code::NOT_FOUND); +} + +TEST(QueueRunnerTest, RealEnqueueDequeue) { + Scope root = Scope::NewRootScope(); + auto q0 = FIFOQueue(root.WithOpName("q0"), {DataType::DT_INT32}); + auto ten = Const(root, 10); + auto enqueue0 = QueueEnqueue(root.WithOpName("enqueue0"), q0, {ten}); + auto close0 = QueueClose(root.WithOpName("close0"), q0); + auto q1 = FIFOQueue(root.WithOpName("q1"), {DataType::DT_INT32}); + auto dequeue0 = + QueueDequeue(root.WithOpName("dequeue0"), q0, {DataType::DT_INT32}); + auto enqueue1 = QueueEnqueue(root.WithOpName("enqueue1"), q1, {dequeue0[0]}); + auto dequeue1 = + QueueDequeue(root.WithOpName("dequeue1"), q1, {DataType::DT_INT32}); + auto close1 = QueueClose(root.WithOpName("close1"), q1); + + GraphDef graph_def; + TF_EXPECT_OK(root.ToGraphDef(&graph_def)); + + SessionOptions options; + std::unique_ptr<Session> session(NewSession(options)); + TF_CHECK_OK(session->Create(graph_def)); + + QueueRunnerDef queue_runner_def = + BuildQueueRunnerDef(kQueueName, {"enqueue1"}, "close1", {}); + QueueRunner qr; + qr.Init(queue_runner_def); + TF_CHECK_OK(qr.Start(session.get())); + + std::vector<Tensor> outputs; + TF_EXPECT_OK(session->Run({}, {}, {"enqueue0"}, &outputs)); + TF_EXPECT_OK(session->Run({}, {}, {"enqueue0"}, &outputs)); + TF_EXPECT_OK(session->Run({}, {}, {"close0"}, &outputs)); + + TF_EXPECT_OK(qr.Join()); + std::vector<Tensor> dq1; + TF_EXPECT_OK(session->Run({}, {"dequeue1"}, {}, &dq1)); + EXPECT_EQ(*dq1[0].scalar<int>().data(), 10); + std::vector<Tensor> dq2; + TF_EXPECT_OK(session->Run({}, {"dequeue1"}, {}, &dq2)); + EXPECT_EQ(*dq2[0].scalar<int>().data(), 10); + + EXPECT_EQ(session->Run({}, {"dequeue1"}, {}, &dq1).code(), + Code::OUT_OF_RANGE); +} + +TEST(QueueRunnerTest, EmptyEnqueueOps) { + QueueRunnerDef queue_runner_def = + BuildQueueRunnerDef(kQueueName, {}, kCountUpToOpName, {}); + + QueueRunner qr; + EXPECT_EQ(qr.Init(queue_runner_def).code(), Code::INVALID_ARGUMENT); +} + +TEST(QueueRunnerTest, InitAfterStart) { + GraphDef graph_def = BuildSimpleGraph(); + auto session = BuildSessionAndInitVariable(graph_def); + QueueRunnerDef queue_runner_def = + BuildQueueRunnerDef(kQueueName, {kCountUpToOpName}, kCountUpToOpName, {}); + + QueueRunner qr; + TF_EXPECT_OK(qr.Init(queue_runner_def)); + TF_EXPECT_OK(qr.Start(session.get())); + EXPECT_EQ(qr.Init(queue_runner_def).code(), Code::ALREADY_EXISTS); +} + +} // namespace |