aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/models/embedding/word2vec_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/models/embedding/word2vec_ops.cc')
-rw-r--r--tensorflow/models/embedding/word2vec_ops.cc56
1 files changed, 56 insertions, 0 deletions
diff --git a/tensorflow/models/embedding/word2vec_ops.cc b/tensorflow/models/embedding/word2vec_ops.cc
new file mode 100644
index 0000000000..abe03baaf4
--- /dev/null
+++ b/tensorflow/models/embedding/word2vec_ops.cc
@@ -0,0 +1,56 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("Skipgram")
+ .Output("vocab_word: string")
+ .Output("vocab_freq: int32")
+ .Output("words_per_epoch: int64")
+ .Output("current_epoch: int32")
+ .Output("total_words_processed: int64")
+ .Output("examples: int32")
+ .Output("labels: int32")
+ .Attr("filename: string")
+ .Attr("batch_size: int")
+ .Attr("window_size: int = 5")
+ .Attr("min_count: int = 5")
+ .Attr("subsample: float = 1e-3")
+ .Doc(R"doc(
+Parses a text file and creates a batch of examples.
+
+vocab_word: A vector of words in the corpus.
+vocab_freq: Frequencies of words. Sorted in the non-ascending order.
+words_per_epoch: Number of words per epoch in the data file.
+current_epoch: The current epoch number.
+total_words_processed: The total number of words processed so far.
+examples: A vector of word ids.
+labels: A vector of word ids.
+filename: The corpus's text file name.
+batch_size: The size of produced batch.
+window_size: The number of words to predict to the left and right of the target.
+min_count: The minimum number of word occurrences for it to be included in the
+ vocabulary.
+subsample: Threshold for word occurrence. Words that appear with higher
+ frequency will be randomly down-sampled. Set to 0 to disable.
+)doc");
+
+REGISTER_OP("NegTrain")
+ .Input("w_in: Ref(float)")
+ .Input("w_out: Ref(float)")
+ .Input("examples: int32")
+ .Input("labels: int32")
+ .Input("lr: float")
+ .Attr("vocab_count: list(int)")
+ .Attr("num_negative_samples: int")
+ .Doc(R"doc(
+Training via negative sampling.
+
+w_in: input word embedding.
+w_out: output word embedding.
+examples: A vector of word ids.
+labels: A vector of word ids.
+vocab_count: Count of words in the vocabulary.
+num_negative_samples: Number of negative samples per exaple.
+)doc");
+
+} // end namespace tensorflow