aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/reader_base.cc
diff options
context:
space:
mode:
authorGravatar Manjunath Kudlur <keveman@gmail.com>2015-11-06 16:27:58 -0800
committerGravatar Manjunath Kudlur <keveman@gmail.com>2015-11-06 16:27:58 -0800
commitf41959ccb2d9d4c722fe8fc3351401d53bcf4900 (patch)
treeef0ca22cb2a5ac4bdec9d080d8e0788a53ed496d /tensorflow/core/kernels/reader_base.cc
TensorFlow: Initial commit of TensorFlow library.
TensorFlow is an open source software library for numerical computation using data flow graphs. Base CL: 107276108
Diffstat (limited to 'tensorflow/core/kernels/reader_base.cc')
-rw-r--r--tensorflow/core/kernels/reader_base.cc156
1 files changed, 156 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/reader_base.cc b/tensorflow/core/kernels/reader_base.cc
new file mode 100644
index 0000000000..06211efb38
--- /dev/null
+++ b/tensorflow/core/kernels/reader_base.cc
@@ -0,0 +1,156 @@
+#include "tensorflow/core/kernels/reader_base.h"
+
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/coding.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace tensorflow {
+
+// ReaderBase ------------------------------------------------------
+
+ReaderBase::ReaderBase(const string& name) : name_(name) {}
+
+int64 ReaderBase::NumRecordsProduced() {
+ mutex_lock lock(mu_);
+ return num_records_produced_;
+}
+
+int64 ReaderBase::NumWorkUnitsCompleted() {
+ mutex_lock lock(mu_);
+ return work_finished_;
+}
+
+Status ReaderBase::Reset() {
+ mutex_lock lock(mu_);
+ return ResetLocked();
+}
+
+Status ReaderBase::ResetLocked() {
+ work_started_ = 0;
+ work_finished_ = 0;
+ num_records_produced_ = 0;
+ work_.clear();
+ return Status::OK();
+}
+
+Status ReaderBase::SerializeState(string* state) {
+ mutex_lock lock(mu_);
+ return SerializeStateLocked(state);
+}
+
+Status ReaderBase::SerializeStateLocked(string* state) {
+ return errors::Unimplemented("Reader SerializeState");
+}
+
+Status ReaderBase::RestoreState(const string& state) {
+ mutex_lock lock(mu_);
+ Status status = RestoreStateLocked(state);
+ if (!status.ok()) {
+ ResetLocked();
+ }
+ return status;
+}
+
+Status ReaderBase::RestoreStateLocked(const string& state) {
+ return errors::Unimplemented("Reader RestoreState");
+}
+
+void ReaderBase::Read(QueueInterface* queue, string* key, string* value,
+ OpKernelContext* context) {
+ mutex_lock lock(mu_);
+ while (true) {
+ if (!work_in_progress()) {
+ GetNextWorkLocked(queue, context);
+ if (!context->status().ok()) return;
+ }
+
+ bool produced = false;
+ bool at_end = false;
+ Status status = ReadLocked(key, value, &produced, &at_end);
+
+ if (!at_end && status.ok() && !produced) {
+ status = errors::Internal(
+ "ReadLocked() for ", name(),
+ " must set *at_end=true, *produced=true, or return an error.");
+ }
+ if (!status.ok() && produced) {
+ status = errors::Internal("ReadLocked() for ", name(),
+ " set *produced=true *and* returned an error: ",
+ status.ToString());
+ }
+ if (status.ok() && at_end) {
+ status = OnWorkFinishedLocked();
+ work_finished_ = work_started_;
+ }
+ if (!status.ok()) {
+ context->SetStatus(status);
+ return;
+ }
+ if (produced) {
+ ++num_records_produced_;
+ return;
+ }
+ }
+}
+
+void ReaderBase::GetNextWorkLocked(QueueInterface* queue,
+ OpKernelContext* context) {
+ Notification n;
+ queue->TryDequeue(
+ context, [this, context, &n](const QueueInterface::Tuple& tuple) {
+ if (context->status().ok()) {
+ if (tuple.size() != 1) {
+ context->SetStatus(
+ errors::InvalidArgument("Expected single component queue"));
+ } else if (tuple[0].dtype() != DT_STRING) {
+ context->SetStatus(errors::InvalidArgument(
+ "Expected queue with single string component"));
+ } else if (tuple[0].NumElements() != 1) {
+ context->SetStatus(errors::InvalidArgument(
+ "Expected to dequeue a one-element string tensor"));
+ } else {
+ work_ = tuple[0].flat<string>()(0);
+ ++work_started_;
+ Status status = OnWorkStartedLocked();
+ if (!status.ok()) {
+ context->SetStatus(status);
+ --work_started_;
+ }
+ }
+ }
+ n.Notify();
+ });
+ n.WaitForNotification();
+}
+
+void ReaderBase::SaveBaseState(ReaderBaseState* state) const {
+ state->Clear();
+ state->set_work_started(work_started_);
+ state->set_work_finished(work_finished_);
+ state->set_num_records_produced(num_records_produced_);
+ state->set_current_work(work_);
+}
+
+Status ReaderBase::RestoreBaseState(const ReaderBaseState& state) {
+ work_started_ = state.work_started();
+ work_finished_ = state.work_finished();
+ num_records_produced_ = state.num_records_produced();
+ work_ = state.current_work();
+ if (work_started_ < 0 || work_finished_ < 0 || num_records_produced_ < 0) {
+ return errors::InvalidArgument(
+ "Unexpected negative value when restoring in ", name(), ": ",
+ state.ShortDebugString());
+ }
+ if (work_started_ > work_finished_) {
+ return errors::InvalidArgument(
+ "Inconsistent work started vs. finished when restoring in ", name(),
+ ": ", state.ShortDebugString());
+ }
+ return Status::OK();
+}
+
+} // namespace tensorflow