aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/models/embedding/word2vec_kernels.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/models/embedding/word2vec_kernels.cc')
-rw-r--r--tensorflow/models/embedding/word2vec_kernels.cc287
1 files changed, 287 insertions, 0 deletions
diff --git a/tensorflow/models/embedding/word2vec_kernels.cc b/tensorflow/models/embedding/word2vec_kernels.cc
new file mode 100644
index 0000000000..f68139fc91
--- /dev/null
+++ b/tensorflow/models/embedding/word2vec_kernels.cc
@@ -0,0 +1,287 @@
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/random/distribution_sampler.h"
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/regexp.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/util/guarded_philox_random.h"
+
+namespace tensorflow {
+
+class SkipgramOp : public OpKernel {
+ public:
+ explicit SkipgramOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx), rng_(&philox_) {
+ string filename;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("filename", &filename));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("batch_size", &batch_size_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("window_size", &window_size_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("min_count", &min_count_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("subsample", &subsample_));
+ OP_REQUIRES_OK(ctx, Init(ctx->env(), filename));
+
+ mutex_lock l(mu_);
+ example_pos_ = corpus_size_;
+ label_pos_ = corpus_size_;
+ label_limit_ = corpus_size_;
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ Tensor words_per_epoch(DT_INT64, TensorShape({}));
+ Tensor current_epoch(DT_INT32, TensorShape({}));
+ Tensor total_words_processed(DT_INT64, TensorShape({}));
+ Tensor examples(DT_INT32, TensorShape({batch_size_}));
+ auto Texamples = examples.flat<int32>();
+ Tensor labels(DT_INT32, TensorShape({batch_size_}));
+ auto Tlabels = labels.flat<int32>();
+ {
+ mutex_lock l(mu_);
+ for (int i = 0; i < batch_size_; ++i) {
+ NextExample(&Texamples(i), &Tlabels(i));
+ }
+ words_per_epoch.scalar<int64>()() = corpus_size_;
+ current_epoch.scalar<int32>()() = current_epoch_;
+ total_words_processed.scalar<int64>()() = total_words_processed_;
+ }
+ ctx->set_output(0, word_);
+ ctx->set_output(1, freq_);
+ ctx->set_output(2, words_per_epoch);
+ ctx->set_output(3, current_epoch);
+ ctx->set_output(4, total_words_processed);
+ ctx->set_output(5, examples);
+ ctx->set_output(6, labels);
+ }
+
+ private:
+ int32 batch_size_ = 0;
+ int32 window_size_ = 5;
+ float subsample_ = 1e-3;
+ int min_count_ = 5;
+ int32 vocab_size_ = 0;
+ Tensor word_;
+ Tensor freq_;
+ int32 corpus_size_ = 0;
+ std::vector<int32> corpus_;
+
+ mutex mu_;
+ random::PhiloxRandom philox_ GUARDED_BY(mu_);
+ random::SimplePhilox rng_ GUARDED_BY(mu_);
+ int32 current_epoch_ GUARDED_BY(mu_) = -1;
+ int64 total_words_processed_ GUARDED_BY(mu_) = 0;
+ int32 example_pos_ GUARDED_BY(mu_);
+ int32 label_pos_ GUARDED_BY(mu_);
+ int32 label_limit_ GUARDED_BY(mu_);
+
+ // {example_pos_, label_pos_} is the cursor for the next example.
+ // example_pos_ wrapps around at the end of corpus_. For each
+ // example, we randomly generate [label_pos_, label_limit) for
+ // labels.
+ void NextExample(int32* example, int32* label) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ while (true) {
+ if (label_pos_ >= label_limit_) {
+ if (example_pos_ + 1 >= corpus_size_) {
+ ++current_epoch_;
+ example_pos_ = 0;
+ } else {
+ ++example_pos_;
+ }
+ ++total_words_processed_;
+ int32 word_freq = freq_.flat<int32>()(corpus_[example_pos_]);
+ if (subsample_ > 0) {
+ // See Eq. 5 in http://arxiv.org/abs/1310.4546
+ float keep_prob =
+ (std::sqrt(word_freq / (subsample_ * corpus_size_)) + 1) *
+ (subsample_ * corpus_size_) / word_freq;
+ if (rng_.RandFloat() > keep_prob) continue;
+ }
+ const int32 skip = 1 + rng_.Uniform(window_size_);
+ label_pos_ = std::max<int32>(0, example_pos_ - skip);
+ label_limit_ = std::min<int32>(corpus_size_, example_pos_ + skip + 1);
+ }
+ if (example_pos_ != label_pos_) {
+ break;
+ }
+ ++label_pos_;
+ }
+ *example = corpus_[example_pos_];
+ *label = corpus_[label_pos_++];
+ }
+
+ Status Init(Env* env, const string& filename) {
+ string data;
+ TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &data));
+ RE2 kWord("\\s*(\\S+)");
+ auto input = ToRegexpStringPiece(data);
+ string w;
+ corpus_size_ = 0;
+ std::unordered_map<string, int32> word_freq;
+ while (RE2::Consume(&input, kWord, &w)) {
+ ++(word_freq[w]);
+ ++corpus_size_;
+ }
+ if (corpus_size_ < window_size_ * 10) {
+ return errors::InvalidArgument("The text file ", filename,
+ " contains too little data: ",
+ corpus_size_, " words");
+ }
+ typedef std::pair<string, int32> WordFreq;
+ std::vector<WordFreq> ordered;
+ for (const auto& p : word_freq) {
+ if (p.second >= min_count_) ordered.push_back(p);
+ }
+ LOG(INFO) << "Data file: " << filename << " contains " << data.size()
+ << " bytes, " << corpus_size_ << " words, " << word_freq.size()
+ << " unique words, " << ordered.size()
+ << " unique frequent words.";
+ word_freq.clear();
+ std::sort(ordered.begin(), ordered.end(),
+ [](const WordFreq& x, const WordFreq& y) {
+ return x.second > y.second;
+ });
+ vocab_size_ = static_cast<int32>(1 + ordered.size());
+ Tensor word(DT_STRING, TensorShape({vocab_size_}));
+ Tensor freq(DT_INT32, TensorShape({vocab_size_}));
+ word.flat<string>()(0) = "UNK";
+ static const int32 kUnkId = 0;
+ std::unordered_map<string, int32> word_id;
+ int64 total_counted = 0;
+ for (std::size_t i = 0; i < ordered.size(); ++i) {
+ const auto& w = ordered[i].first;
+ auto id = i + 1;
+ word.flat<string>()(id) = w;
+ auto word_count = ordered[i].second;
+ freq.flat<int32>()(id) = word_count;
+ total_counted += word_count;
+ word_id[w] = id;
+ }
+ freq.flat<int32>()(kUnkId) = corpus_size_ - total_counted;
+ word_ = word;
+ freq_ = freq;
+ corpus_.reserve(corpus_size_);
+ input = ToRegexpStringPiece(data);
+ while (RE2::Consume(&input, kWord, &w)) {
+ corpus_.push_back(gtl::FindWithDefault(word_id, w, kUnkId));
+ }
+ return Status::OK();
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("Skipgram").Device(DEVICE_CPU), SkipgramOp);
+
+class NegTrainOp : public OpKernel {
+ public:
+ explicit NegTrainOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ base_.Init(0, 0);
+
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("num_negative_samples", &num_samples_));
+
+ std::vector<int32> vocab_count;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("vocab_count", &vocab_count));
+
+ std::vector<float> vocab_weights;
+ vocab_weights.reserve(vocab_count.size());
+ for (const auto& f : vocab_count) {
+ float r = std::pow(static_cast<float>(f), 0.75f);
+ vocab_weights.push_back(r);
+ }
+ sampler_ = new random::DistributionSampler(vocab_weights);
+ }
+
+ ~NegTrainOp() { delete sampler_; }
+
+ void Compute(OpKernelContext* ctx) override {
+ Tensor w_in = ctx->mutable_input(0, false);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(w_in.shape()),
+ errors::InvalidArgument("Must be a matrix"));
+ Tensor w_out = ctx->mutable_input(1, false);
+ OP_REQUIRES(ctx, w_in.shape() == w_out.shape(),
+ errors::InvalidArgument("w_in.shape == w_out.shape"));
+ const Tensor& examples = ctx->input(2);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(examples.shape()),
+ errors::InvalidArgument("Must be a vector"));
+ const Tensor& labels = ctx->input(3);
+ OP_REQUIRES(ctx, examples.shape() == labels.shape(),
+ errors::InvalidArgument("examples.shape == labels.shape"));
+ const Tensor& learning_rate = ctx->input(4);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(learning_rate.shape()),
+ errors::InvalidArgument("Must be a scalar"));
+
+ auto Tw_in = w_in.matrix<float>();
+ auto Tw_out = w_out.matrix<float>();
+ auto Texamples = examples.flat<int32>();
+ auto Tlabels = labels.flat<int32>();
+ auto lr = learning_rate.scalar<float>()();
+ const int64 vocab_size = w_in.dim_size(0);
+ const int64 dims = w_in.dim_size(1);
+ const int64 batch_size = examples.dim_size(0);
+ OP_REQUIRES(ctx, vocab_size == sampler_->num(),
+ errors::InvalidArgument("vocab_size mismatches: ", vocab_size,
+ " vs. ", sampler_->num()));
+
+ // Gradient accumulator for v_in.
+ Tensor buf(DT_FLOAT, TensorShape({dims}));
+ auto Tbuf = buf.flat<float>();
+
+ // Scalar buffer to hold sigmoid(+/- dot).
+ Tensor g_buf(DT_FLOAT, TensorShape({}));
+ auto g = g_buf.scalar<float>();
+
+ // The following loop needs 2 random 32-bit values per negative
+ // sample. We reserve 8 values per sample just in case the
+ // underlying implementation changes.
+ auto rnd = base_.ReserveSamples32(batch_size * num_samples_ * 8);
+ random::SimplePhilox srnd(&rnd);
+
+ for (int64 i = 0; i < batch_size; ++i) {
+ const int32 example = Texamples(i);
+ DCHECK(0 <= example && example < vocab_size) << example;
+ const int32 label = Tlabels(i);
+ DCHECK(0 <= label && label < vocab_size) << label;
+ auto v_in = Tw_in.chip<0>(example);
+
+ // Positive: example predicts label.
+ // forward: x = v_in' * v_out
+ // l = log(sigmoid(x))
+ // backward: dl/dx = g = sigmoid(-x)
+ // dl/d(v_in) = g * v_out'
+ // dl/d(v_out) = v_in' * g
+ {
+ auto v_out = Tw_out.chip<0>(label);
+ auto dot = (v_in * v_out).sum();
+ g = (dot.exp() + 1.f).inverse();
+ Tbuf = v_out * (g() * lr);
+ v_out += v_in * (g() * lr);
+ }
+
+ // Negative samples:
+ // forward: x = v_in' * v_sample
+ // l = log(sigmoid(-x))
+ // backward: dl/dx = g = -sigmoid(x)
+ // dl/d(v_in) = g * v_out'
+ // dl/d(v_out) = v_in' * g
+ for (int j = 0; j < num_samples_; ++j) {
+ const int sample = sampler_->Sample(&srnd);
+ if (sample == label) continue; // Skip.
+ auto v_sample = Tw_out.chip<0>(sample);
+ auto dot = (v_in * v_sample).sum();
+ g = -((-dot).exp() + 1.f).inverse();
+ Tbuf += v_sample * (g() * lr);
+ v_sample += v_in * (g() * lr);
+ }
+
+ // Applies the gradient on v_in.
+ v_in += Tbuf;
+ }
+ }
+
+ private:
+ int32 num_samples_ = 0;
+ random::DistributionSampler* sampler_ = nullptr;
+ GuardedPhiloxRandom base_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("NegTrain").Device(DEVICE_CPU), NegTrainOp);
+
+} // end namespace tensorflow