diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-03-27 03:23:58 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-27 03:26:34 -0700 |
commit | 1c055f0679ea6cdae28b3c78c3bf98cb40f00e13 (patch) | |
tree | d5ae34143aa113c0f63f3f479331d943db677077 /tensorflow/core/kernels/lookup_util.cc | |
parent | 7555534be3c6138cbcca138556fe4dbf4cc6b8ce (diff) |
Avoid reading the input file twice for InitializableLookupTable in combination with HashTable.
Before this cl, TextFileLineIterator::total_size() was called for HashTable::DoPrepare, even though HashTable::DoPrepare ignores the size parameter.
In order to have a result ready for TextFileLineIterator::total_size(), Init() called GetNumLinesInTextFile(), which read the whole file. Just to throw away the result :-/
This cl:
- adds a DoLazyPrepare, that gets a functor to get the size, only if needed.
- add HashTable::DoLazyPrepare which does not call this functor.
- modify TextFileLineIterator::Init() to not call GetNumLinesInTextFile() anymore, when vocab_size was given as -1.
- modify TextFileLineIterator::total_size() to call GetNumLinesInTextFile() lazily on the first call, if vocab_size_ was passed as -1.
PiperOrigin-RevId: 190593744
Diffstat (limited to 'tensorflow/core/kernels/lookup_util.cc')
-rw-r--r-- | tensorflow/core/kernels/lookup_util.cc | 24 |
1 files changed, 17 insertions, 7 deletions
diff --git a/tensorflow/core/kernels/lookup_util.cc b/tensorflow/core/kernels/lookup_util.cc index c7ce1c3747..27031d9216 100644 --- a/tensorflow/core/kernels/lookup_util.cc +++ b/tensorflow/core/kernels/lookup_util.cc @@ -75,9 +75,6 @@ class TextFileLineIterator Status Init(const string& filename, int64 vocab_size, char delimiter, DataType key_dtype, int64 key_index, DataType value_dtype, int64 value_index, Env* env) { - if (vocab_size == -1) { - TF_RETURN_IF_ERROR(GetNumLinesInTextFile(env, filename, &vocab_size)); - } filename_ = filename; vocab_size_ = vocab_size; delimiter_ = delimiter; @@ -85,6 +82,7 @@ class TextFileLineIterator value_ = Tensor(value_dtype, TensorShape({})); key_index_ = key_index; value_index_ = value_index; + env_ = env; status_ = env->NewRandomAccessFile(filename_, &file_); if (!status_.ok()) return status_; @@ -103,15 +101,15 @@ class TextFileLineIterator string line; status_ = input_buffer_->ReadLine(&line); if (!status_.ok()) { - if (errors::IsOutOfRange(status_) && next_id_ != vocab_size_) { + if (errors::IsOutOfRange(status_) && next_id_ != total_size()) { status_ = errors::InvalidArgument("Invalid vocab_size in ", filename_, - ": expected ", vocab_size_, + ": expected ", total_size(), " but got ", next_id_); } valid_ = false; return; } - if (next_id_ >= vocab_size_) { + if (vocab_size_ != -1 && next_id_ >= vocab_size_) { LOG(WARNING) << "Truncated " << filename_ << " before its end at " << vocab_size_ << " records."; LOG(WARNING) << "next_id_ : " << next_id_; @@ -162,7 +160,18 @@ class TextFileLineIterator Status status() const override { return status_; } - int64 total_size() const override { return vocab_size_; } + int64 total_size() const override { + if (vocab_size_ == -1) { + int64 new_size; + Status status = GetNumLinesInTextFile(env_, filename_, &new_size); + if (!status.ok()) { + LOG(WARNING) << "Unable to get line count: " << status; + new_size = -1; + } + *const_cast<int64*>(&vocab_size_) = new_size; + } + return vocab_size_; + } private: Tensor key_; @@ -170,6 +179,7 @@ class TextFileLineIterator bool valid_; // true if the iterator points to an existing range. int64 key_index_; int64 value_index_; + Env* env_; int64 next_id_; int64 vocab_size_; string filename_; |