#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/random/distribution_sampler.h" #include "tensorflow/core/lib/random/philox_random.h" #include "tensorflow/core/lib/random/simple_philox.h" #include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/util/guarded_philox_random.h" namespace tensorflow { class SkipgramOp : public OpKernel { public: explicit SkipgramOp(OpKernelConstruction* ctx) : OpKernel(ctx), rng_(&philox_) { string filename; OP_REQUIRES_OK(ctx, ctx->GetAttr("filename", &filename)); OP_REQUIRES_OK(ctx, ctx->GetAttr("batch_size", &batch_size_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("window_size", &window_size_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("min_count", &min_count_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("subsample", &subsample_)); OP_REQUIRES_OK(ctx, Init(ctx->env(), filename)); mutex_lock l(mu_); example_pos_ = corpus_size_; label_pos_ = corpus_size_; label_limit_ = corpus_size_; } void Compute(OpKernelContext* ctx) override { Tensor words_per_epoch(DT_INT64, TensorShape({})); Tensor current_epoch(DT_INT32, TensorShape({})); Tensor total_words_processed(DT_INT64, TensorShape({})); Tensor examples(DT_INT32, TensorShape({batch_size_})); auto Texamples = examples.flat(); Tensor labels(DT_INT32, TensorShape({batch_size_})); auto Tlabels = labels.flat(); { mutex_lock l(mu_); for (int i = 0; i < batch_size_; ++i) { NextExample(&Texamples(i), &Tlabels(i)); } words_per_epoch.scalar()() = corpus_size_; current_epoch.scalar()() = current_epoch_; total_words_processed.scalar()() = total_words_processed_; } ctx->set_output(0, word_); ctx->set_output(1, freq_); ctx->set_output(2, words_per_epoch); ctx->set_output(3, current_epoch); ctx->set_output(4, total_words_processed); ctx->set_output(5, examples); ctx->set_output(6, labels); } private: int32 batch_size_ = 0; int32 window_size_ = 5; float subsample_ = 1e-3; int min_count_ = 5; int32 vocab_size_ = 0; Tensor word_; Tensor freq_; int32 corpus_size_ = 0; std::vector corpus_; mutex mu_; random::PhiloxRandom philox_ GUARDED_BY(mu_); random::SimplePhilox rng_ GUARDED_BY(mu_); int32 current_epoch_ GUARDED_BY(mu_) = -1; int64 total_words_processed_ GUARDED_BY(mu_) = 0; int32 example_pos_ GUARDED_BY(mu_); int32 label_pos_ GUARDED_BY(mu_); int32 label_limit_ GUARDED_BY(mu_); // {example_pos_, label_pos_} is the cursor for the next example. // example_pos_ wrapps around at the end of corpus_. For each // example, we randomly generate [label_pos_, label_limit) for // labels. void NextExample(int32* example, int32* label) EXCLUSIVE_LOCKS_REQUIRED(mu_) { while (true) { if (label_pos_ >= label_limit_) { if (example_pos_ + 1 >= corpus_size_) { ++current_epoch_; example_pos_ = 0; } else { ++example_pos_; } ++total_words_processed_; int32 word_freq = freq_.flat()(corpus_[example_pos_]); if (subsample_ > 0) { // See Eq. 5 in http://arxiv.org/abs/1310.4546 float keep_prob = (std::sqrt(word_freq / (subsample_ * corpus_size_)) + 1) * (subsample_ * corpus_size_) / word_freq; if (rng_.RandFloat() > keep_prob) continue; } const int32 skip = 1 + rng_.Uniform(window_size_); label_pos_ = std::max(0, example_pos_ - skip); label_limit_ = std::min(corpus_size_, example_pos_ + skip + 1); } if (example_pos_ != label_pos_) { break; } ++label_pos_; } *example = corpus_[example_pos_]; *label = corpus_[label_pos_++]; } Status Init(Env* env, const string& filename) { string data; TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &data)); RE2 kWord("\\s*(\\S+)"); auto input = ToRegexpStringPiece(data); string w; corpus_size_ = 0; std::unordered_map word_freq; while (RE2::Consume(&input, kWord, &w)) { ++(word_freq[w]); ++corpus_size_; } if (corpus_size_ < window_size_ * 10) { return errors::InvalidArgument("The text file ", filename, " contains too little data: ", corpus_size_, " words"); } typedef std::pair WordFreq; std::vector ordered; for (const auto& p : word_freq) { if (p.second >= min_count_) ordered.push_back(p); } LOG(INFO) << "Data file: " << filename << " contains " << data.size() << " bytes, " << corpus_size_ << " words, " << word_freq.size() << " unique words, " << ordered.size() << " unique frequent words."; word_freq.clear(); std::sort(ordered.begin(), ordered.end(), [](const WordFreq& x, const WordFreq& y) { return x.second > y.second; }); vocab_size_ = static_cast(1 + ordered.size()); Tensor word(DT_STRING, TensorShape({vocab_size_})); Tensor freq(DT_INT32, TensorShape({vocab_size_})); word.flat()(0) = "UNK"; static const int32 kUnkId = 0; std::unordered_map word_id; int64 total_counted = 0; for (std::size_t i = 0; i < ordered.size(); ++i) { const auto& w = ordered[i].first; auto id = i + 1; word.flat()(id) = w; auto word_count = ordered[i].second; freq.flat()(id) = word_count; total_counted += word_count; word_id[w] = id; } freq.flat()(kUnkId) = corpus_size_ - total_counted; word_ = word; freq_ = freq; corpus_.reserve(corpus_size_); input = ToRegexpStringPiece(data); while (RE2::Consume(&input, kWord, &w)) { corpus_.push_back(gtl::FindWithDefault(word_id, w, kUnkId)); } return Status::OK(); } }; REGISTER_KERNEL_BUILDER(Name("Skipgram").Device(DEVICE_CPU), SkipgramOp); class NegTrainOp : public OpKernel { public: explicit NegTrainOp(OpKernelConstruction* ctx) : OpKernel(ctx) { base_.Init(0, 0); OP_REQUIRES_OK(ctx, ctx->GetAttr("num_negative_samples", &num_samples_)); std::vector vocab_count; OP_REQUIRES_OK(ctx, ctx->GetAttr("vocab_count", &vocab_count)); std::vector vocab_weights; vocab_weights.reserve(vocab_count.size()); for (const auto& f : vocab_count) { float r = std::pow(static_cast(f), 0.75f); vocab_weights.push_back(r); } sampler_ = new random::DistributionSampler(vocab_weights); } ~NegTrainOp() { delete sampler_; } void Compute(OpKernelContext* ctx) override { Tensor w_in = ctx->mutable_input(0, false); OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(w_in.shape()), errors::InvalidArgument("Must be a matrix")); Tensor w_out = ctx->mutable_input(1, false); OP_REQUIRES(ctx, w_in.shape() == w_out.shape(), errors::InvalidArgument("w_in.shape == w_out.shape")); const Tensor& examples = ctx->input(2); OP_REQUIRES(ctx, TensorShapeUtils::IsVector(examples.shape()), errors::InvalidArgument("Must be a vector")); const Tensor& labels = ctx->input(3); OP_REQUIRES(ctx, examples.shape() == labels.shape(), errors::InvalidArgument("examples.shape == labels.shape")); const Tensor& learning_rate = ctx->input(4); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(learning_rate.shape()), errors::InvalidArgument("Must be a scalar")); auto Tw_in = w_in.matrix(); auto Tw_out = w_out.matrix(); auto Texamples = examples.flat(); auto Tlabels = labels.flat(); auto lr = learning_rate.scalar()(); const int64 vocab_size = w_in.dim_size(0); const int64 dims = w_in.dim_size(1); const int64 batch_size = examples.dim_size(0); OP_REQUIRES(ctx, vocab_size == sampler_->num(), errors::InvalidArgument("vocab_size mismatches: ", vocab_size, " vs. ", sampler_->num())); // Gradient accumulator for v_in. Tensor buf(DT_FLOAT, TensorShape({dims})); auto Tbuf = buf.flat(); // Scalar buffer to hold sigmoid(+/- dot). Tensor g_buf(DT_FLOAT, TensorShape({})); auto g = g_buf.scalar(); // The following loop needs 2 random 32-bit values per negative // sample. We reserve 8 values per sample just in case the // underlying implementation changes. auto rnd = base_.ReserveSamples32(batch_size * num_samples_ * 8); random::SimplePhilox srnd(&rnd); for (int64 i = 0; i < batch_size; ++i) { const int32 example = Texamples(i); DCHECK(0 <= example && example < vocab_size) << example; const int32 label = Tlabels(i); DCHECK(0 <= label && label < vocab_size) << label; auto v_in = Tw_in.chip<0>(example); // Positive: example predicts label. // forward: x = v_in' * v_out // l = log(sigmoid(x)) // backward: dl/dx = g = sigmoid(-x) // dl/d(v_in) = g * v_out' // dl/d(v_out) = v_in' * g { auto v_out = Tw_out.chip<0>(label); auto dot = (v_in * v_out).sum(); g = (dot.exp() + 1.f).inverse(); Tbuf = v_out * (g() * lr); v_out += v_in * (g() * lr); } // Negative samples: // forward: x = v_in' * v_sample // l = log(sigmoid(-x)) // backward: dl/dx = g = -sigmoid(x) // dl/d(v_in) = g * v_out' // dl/d(v_out) = v_in' * g for (int j = 0; j < num_samples_; ++j) { const int sample = sampler_->Sample(&srnd); if (sample == label) continue; // Skip. auto v_sample = Tw_out.chip<0>(sample); auto dot = (v_in * v_sample).sum(); g = -((-dot).exp() + 1.f).inverse(); Tbuf += v_sample * (g() * lr); v_sample += v_in * (g() * lr); } // Applies the gradient on v_in. v_in += Tbuf; } } private: int32 num_samples_ = 0; random::DistributionSampler* sampler_ = nullptr; GuardedPhiloxRandom base_; }; REGISTER_KERNEL_BUILDER(Name("NegTrain").Device(DEVICE_CPU), NegTrainOp); } // end namespace tensorflow