1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
|
# encoding: utf-8
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import os
import numpy as np
import tensorflow as tf
from tensorflow.contrib import learn
# Get training data
# This dataset can be downloaded from http://www.statmt.org/europarl/v6/fr-en.tgz
ENGLISH_CORPUS = "europarl-v6.fr-en.en"
FRENCH_CORPUS = "europarl-v6.fr-en.fr"
def read_iterator(filename):
f = open(filename)
for line in f:
yield line.strip()
def repeated_read_iterator(filename):
while True:
f = open(filename)
for line in f:
yield line.strip()
def split_train_test(data, partition=0.2, random_seed=42):
rnd = np.random.RandomState(random_seed)
for item in data:
if rnd.uniform() > partition:
yield (0, item)
else:
yield (1, item)
def save_partitions(data, filenames):
files = [open(filename, 'w') for filename in filenames]
for partition, item in data:
files[partition].write(item + '\n')
def loop_iterator(data):
while True:
for item in data:
yield item
if not (os.path.exists('train.data') and os.path.exists('test.data')):
english_data = read_iterator(ENGLISH_CORPUS)
french_data = read_iterator(FRENCH_CORPUS)
parallel_data = ('%s;;;%s' % (eng, fr) for eng, fr in itertools.izip(english_data, french_data))
save_partitions(split_train_test(parallel_data), ['train.data', 'test.data'])
def Xy(data):
def split_lines(data):
for item in data:
yield item.split(';;;')
X, y = itertools.tee(split_lines(data))
return (item[0] for item in X), (item[1] for item in y)
X_train, y_train = Xy(repeated_read_iterator('train.data'))
X_test, y_test = Xy(read_iterator('test.data'))
# Translation model
MAX_DOCUMENT_LENGTH = 30
HIDDEN_SIZE = 100
def translate_model(X, y):
byte_list = learn.ops.one_hot_matrix(X, 256)
in_X, in_y, out_y = learn.ops.seq2seq_inputs(
byte_list, y, MAX_DOCUMENT_LENGTH, MAX_DOCUMENT_LENGTH)
cell = tf.nn.rnn_cell.OutputProjectionWrapper(tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE), 256)
decoding, _, sampling_decoding, _ = learn.ops.rnn_seq2seq(in_X, in_y, cell)
return learn.ops.sequence_classifier(decoding, out_y, sampling_decoding)
vocab_processor = learn.preprocessing.ByteProcessor(
max_document_length=MAX_DOCUMENT_LENGTH)
x_iter = vocab_processor.transform(X_train)
y_iter = vocab_processor.transform(y_train)
xpred = np.array(list(vocab_processor.transform(X_test))[:20])
ygold = list(y_test)[:20]
PATH = '/tmp/tf_examples/ntm/'
if os.path.exists(PATH):
translator = learn.TensorFlowEstimator.restore(PATH)
else:
translator = learn.TensorFlowEstimator(model_fn=translate_model,
n_classes=256,
optimizer='Adam', learning_rate=0.01, batch_size=128,
continue_training=True)
while True:
translator.fit(x_iter, y_iter, logdir=PATH)
translator.save(PATH)
predictions = translator.predict(xpred, axis=2)
xpred_inp = vocab_processor.reverse(xpred)
text_outputs = vocab_processor.reverse(predictions)
for inp_data, input_text, pred, output_text, gold in zip(xpred, xpred_inp,
predictions, text_outputs, ygold):
print('English: %s. French (pred): %s, French (gold): %s' %
(input_text, output_text, gold.decode('utf-8')))
print(inp_data, pred)
|