diff options
author | 2015-11-06 16:27:58 -0800 | |
---|---|---|
committer | 2015-11-06 16:27:58 -0800 | |
commit | f41959ccb2d9d4c722fe8fc3351401d53bcf4900 (patch) | |
tree | ef0ca22cb2a5ac4bdec9d080d8e0788a53ed496d /tensorflow/core/kernels/reader_base.cc |
TensorFlow: Initial commit of TensorFlow library.
TensorFlow is an open source software library for numerical computation
using data flow graphs.
Base CL: 107276108
Diffstat (limited to 'tensorflow/core/kernels/reader_base.cc')
-rw-r--r-- | tensorflow/core/kernels/reader_base.cc | 156 |
1 files changed, 156 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/reader_base.cc b/tensorflow/core/kernels/reader_base.cc new file mode 100644 index 0000000000..06211efb38 --- /dev/null +++ b/tensorflow/core/kernels/reader_base.cc @@ -0,0 +1,156 @@ +#include "tensorflow/core/kernels/reader_base.h" + +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace tensorflow { + +// ReaderBase ------------------------------------------------------ + +ReaderBase::ReaderBase(const string& name) : name_(name) {} + +int64 ReaderBase::NumRecordsProduced() { + mutex_lock lock(mu_); + return num_records_produced_; +} + +int64 ReaderBase::NumWorkUnitsCompleted() { + mutex_lock lock(mu_); + return work_finished_; +} + +Status ReaderBase::Reset() { + mutex_lock lock(mu_); + return ResetLocked(); +} + +Status ReaderBase::ResetLocked() { + work_started_ = 0; + work_finished_ = 0; + num_records_produced_ = 0; + work_.clear(); + return Status::OK(); +} + +Status ReaderBase::SerializeState(string* state) { + mutex_lock lock(mu_); + return SerializeStateLocked(state); +} + +Status ReaderBase::SerializeStateLocked(string* state) { + return errors::Unimplemented("Reader SerializeState"); +} + +Status ReaderBase::RestoreState(const string& state) { + mutex_lock lock(mu_); + Status status = RestoreStateLocked(state); + if (!status.ok()) { + ResetLocked(); + } + return status; +} + +Status ReaderBase::RestoreStateLocked(const string& state) { + return errors::Unimplemented("Reader RestoreState"); +} + +void ReaderBase::Read(QueueInterface* queue, string* key, string* value, + OpKernelContext* context) { + mutex_lock lock(mu_); + while (true) { + if (!work_in_progress()) { + GetNextWorkLocked(queue, context); + if (!context->status().ok()) return; + } + + bool produced = false; + bool at_end = false; + Status status = ReadLocked(key, value, &produced, &at_end); + + if (!at_end && status.ok() && !produced) { + status = errors::Internal( + "ReadLocked() for ", name(), + " must set *at_end=true, *produced=true, or return an error."); + } + if (!status.ok() && produced) { + status = errors::Internal("ReadLocked() for ", name(), + " set *produced=true *and* returned an error: ", + status.ToString()); + } + if (status.ok() && at_end) { + status = OnWorkFinishedLocked(); + work_finished_ = work_started_; + } + if (!status.ok()) { + context->SetStatus(status); + return; + } + if (produced) { + ++num_records_produced_; + return; + } + } +} + +void ReaderBase::GetNextWorkLocked(QueueInterface* queue, + OpKernelContext* context) { + Notification n; + queue->TryDequeue( + context, [this, context, &n](const QueueInterface::Tuple& tuple) { + if (context->status().ok()) { + if (tuple.size() != 1) { + context->SetStatus( + errors::InvalidArgument("Expected single component queue")); + } else if (tuple[0].dtype() != DT_STRING) { + context->SetStatus(errors::InvalidArgument( + "Expected queue with single string component")); + } else if (tuple[0].NumElements() != 1) { + context->SetStatus(errors::InvalidArgument( + "Expected to dequeue a one-element string tensor")); + } else { + work_ = tuple[0].flat<string>()(0); + ++work_started_; + Status status = OnWorkStartedLocked(); + if (!status.ok()) { + context->SetStatus(status); + --work_started_; + } + } + } + n.Notify(); + }); + n.WaitForNotification(); +} + +void ReaderBase::SaveBaseState(ReaderBaseState* state) const { + state->Clear(); + state->set_work_started(work_started_); + state->set_work_finished(work_finished_); + state->set_num_records_produced(num_records_produced_); + state->set_current_work(work_); +} + +Status ReaderBase::RestoreBaseState(const ReaderBaseState& state) { + work_started_ = state.work_started(); + work_finished_ = state.work_finished(); + num_records_produced_ = state.num_records_produced(); + work_ = state.current_work(); + if (work_started_ < 0 || work_finished_ < 0 || num_records_produced_ < 0) { + return errors::InvalidArgument( + "Unexpected negative value when restoring in ", name(), ": ", + state.ShortDebugString()); + } + if (work_started_ > work_finished_) { + return errors::InvalidArgument( + "Inconsistent work started vs. finished when restoring in ", name(), + ": ", state.ShortDebugString()); + } + return Status::OK(); +} + +} // namespace tensorflow |