aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/lookup_util.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-27 03:23:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-27 03:26:34 -0700
commit1c055f0679ea6cdae28b3c78c3bf98cb40f00e13 (patch)
treed5ae34143aa113c0f63f3f479331d943db677077 /tensorflow/core/kernels/lookup_util.cc
parent7555534be3c6138cbcca138556fe4dbf4cc6b8ce (diff)
Avoid reading the input file twice for InitializableLookupTable in combination with HashTable.
Before this cl, TextFileLineIterator::total_size() was called for HashTable::DoPrepare, even though HashTable::DoPrepare ignores the size parameter. In order to have a result ready for TextFileLineIterator::total_size(), Init() called GetNumLinesInTextFile(), which read the whole file. Just to throw away the result :-/ This cl: - adds a DoLazyPrepare, that gets a functor to get the size, only if needed. - add HashTable::DoLazyPrepare which does not call this functor. - modify TextFileLineIterator::Init() to not call GetNumLinesInTextFile() anymore, when vocab_size was given as -1. - modify TextFileLineIterator::total_size() to call GetNumLinesInTextFile() lazily on the first call, if vocab_size_ was passed as -1. PiperOrigin-RevId: 190593744
Diffstat (limited to 'tensorflow/core/kernels/lookup_util.cc')
-rw-r--r--tensorflow/core/kernels/lookup_util.cc24
1 files changed, 17 insertions, 7 deletions
diff --git a/tensorflow/core/kernels/lookup_util.cc b/tensorflow/core/kernels/lookup_util.cc
index c7ce1c3747..27031d9216 100644
--- a/tensorflow/core/kernels/lookup_util.cc
+++ b/tensorflow/core/kernels/lookup_util.cc
@@ -75,9 +75,6 @@ class TextFileLineIterator
Status Init(const string& filename, int64 vocab_size, char delimiter,
DataType key_dtype, int64 key_index, DataType value_dtype,
int64 value_index, Env* env) {
- if (vocab_size == -1) {
- TF_RETURN_IF_ERROR(GetNumLinesInTextFile(env, filename, &vocab_size));
- }
filename_ = filename;
vocab_size_ = vocab_size;
delimiter_ = delimiter;
@@ -85,6 +82,7 @@ class TextFileLineIterator
value_ = Tensor(value_dtype, TensorShape({}));
key_index_ = key_index;
value_index_ = value_index;
+ env_ = env;
status_ = env->NewRandomAccessFile(filename_, &file_);
if (!status_.ok()) return status_;
@@ -103,15 +101,15 @@ class TextFileLineIterator
string line;
status_ = input_buffer_->ReadLine(&line);
if (!status_.ok()) {
- if (errors::IsOutOfRange(status_) && next_id_ != vocab_size_) {
+ if (errors::IsOutOfRange(status_) && next_id_ != total_size()) {
status_ = errors::InvalidArgument("Invalid vocab_size in ", filename_,
- ": expected ", vocab_size_,
+ ": expected ", total_size(),
" but got ", next_id_);
}
valid_ = false;
return;
}
- if (next_id_ >= vocab_size_) {
+ if (vocab_size_ != -1 && next_id_ >= vocab_size_) {
LOG(WARNING) << "Truncated " << filename_ << " before its end at "
<< vocab_size_ << " records.";
LOG(WARNING) << "next_id_ : " << next_id_;
@@ -162,7 +160,18 @@ class TextFileLineIterator
Status status() const override { return status_; }
- int64 total_size() const override { return vocab_size_; }
+ int64 total_size() const override {
+ if (vocab_size_ == -1) {
+ int64 new_size;
+ Status status = GetNumLinesInTextFile(env_, filename_, &new_size);
+ if (!status.ok()) {
+ LOG(WARNING) << "Unable to get line count: " << status;
+ new_size = -1;
+ }
+ *const_cast<int64*>(&vocab_size_) = new_size;
+ }
+ return vocab_size_;
+ }
private:
Tensor key_;
@@ -170,6 +179,7 @@ class TextFileLineIterator
bool valid_; // true if the iterator points to an existing range.
int64 key_index_;
int64 value_index_;
+ Env* env_;
int64 next_id_;
int64 vocab_size_;
string filename_;