aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/models/rnn/translate/data_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/models/rnn/translate/data_utils.py')
-rw-r--r--tensorflow/models/rnn/translate/data_utils.py264
1 files changed, 264 insertions, 0 deletions
diff --git a/tensorflow/models/rnn/translate/data_utils.py b/tensorflow/models/rnn/translate/data_utils.py
new file mode 100644
index 0000000000..28bc54354c
--- /dev/null
+++ b/tensorflow/models/rnn/translate/data_utils.py
@@ -0,0 +1,264 @@
+"""Utilities for downloading data from WMT, tokenizing, vocabularies."""
+
+import gzip
+import os
+import re
+import tarfile
+import urllib
+
+from tensorflow.python.platform import gfile
+
+# Special vocabulary symbols - we always put them at the start.
+_PAD = "_PAD"
+_GO = "_GO"
+_EOS = "_EOS"
+_UNK = "_UNK"
+_START_VOCAB = [_PAD, _GO, _EOS, _UNK]
+
+PAD_ID = 0
+GO_ID = 1
+EOS_ID = 2
+UNK_ID = 3
+
+# Regular expressions used to tokenize.
+_WORD_SPLIT = re.compile("([.,!?\"':;)(])")
+_DIGIT_RE = re.compile(r"\d")
+
+# URLs for WMT data.
+_WMT_ENFR_TRAIN_URL = "http://www.statmt.org/wmt10/training-giga-fren.tar"
+_WMT_ENFR_DEV_URL = "http://www.statmt.org/wmt15/dev-v2.tgz"
+
+
+def maybe_download(directory, filename, url):
+ """Download filename from url unless it's already in directory."""
+ if not os.path.exists(directory):
+ print "Creating directory %s" % directory
+ os.mkdir(directory)
+ filepath = os.path.join(directory, filename)
+ if not os.path.exists(filepath):
+ print "Downloading %s to %s" % (url, filepath)
+ filepath, _ = urllib.urlretrieve(url, filepath)
+ statinfo = os.stat(filepath)
+ print "Succesfully downloaded", filename, statinfo.st_size, "bytes"
+ return filepath
+
+
+def gunzip_file(gz_path, new_path):
+ """Unzips from gz_path into new_path."""
+ print "Unpacking %s to %s" % (gz_path, new_path)
+ with gzip.open(gz_path, "rb") as gz_file:
+ with open(new_path, "w") as new_file:
+ for line in gz_file:
+ new_file.write(line)
+
+
+def get_wmt_enfr_train_set(directory):
+ """Download the WMT en-fr training corpus to directory unless it's there."""
+ train_path = os.path.join(directory, "giga-fren.release2")
+ if not (gfile.Exists(train_path +".fr") and gfile.Exists(train_path +".en")):
+ corpus_file = maybe_download(directory, "training-giga-fren.tar",
+ _WMT_ENFR_TRAIN_URL)
+ print "Extracting tar file %s" % corpus_file
+ with tarfile.open(corpus_file, "r") as corpus_tar:
+ corpus_tar.extractall(directory)
+ gunzip_file(train_path + ".fr.gz", train_path + ".fr")
+ gunzip_file(train_path + ".en.gz", train_path + ".en")
+ return train_path
+
+
+def get_wmt_enfr_dev_set(directory):
+ """Download the WMT en-fr training corpus to directory unless it's there."""
+ dev_name = "newstest2013"
+ dev_path = os.path.join(directory, dev_name)
+ if not (gfile.Exists(dev_path + ".fr") and gfile.Exists(dev_path + ".en")):
+ dev_file = maybe_download(directory, "dev-v2.tgz", _WMT_ENFR_DEV_URL)
+ print "Extracting tgz file %s" % dev_file
+ with tarfile.open(dev_file, "r:gz") as dev_tar:
+ fr_dev_file = dev_tar.getmember("dev/" + dev_name + ".fr")
+ en_dev_file = dev_tar.getmember("dev/" + dev_name + ".en")
+ fr_dev_file.name = dev_name + ".fr" # Extract without "dev/" prefix.
+ en_dev_file.name = dev_name + ".en"
+ dev_tar.extract(fr_dev_file, directory)
+ dev_tar.extract(en_dev_file, directory)
+ return dev_path
+
+
+def basic_tokenizer(sentence):
+ """Very basic tokenizer: split the sentence into a list of tokens."""
+ words = []
+ for space_separated_fragment in sentence.strip().split():
+ words.extend(re.split(_WORD_SPLIT, space_separated_fragment))
+ return [w for w in words if w]
+
+
+def create_vocabulary(vocabulary_path, data_path, max_vocabulary_size,
+ tokenizer=None, normalize_digits=True):
+ """Create vocabulary file (if it does not exist yet) from data file.
+
+ Data file is assumed to contain one sentence per line. Each sentence is
+ tokenized and digits are normalized (if normalize_digits is set).
+ Vocabulary contains the most-frequent tokens up to max_vocabulary_size.
+ We write it to vocabulary_path in a one-token-per-line format, so that later
+ token in the first line gets id=0, second line gets id=1, and so on.
+
+ Args:
+ vocabulary_path: path where the vocabulary will be created.
+ data_path: data file that will be used to create vocabulary.
+ max_vocabulary_size: limit on the size of the created vocabulary.
+ tokenizer: a function to use to tokenize each data sentence;
+ if None, basic_tokenizer will be used.
+ normalize_digits: Boolean; if true, all digits are replaced by 0s.
+ """
+ if not gfile.Exists(vocabulary_path):
+ print "Creating vocabulary %s from data %s" % (vocabulary_path, data_path)
+ vocab = {}
+ with gfile.GFile(data_path, mode="r") as f:
+ counter = 0
+ for line in f:
+ counter += 1
+ if counter % 100000 == 0: print " processing line %d" % counter
+ tokens = tokenizer(line) if tokenizer else basic_tokenizer(line)
+ for w in tokens:
+ word = re.sub(_DIGIT_RE, "0", w) if normalize_digits else w
+ if word in vocab:
+ vocab[word] += 1
+ else:
+ vocab[word] = 1
+ vocab_list = _START_VOCAB + sorted(vocab, key=vocab.get, reverse=True)
+ if len(vocab_list) > max_vocabulary_size:
+ vocab_list = vocab_list[:max_vocabulary_size]
+ with gfile.GFile(vocabulary_path, mode="w") as vocab_file:
+ for w in vocab_list:
+ vocab_file.write(w + "\n")
+
+
+def initialize_vocabulary(vocabulary_path):
+ """Initialize vocabulary from file.
+
+ We assume the vocabulary is stored one-item-per-line, so a file:
+ dog
+ cat
+ will result in a vocabulary {"dog": 0, "cat": 1}, and this function will
+ also return the reversed-vocabulary ["dog", "cat"].
+
+ Args:
+ vocabulary_path: path to the file containing the vocabulary.
+
+ Returns:
+ a pair: the vocabulary (a dictionary mapping string to integers), and
+ the reversed vocabulary (a list, which reverses the vocabulary mapping).
+
+ Raises:
+ ValueError: if the provided vocabulary_path does not exist.
+ """
+ if gfile.Exists(vocabulary_path):
+ rev_vocab = []
+ with gfile.GFile(vocabulary_path, mode="r") as f:
+ rev_vocab.extend(f.readlines())
+ rev_vocab = [line.strip() for line in rev_vocab]
+ vocab = dict([(x, y) for (y, x) in enumerate(rev_vocab)])
+ return vocab, rev_vocab
+ else:
+ raise ValueError("Vocabulary file %s not found.", vocabulary_path)
+
+
+def sentence_to_token_ids(sentence, vocabulary,
+ tokenizer=None, normalize_digits=True):
+ """Convert a string to list of integers representing token-ids.
+
+ For example, a sentence "I have a dog" may become tokenized into
+ ["I", "have", "a", "dog"] and with vocabulary {"I": 1, "have": 2,
+ "a": 4, "dog": 7"} this function will return [1, 2, 4, 7].
+
+ Args:
+ sentence: a string, the sentence to convert to token-ids.
+ vocabulary: a dictionary mapping tokens to integers.
+ tokenizer: a function to use to tokenize each sentence;
+ if None, basic_tokenizer will be used.
+ normalize_digits: Boolean; if true, all digits are replaced by 0s.
+
+ Returns:
+ a list of integers, the token-ids for the sentence.
+ """
+ if tokenizer:
+ words = tokenizer(sentence)
+ else:
+ words = basic_tokenizer(sentence)
+ if not normalize_digits:
+ return [vocabulary.get(w, UNK_ID) for w in words]
+ # Normalize digits by 0 before looking words up in the vocabulary.
+ return [vocabulary.get(re.sub(_DIGIT_RE, "0", w), UNK_ID) for w in words]
+
+
+def data_to_token_ids(data_path, target_path, vocabulary_path,
+ tokenizer=None, normalize_digits=True):
+ """Tokenize data file and turn into token-ids using given vocabulary file.
+
+ This function loads data line-by-line from data_path, calls the above
+ sentence_to_token_ids, and saves the result to target_path. See comment
+ for sentence_to_token_ids on the details of token-ids format.
+
+ Args:
+ data_path: path to the data file in one-sentence-per-line format.
+ target_path: path where the file with token-ids will be created.
+ vocabulary_path: path to the vocabulary file.
+ tokenizer: a function to use to tokenize each sentence;
+ if None, basic_tokenizer will be used.
+ normalize_digits: Boolean; if true, all digits are replaced by 0s.
+ """
+ if not gfile.Exists(target_path):
+ print "Tokenizing data in %s" % data_path
+ vocab, _ = initialize_vocabulary(vocabulary_path)
+ with gfile.GFile(data_path, mode="r") as data_file:
+ with gfile.GFile(target_path, mode="w") as tokens_file:
+ counter = 0
+ for line in data_file:
+ counter += 1
+ if counter % 100000 == 0: print " tokenizing line %d" % counter
+ token_ids = sentence_to_token_ids(line, vocab, tokenizer,
+ normalize_digits)
+ tokens_file.write(" ".join([str(tok) for tok in token_ids]) + "\n")
+
+
+def prepare_wmt_data(data_dir, en_vocabulary_size, fr_vocabulary_size):
+ """Get WMT data into data_dir, create vocabularies and tokenize data.
+
+ Args:
+ data_dir: directory in which the data sets will be stored.
+ en_vocabulary_size: size of the English vocabulary to create and use.
+ fr_vocabulary_size: size of the French vocabulary to create and use.
+
+ Returns:
+ A tuple of 6 elements:
+ (1) path to the token-ids for English training data-set,
+ (2) path to the token-ids for French training data-set,
+ (3) path to the token-ids for English development data-set,
+ (4) path to the token-ids for French development data-set,
+ (5) path to the English vocabulary file,
+ (6) path to the French vocabluary file.
+ """
+ # Get wmt data to the specified directory.
+ train_path = get_wmt_enfr_train_set(data_dir)
+ dev_path = get_wmt_enfr_dev_set(data_dir)
+
+ # Create vocabularies of the appropriate sizes.
+ fr_vocab_path = os.path.join(data_dir, "vocab%d.fr" % fr_vocabulary_size)
+ en_vocab_path = os.path.join(data_dir, "vocab%d.en" % en_vocabulary_size)
+ create_vocabulary(fr_vocab_path, train_path + ".fr", fr_vocabulary_size)
+ create_vocabulary(en_vocab_path, train_path + ".en", en_vocabulary_size)
+
+ # Create token ids for the training data.
+ fr_train_ids_path = train_path + (".ids%d.fr" % fr_vocabulary_size)
+ en_train_ids_path = train_path + (".ids%d.en" % en_vocabulary_size)
+ data_to_token_ids(train_path + ".fr", fr_train_ids_path, fr_vocab_path)
+ data_to_token_ids(train_path + ".en", fr_train_ids_path, fr_vocab_path)
+
+ # Create token ids for the development data.
+ fr_dev_ids_path = dev_path + (".ids%d.fr" % fr_vocabulary_size)
+ en_dev_ids_path = dev_path + (".ids%d.en" % en_vocabulary_size)
+ data_to_token_ids(dev_path + ".fr", fr_dev_ids_path, fr_vocab_path)
+ data_to_token_ids(dev_path + ".en", en_dev_ids_path, en_vocab_path)
+
+ return (en_train_ids_path, fr_train_ids_path,
+ en_dev_ids_path, fr_dev_ids_path,
+ en_vocab_path, fr_vocab_path)