diff options
Diffstat (limited to 'tensorflow/models/embedding/word2vec_kernels.cc')
-rw-r--r-- | tensorflow/models/embedding/word2vec_kernels.cc | 287 |
1 files changed, 287 insertions, 0 deletions
diff --git a/tensorflow/models/embedding/word2vec_kernels.cc b/tensorflow/models/embedding/word2vec_kernels.cc new file mode 100644 index 0000000000..f68139fc91 --- /dev/null +++ b/tensorflow/models/embedding/word2vec_kernels.cc @@ -0,0 +1,287 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/random/distribution_sampler.h" +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/regexp.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/util/guarded_philox_random.h" + +namespace tensorflow { + +class SkipgramOp : public OpKernel { + public: + explicit SkipgramOp(OpKernelConstruction* ctx) + : OpKernel(ctx), rng_(&philox_) { + string filename; + OP_REQUIRES_OK(ctx, ctx->GetAttr("filename", &filename)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("batch_size", &batch_size_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("window_size", &window_size_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("min_count", &min_count_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("subsample", &subsample_)); + OP_REQUIRES_OK(ctx, Init(ctx->env(), filename)); + + mutex_lock l(mu_); + example_pos_ = corpus_size_; + label_pos_ = corpus_size_; + label_limit_ = corpus_size_; + } + + void Compute(OpKernelContext* ctx) override { + Tensor words_per_epoch(DT_INT64, TensorShape({})); + Tensor current_epoch(DT_INT32, TensorShape({})); + Tensor total_words_processed(DT_INT64, TensorShape({})); + Tensor examples(DT_INT32, TensorShape({batch_size_})); + auto Texamples = examples.flat<int32>(); + Tensor labels(DT_INT32, TensorShape({batch_size_})); + auto Tlabels = labels.flat<int32>(); + { + mutex_lock l(mu_); + for (int i = 0; i < batch_size_; ++i) { + NextExample(&Texamples(i), &Tlabels(i)); + } + words_per_epoch.scalar<int64>()() = corpus_size_; + current_epoch.scalar<int32>()() = current_epoch_; + total_words_processed.scalar<int64>()() = total_words_processed_; + } + ctx->set_output(0, word_); + ctx->set_output(1, freq_); + ctx->set_output(2, words_per_epoch); + ctx->set_output(3, current_epoch); + ctx->set_output(4, total_words_processed); + ctx->set_output(5, examples); + ctx->set_output(6, labels); + } + + private: + int32 batch_size_ = 0; + int32 window_size_ = 5; + float subsample_ = 1e-3; + int min_count_ = 5; + int32 vocab_size_ = 0; + Tensor word_; + Tensor freq_; + int32 corpus_size_ = 0; + std::vector<int32> corpus_; + + mutex mu_; + random::PhiloxRandom philox_ GUARDED_BY(mu_); + random::SimplePhilox rng_ GUARDED_BY(mu_); + int32 current_epoch_ GUARDED_BY(mu_) = -1; + int64 total_words_processed_ GUARDED_BY(mu_) = 0; + int32 example_pos_ GUARDED_BY(mu_); + int32 label_pos_ GUARDED_BY(mu_); + int32 label_limit_ GUARDED_BY(mu_); + + // {example_pos_, label_pos_} is the cursor for the next example. + // example_pos_ wrapps around at the end of corpus_. For each + // example, we randomly generate [label_pos_, label_limit) for + // labels. + void NextExample(int32* example, int32* label) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + while (true) { + if (label_pos_ >= label_limit_) { + if (example_pos_ + 1 >= corpus_size_) { + ++current_epoch_; + example_pos_ = 0; + } else { + ++example_pos_; + } + ++total_words_processed_; + int32 word_freq = freq_.flat<int32>()(corpus_[example_pos_]); + if (subsample_ > 0) { + // See Eq. 5 in http://arxiv.org/abs/1310.4546 + float keep_prob = + (std::sqrt(word_freq / (subsample_ * corpus_size_)) + 1) * + (subsample_ * corpus_size_) / word_freq; + if (rng_.RandFloat() > keep_prob) continue; + } + const int32 skip = 1 + rng_.Uniform(window_size_); + label_pos_ = std::max<int32>(0, example_pos_ - skip); + label_limit_ = std::min<int32>(corpus_size_, example_pos_ + skip + 1); + } + if (example_pos_ != label_pos_) { + break; + } + ++label_pos_; + } + *example = corpus_[example_pos_]; + *label = corpus_[label_pos_++]; + } + + Status Init(Env* env, const string& filename) { + string data; + TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &data)); + RE2 kWord("\\s*(\\S+)"); + auto input = ToRegexpStringPiece(data); + string w; + corpus_size_ = 0; + std::unordered_map<string, int32> word_freq; + while (RE2::Consume(&input, kWord, &w)) { + ++(word_freq[w]); + ++corpus_size_; + } + if (corpus_size_ < window_size_ * 10) { + return errors::InvalidArgument("The text file ", filename, + " contains too little data: ", + corpus_size_, " words"); + } + typedef std::pair<string, int32> WordFreq; + std::vector<WordFreq> ordered; + for (const auto& p : word_freq) { + if (p.second >= min_count_) ordered.push_back(p); + } + LOG(INFO) << "Data file: " << filename << " contains " << data.size() + << " bytes, " << corpus_size_ << " words, " << word_freq.size() + << " unique words, " << ordered.size() + << " unique frequent words."; + word_freq.clear(); + std::sort(ordered.begin(), ordered.end(), + [](const WordFreq& x, const WordFreq& y) { + return x.second > y.second; + }); + vocab_size_ = static_cast<int32>(1 + ordered.size()); + Tensor word(DT_STRING, TensorShape({vocab_size_})); + Tensor freq(DT_INT32, TensorShape({vocab_size_})); + word.flat<string>()(0) = "UNK"; + static const int32 kUnkId = 0; + std::unordered_map<string, int32> word_id; + int64 total_counted = 0; + for (std::size_t i = 0; i < ordered.size(); ++i) { + const auto& w = ordered[i].first; + auto id = i + 1; + word.flat<string>()(id) = w; + auto word_count = ordered[i].second; + freq.flat<int32>()(id) = word_count; + total_counted += word_count; + word_id[w] = id; + } + freq.flat<int32>()(kUnkId) = corpus_size_ - total_counted; + word_ = word; + freq_ = freq; + corpus_.reserve(corpus_size_); + input = ToRegexpStringPiece(data); + while (RE2::Consume(&input, kWord, &w)) { + corpus_.push_back(gtl::FindWithDefault(word_id, w, kUnkId)); + } + return Status::OK(); + } +}; + +REGISTER_KERNEL_BUILDER(Name("Skipgram").Device(DEVICE_CPU), SkipgramOp); + +class NegTrainOp : public OpKernel { + public: + explicit NegTrainOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + base_.Init(0, 0); + + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_negative_samples", &num_samples_)); + + std::vector<int32> vocab_count; + OP_REQUIRES_OK(ctx, ctx->GetAttr("vocab_count", &vocab_count)); + + std::vector<float> vocab_weights; + vocab_weights.reserve(vocab_count.size()); + for (const auto& f : vocab_count) { + float r = std::pow(static_cast<float>(f), 0.75f); + vocab_weights.push_back(r); + } + sampler_ = new random::DistributionSampler(vocab_weights); + } + + ~NegTrainOp() { delete sampler_; } + + void Compute(OpKernelContext* ctx) override { + Tensor w_in = ctx->mutable_input(0, false); + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(w_in.shape()), + errors::InvalidArgument("Must be a matrix")); + Tensor w_out = ctx->mutable_input(1, false); + OP_REQUIRES(ctx, w_in.shape() == w_out.shape(), + errors::InvalidArgument("w_in.shape == w_out.shape")); + const Tensor& examples = ctx->input(2); + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(examples.shape()), + errors::InvalidArgument("Must be a vector")); + const Tensor& labels = ctx->input(3); + OP_REQUIRES(ctx, examples.shape() == labels.shape(), + errors::InvalidArgument("examples.shape == labels.shape")); + const Tensor& learning_rate = ctx->input(4); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(learning_rate.shape()), + errors::InvalidArgument("Must be a scalar")); + + auto Tw_in = w_in.matrix<float>(); + auto Tw_out = w_out.matrix<float>(); + auto Texamples = examples.flat<int32>(); + auto Tlabels = labels.flat<int32>(); + auto lr = learning_rate.scalar<float>()(); + const int64 vocab_size = w_in.dim_size(0); + const int64 dims = w_in.dim_size(1); + const int64 batch_size = examples.dim_size(0); + OP_REQUIRES(ctx, vocab_size == sampler_->num(), + errors::InvalidArgument("vocab_size mismatches: ", vocab_size, + " vs. ", sampler_->num())); + + // Gradient accumulator for v_in. + Tensor buf(DT_FLOAT, TensorShape({dims})); + auto Tbuf = buf.flat<float>(); + + // Scalar buffer to hold sigmoid(+/- dot). + Tensor g_buf(DT_FLOAT, TensorShape({})); + auto g = g_buf.scalar<float>(); + + // The following loop needs 2 random 32-bit values per negative + // sample. We reserve 8 values per sample just in case the + // underlying implementation changes. + auto rnd = base_.ReserveSamples32(batch_size * num_samples_ * 8); + random::SimplePhilox srnd(&rnd); + + for (int64 i = 0; i < batch_size; ++i) { + const int32 example = Texamples(i); + DCHECK(0 <= example && example < vocab_size) << example; + const int32 label = Tlabels(i); + DCHECK(0 <= label && label < vocab_size) << label; + auto v_in = Tw_in.chip<0>(example); + + // Positive: example predicts label. + // forward: x = v_in' * v_out + // l = log(sigmoid(x)) + // backward: dl/dx = g = sigmoid(-x) + // dl/d(v_in) = g * v_out' + // dl/d(v_out) = v_in' * g + { + auto v_out = Tw_out.chip<0>(label); + auto dot = (v_in * v_out).sum(); + g = (dot.exp() + 1.f).inverse(); + Tbuf = v_out * (g() * lr); + v_out += v_in * (g() * lr); + } + + // Negative samples: + // forward: x = v_in' * v_sample + // l = log(sigmoid(-x)) + // backward: dl/dx = g = -sigmoid(x) + // dl/d(v_in) = g * v_out' + // dl/d(v_out) = v_in' * g + for (int j = 0; j < num_samples_; ++j) { + const int sample = sampler_->Sample(&srnd); + if (sample == label) continue; // Skip. + auto v_sample = Tw_out.chip<0>(sample); + auto dot = (v_in * v_sample).sum(); + g = -((-dot).exp() + 1.f).inverse(); + Tbuf += v_sample * (g() * lr); + v_sample += v_in * (g() * lr); + } + + // Applies the gradient on v_in. + v_in += Tbuf; + } + } + + private: + int32 num_samples_ = 0; + random::DistributionSampler* sampler_ = nullptr; + GuardedPhiloxRandom base_; +}; + +REGISTER_KERNEL_BUILDER(Name("NegTrain").Device(DEVICE_CPU), NegTrainOp); + +} // end namespace tensorflow |