diff options
Diffstat (limited to 'tensorflow/models/embedding/word2vec_ops.cc')
-rw-r--r-- | tensorflow/models/embedding/word2vec_ops.cc | 56 |
1 files changed, 56 insertions, 0 deletions
diff --git a/tensorflow/models/embedding/word2vec_ops.cc b/tensorflow/models/embedding/word2vec_ops.cc new file mode 100644 index 0000000000..abe03baaf4 --- /dev/null +++ b/tensorflow/models/embedding/word2vec_ops.cc @@ -0,0 +1,56 @@ +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("Skipgram") + .Output("vocab_word: string") + .Output("vocab_freq: int32") + .Output("words_per_epoch: int64") + .Output("current_epoch: int32") + .Output("total_words_processed: int64") + .Output("examples: int32") + .Output("labels: int32") + .Attr("filename: string") + .Attr("batch_size: int") + .Attr("window_size: int = 5") + .Attr("min_count: int = 5") + .Attr("subsample: float = 1e-3") + .Doc(R"doc( +Parses a text file and creates a batch of examples. + +vocab_word: A vector of words in the corpus. +vocab_freq: Frequencies of words. Sorted in the non-ascending order. +words_per_epoch: Number of words per epoch in the data file. +current_epoch: The current epoch number. +total_words_processed: The total number of words processed so far. +examples: A vector of word ids. +labels: A vector of word ids. +filename: The corpus's text file name. +batch_size: The size of produced batch. +window_size: The number of words to predict to the left and right of the target. +min_count: The minimum number of word occurrences for it to be included in the + vocabulary. +subsample: Threshold for word occurrence. Words that appear with higher + frequency will be randomly down-sampled. Set to 0 to disable. +)doc"); + +REGISTER_OP("NegTrain") + .Input("w_in: Ref(float)") + .Input("w_out: Ref(float)") + .Input("examples: int32") + .Input("labels: int32") + .Input("lr: float") + .Attr("vocab_count: list(int)") + .Attr("num_negative_samples: int") + .Doc(R"doc( +Training via negative sampling. + +w_in: input word embedding. +w_out: output word embedding. +examples: A vector of word ids. +labels: A vector of word ids. +vocab_count: Count of words in the vocabulary. +num_negative_samples: Number of negative samples per exaple. +)doc"); + +} // end namespace tensorflow |