blob: abe03baaf482fcd7e8278b70eb1e853e97ad565b (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
|
#include "tensorflow/core/framework/op.h"
namespace tensorflow {
REGISTER_OP("Skipgram")
.Output("vocab_word: string")
.Output("vocab_freq: int32")
.Output("words_per_epoch: int64")
.Output("current_epoch: int32")
.Output("total_words_processed: int64")
.Output("examples: int32")
.Output("labels: int32")
.Attr("filename: string")
.Attr("batch_size: int")
.Attr("window_size: int = 5")
.Attr("min_count: int = 5")
.Attr("subsample: float = 1e-3")
.Doc(R"doc(
Parses a text file and creates a batch of examples.
vocab_word: A vector of words in the corpus.
vocab_freq: Frequencies of words. Sorted in the non-ascending order.
words_per_epoch: Number of words per epoch in the data file.
current_epoch: The current epoch number.
total_words_processed: The total number of words processed so far.
examples: A vector of word ids.
labels: A vector of word ids.
filename: The corpus's text file name.
batch_size: The size of produced batch.
window_size: The number of words to predict to the left and right of the target.
min_count: The minimum number of word occurrences for it to be included in the
vocabulary.
subsample: Threshold for word occurrence. Words that appear with higher
frequency will be randomly down-sampled. Set to 0 to disable.
)doc");
REGISTER_OP("NegTrain")
.Input("w_in: Ref(float)")
.Input("w_out: Ref(float)")
.Input("examples: int32")
.Input("labels: int32")
.Input("lr: float")
.Attr("vocab_count: list(int)")
.Attr("num_negative_samples: int")
.Doc(R"doc(
Training via negative sampling.
w_in: input word embedding.
w_out: output word embedding.
examples: A vector of word ids.
labels: A vector of word ids.
vocab_count: Count of words in the vocabulary.
num_negative_samples: Number of negative samples per exaple.
)doc");
} // end namespace tensorflow
|