diff options
Diffstat (limited to 'tensorflow/models')
21 files changed, 3097 insertions, 0 deletions
diff --git a/tensorflow/models/embedding/BUILD b/tensorflow/models/embedding/BUILD new file mode 100644 index 0000000000..0fb164b05e --- /dev/null +++ b/tensorflow/models/embedding/BUILD @@ -0,0 +1,74 @@ +# Description: +# TensorFlow model for word2vec + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("/tensorflow/tensorflow", "tf_gen_op_wrapper_py") + +py_binary( + name = "word2vec", + srcs = [ + "word2vec.py", + ], + deps = [ + ":gen_word2vec", + "//tensorflow:tensorflow_py", + "//tensorflow/python:platform", + ], +) + +py_binary( + name = "word2vec_optimized", + srcs = [ + "word2vec_optimized.py", + ], + deps = [ + ":gen_word2vec", + "//tensorflow:tensorflow_py", + "//tensorflow/python:platform", + ], +) + +cc_library( + name = "word2vec_ops", + srcs = [ + "word2vec_ops.cc", + ], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/core:framework", + ], + alwayslink = 1, +) + +cc_library( + name = "word2vec_kernels", + srcs = [ + "word2vec_kernels.cc", + ], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/core", + ], + alwayslink = 1, +) + +tf_gen_op_wrapper_py( + name = "gen_word2vec", + out = "gen_word2vec.py", + deps = [":word2vec_ops"], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/models/embedding/__init__.py b/tensorflow/models/embedding/__init__.py new file mode 100755 index 0000000000..e69de29bb2 --- /dev/null +++ b/tensorflow/models/embedding/__init__.py diff --git a/tensorflow/models/embedding/word2vec.py b/tensorflow/models/embedding/word2vec.py new file mode 100644 index 0000000000..4ebf3d6f27 --- /dev/null +++ b/tensorflow/models/embedding/word2vec.py @@ -0,0 +1,503 @@ +"""Multi-threaded word2vec mini-batched skip-gram model. + +Trains the model described in: +(Mikolov, et. al.) Efficient Estimation of Word Representations in Vector Space +ICLR 2013. +http://arxiv.org/abs/1301.3781 +This model does traditional minibatching. + +The key ops used are: +* placeholder for feeding in tensors for each example. +* embedding_lookup for fetching rows from the embedding matrix. +* sigmoid_cross_entropy_with_logits to calculate the loss. +* GradientDescentOptimizer for optimizing the loss. +* skipgram custom op that does input processing. +""" + +import sys +import threading +import time + +import tensorflow.python.platform + +import numpy as np +import tensorflow as tf + +from tensorflow.models.embedding import gen_word2vec as word2vec + +flags = tf.app.flags + +flags.DEFINE_string("save_path", None, "Directory to write the model and " + "training summaries.") +flags.DEFINE_string("train_data", None, "Training text file. " + "E.g., unzipped file http://mattmahoney.net/dc/text8.zip.") +flags.DEFINE_string( + "eval_data", None, "File consisting of analogies of four tokens." + "embedding 2 - embedding 1 + embedding 3 should be close " + "to embedding 4." + "E.g. https://word2vec.googlecode.com/svn/trunk/questions-words.txt.") +flags.DEFINE_integer("embedding_size", 200, "The embedding dimension size.") +flags.DEFINE_integer( + "epochs_to_train", 15, + "Number of epochs to train. Each epoch processes the training data once " + "completely.") +flags.DEFINE_float("learning_rate", 0.2, "Initial learning rate.") +flags.DEFINE_integer("num_neg_samples", 100, + "Negative samples per training example.") +flags.DEFINE_integer("batch_size", 16, + "Number of training examples processed per step " + "(size of a minibatch).") +flags.DEFINE_integer("concurrent_steps", 12, + "The number of concurrent training steps.") +flags.DEFINE_integer("window_size", 5, + "The number of words to predict to the left and right " + "of the target word.") +flags.DEFINE_integer("min_count", 5, + "The minimum number of word occurrences for it to be " + "included in the vocabulary.") +flags.DEFINE_float("subsample", 1e-3, + "Subsample threshold for word occurrence. Words that appear " + "with higher frequency will be randomly down-sampled. Set " + "to 0 to disable.") +flags.DEFINE_boolean( + "interactive", False, + "If true, enters an IPython interactive session to play with the trained " + "model. E.g., try model.analogy('france', 'paris', 'russia') and " + "model.nearby(['proton', 'elephant', 'maxwell']") +flags.DEFINE_integer("statistics_interval", 5, + "Print statistics every n seconds.") +flags.DEFINE_integer("summary_interval", 5, + "Save training summary to file every n seconds (rounded " + "up to statistics interval.") +flags.DEFINE_integer("checkpoint_interval", 600, + "Checkpoint the model (i.e. save the parameters) every n " + "seconds (rounded up to statistics interval.") + +FLAGS = flags.FLAGS + + +class Options(object): + """Options used by our word2vec model.""" + + def __init__(self): + # Model options. + + # Embedding dimension. + self.emb_dim = FLAGS.embedding_size + + # Training options. + # The training text file. + self.train_data = FLAGS.train_data + + # Number of negative samples per example. + self.num_samples = FLAGS.num_neg_samples + + # The initial learning rate. + self.learning_rate = FLAGS.learning_rate + + # Number of epochs to train. After these many epochs, the learning + # rate decays linearly to zero and the training stops. + self.epochs_to_train = FLAGS.epochs_to_train + + # Concurrent training steps. + self.concurrent_steps = FLAGS.concurrent_steps + + # Number of examples for one training step. + self.batch_size = FLAGS.batch_size + + # The number of words to predict to the left and right of the target word. + self.window_size = FLAGS.window_size + + # The minimum number of word occurrences for it to be included in the + # vocabulary. + self.min_count = FLAGS.min_count + + # Subsampling threshold for word occurrence. + self.subsample = FLAGS.subsample + + # How often to print statistics. + self.statistics_interval = FLAGS.statistics_interval + + # How often to write to the summary file (rounds up to the nearest + # statistics_interval). + self.summary_interval = FLAGS.summary_interval + + # How often to write checkpoints (rounds up to the nearest statistics + # interval). + self.checkpoint_interval = FLAGS.checkpoint_interval + + # Where to write out summaries. + self.save_path = FLAGS.save_path + + # Eval options. + # The text file for eval. + self.eval_data = FLAGS.eval_data + + +class Word2Vec(object): + """Word2Vec model (Skipgram).""" + + def __init__(self, options, session): + self._options = options + self._session = session + self._word2id = {} + self._id2word = [] + self.build_graph() + self.build_eval_graph() + self.save_vocab() + self._read_analogies() + + def _read_analogies(self): + """Reads through the analogy question file. + + Returns: + questions: a [n, 4] numpy array containing the analogy question's + word ids. + questions_skipped: questions skipped due to unknown words. + """ + questions = [] + questions_skipped = 0 + with open(self._options.eval_data) as analogy_f: + for line in analogy_f: + if line.startswith(":"): # Skip comments. + continue + words = line.strip().lower().split(" ") + ids = [self._word2id.get(w.strip()) for w in words] + if None in ids or len(ids) != 4: + questions_skipped += 1 + else: + questions.append(np.array(ids)) + print "Eval analogy file: ", self._options.eval_data + print "Questions: ", len(questions) + print "Skipped: ", questions_skipped + self._analogy_questions = np.array(questions, dtype=np.int32) + + def forward(self, examples, labels): + """Build the graph for the forward pass.""" + opts = self._options + + # Declare all variables we need. + # Embedding: [vocab_size, emb_dim] + init_width = 0.5 / opts.emb_dim + emb = tf.Variable( + tf.random_uniform( + [opts.vocab_size, opts.emb_dim], -init_width, init_width), + name="emb") + self._emb = emb + + # Softmax weight: [vocab_size, emb_dim]. Transposed. + sm_w_t = tf.Variable( + tf.zeros([opts.vocab_size, opts.emb_dim]), + name="sm_w_t") + + # Softmax bias: [emb_dim]. + sm_b = tf.Variable(tf.zeros([opts.vocab_size]), name="sm_b") + + # Global step: scalar, i.e., shape []. + self.global_step = tf.Variable(0, name="global_step") + + # Nodes to compute the nce loss w/ candidate sampling. + labels_matrix = tf.reshape( + tf.cast(labels, + dtype=tf.int64), + [opts.batch_size, 1]) + + # Negative sampling. + sampled_ids, _, _ = (tf.nn.fixed_unigram_candidate_sampler( + true_classes=labels_matrix, + num_true=1, + num_sampled=opts.num_samples, + unique=True, + range_max=opts.vocab_size, + distortion=0.75, + unigrams=opts.vocab_counts.tolist())) + + # Embeddings for examples: [batch_size, emb_dim] + example_emb = tf.nn.embedding_lookup(emb, examples) + + # Weights for labels: [batch_size, emb_dim] + true_w = tf.nn.embedding_lookup(sm_w_t, labels) + # Biases for labels: [batch_size, 1] + true_b = tf.nn.embedding_lookup(sm_b, labels) + + # Weights for sampled ids: [num_sampled, emb_dim] + sampled_w = tf.nn.embedding_lookup(sm_w_t, sampled_ids) + # Biases for sampled ids: [num_sampled, 1] + sampled_b = tf.nn.embedding_lookup(sm_b, sampled_ids) + + # True logits: [batch_size, 1] + true_logits = tf.reduce_sum(tf.mul(example_emb, true_w), 1) + true_b + + # Sampled logits: [batch_size, num_sampled] + # We replicate sampled noise lables for all examples in the batch + # using the matmul. + sampled_b_vec = tf.reshape(sampled_b, [opts.num_samples]) + sampled_logits = tf.matmul(example_emb, + sampled_w, + transpose_b=True) + sampled_b_vec + return true_logits, sampled_logits + + def nce_loss(self, true_logits, sampled_logits): + """Build the graph for the NCE loss.""" + + # cross-entropy(logits, labels) + opts = self._options + true_xent = tf.nn.sigmoid_cross_entropy_with_logits( + true_logits, tf.ones_like(true_logits)) + sampled_xent = tf.nn.sigmoid_cross_entropy_with_logits( + sampled_logits, tf.zeros_like(sampled_logits)) + + # NCE-loss is the sum of the true and noise (sampled words) + # contributions, averaged over the batch. + nce_loss_tensor = (tf.reduce_sum(true_xent) + + tf.reduce_sum(sampled_xent)) / opts.batch_size + return nce_loss_tensor + + def optimize(self, loss): + """Build the graph to optimize the loss function.""" + + # Optimizer nodes. + # Linear learning rate decay. + opts = self._options + words_to_train = float(opts.words_per_epoch * opts.epochs_to_train) + lr = opts.learning_rate * tf.maximum( + 0.0001, 1.0 - tf.cast(self._words, tf.float32) / words_to_train) + self._lr = lr + optimizer = tf.train.GradientDescentOptimizer(lr) + train = optimizer.minimize(loss, + global_step=self.global_step, + gate_gradients=optimizer.GATE_NONE) + self._train = train + + def build_eval_graph(self): + """Build the eval graph.""" + # Eval graph + + # Each analogy task is to predict the 4th word (d) given three + # words: a, b, c. E.g., a=italy, b=rome, c=france, we should + # predict d=paris. + + # The eval feeds three vectors of word ids for a, b, c, each of + # which is of size N, where N is the number of analogies we want to + # evaluate in one batch. + analogy_a = tf.placeholder(dtype=tf.int32) # [N] + analogy_b = tf.placeholder(dtype=tf.int32) # [N] + analogy_c = tf.placeholder(dtype=tf.int32) # [N] + + # Normalized word embeddings of shape [vocab_size, emb_dim]. + nemb = tf.nn.l2_normalize(self._emb, 1) + + # Each row of a_emb, b_emb, c_emb is a word's embedding vector. + # They all have the shape [N, emb_dim] + a_emb = tf.gather(nemb, analogy_a) # a's embs + b_emb = tf.gather(nemb, analogy_b) # b's embs + c_emb = tf.gather(nemb, analogy_c) # c's embs + + # We expect that d's embedding vectors on the unit hyper-sphere is + # near: c_emb + (b_emb - a_emb), which has the shape [N, emb_dim]. + target = c_emb + (b_emb - a_emb) + + # Compute cosine distance between each pair of target and vocab. + # dist has shape [N, vocab_size]. + dist = tf.matmul(target, nemb, transpose_b=True) + + # For each question (row in dist), find the top 4 words. + _, pred_idx = tf.nn.top_k(dist, 4) + + # Nodes for computing neighbors for a given word according to + # their cosine distance. + nearby_word = tf.placeholder(dtype=tf.int32) # word id + nearby_emb = tf.gather(nemb, nearby_word) + nearby_dist = tf.matmul(nearby_emb, nemb, transpose_b=True) + nearby_val, nearby_idx = tf.nn.top_k(nearby_dist, + min(1000, self._options.vocab_size)) + + # Nodes in the construct graph which are used by training and + # evaluation to run/feed/fetch. + self._analogy_a = analogy_a + self._analogy_b = analogy_b + self._analogy_c = analogy_c + self._analogy_pred_idx = pred_idx + self._nearby_word = nearby_word + self._nearby_val = nearby_val + self._nearby_idx = nearby_idx + + def build_graph(self): + """Build the graph for the full model.""" + opts = self._options + # The training data. A text file. + (words, counts, words_per_epoch, self._epoch, self._words, examples, + labels) = word2vec.skipgram(filename=opts.train_data, + batch_size=opts.batch_size, + window_size=opts.window_size, + min_count=opts.min_count, + subsample=opts.subsample) + (opts.vocab_words, opts.vocab_counts, + opts.words_per_epoch) = self._session.run([words, counts, words_per_epoch]) + opts.vocab_size = len(opts.vocab_words) + print "Data file: ", opts.train_data + print "Vocab size: ", opts.vocab_size - 1, " + UNK" + print "Words per epoch: ", opts.words_per_epoch + self._examples = examples + self._labels = labels + self._id2word = opts.vocab_words + for i, w in enumerate(self._id2word): + self._word2id[w] = i + true_logits, sampled_logits = self.forward(examples, labels) + loss = self.nce_loss(true_logits, sampled_logits) + tf.scalar_summary("NCE loss", loss) + self._loss = loss + self.optimize(loss) + + # Properly initialize all variables. + tf.initialize_all_variables().run() + + self.saver = tf.train.Saver() + + def save_vocab(self): + """Save the vocabulary to a file so the model can be reloaded.""" + opts = self._options + with open(opts.save_path + "/vocab.txt", "w") as f: + for i in xrange(opts.vocab_size): + f.write(opts.vocab_words[i] + " " + str(opts.vocab_counts[i]) + "\n") + + def _train_thread_body(self): + initial_epoch, = self._session.run([self._epoch]) + while True: + _, epoch = self._session.run([self._train, self._epoch]) + if epoch != initial_epoch: + break + + def train(self): + """Train the model.""" + opts = self._options + + initial_epoch, initial_words = self._session.run([self._epoch, self._words]) + + summary_op = tf.merge_all_summaries() + summary_writer = tf.train.SummaryWriter(opts.save_path, + graph_def=self._session.graph_def) + workers = [] + for _ in xrange(opts.concurrent_steps): + t = threading.Thread(target=self._train_thread_body) + t.start() + workers.append(t) + + last_words, last_time, last_summary_time = initial_words, time.time(), 0 + last_checkpoint_time = 0 + while True: + time.sleep(opts.statistics_interval) # Reports our progress once a while. + (epoch, step, loss, words, lr) = self._session.run( + [self._epoch, self.global_step, self._loss, self._words, self._lr]) + now = time.time() + last_words, last_time, rate = words, now, (words - last_words) / ( + now - last_time) + print("Epoch %4d Step %8d: lr = %5.3f loss = %6.2f words/sec = %8.0f\r" % + (epoch, step, lr, loss, rate)), + sys.stdout.flush() + if now - last_summary_time > opts.summary_interval: + summary_str = self._session.run(summary_op) + summary_writer.add_summary(summary_str, step) + last_summary_time = now + if now - last_checkpoint_time > opts.checkpoint_interval: + self.saver.save(self._session, + opts.save_path + "model", + global_step=step) + last_checkpoint_time = now + if epoch != initial_epoch: + break + + for t in workers: + t.join() + + return epoch + + def _predict(self, analogy): + """Predict the top 4 answers for analogy questions.""" + idx, = self._session.run([self._analogy_pred_idx], { + self._analogy_a: analogy[:, 0], + self._analogy_b: analogy[:, 1], + self._analogy_c: analogy[:, 2] + }) + return idx + + def eval(self): + """Evaluate analogy questions and reports accuracy.""" + + # How many questions we get right at precision@1. + correct = 0 + + total = self._analogy_questions.shape[0] + start = 0 + while start < total: + limit = start + 2500 + sub = self._analogy_questions[start:limit, :] + idx = self._predict(sub) + start = limit + for question in xrange(sub.shape[0]): + for j in xrange(4): + if idx[question, j] == sub[question, 3]: + # Bingo! We predicted correctly. E.g., [italy, rome, france, paris]. + correct += 1 + break + elif idx[question, j] in sub[question, :3]: + # We need to skip words already in the question. + continue + else: + # The correct label is not the precision@1 + break + print + print "Eval %4d/%d accuracy = %4.1f%%" % (correct, total, + correct * 100.0 / total) + + def analogy(self, w0, w1, w2): + """Predict word w3 as in w0:w1 vs w2:w3.""" + wid = np.array([[self._word2id.get(w, 0) for w in [w0, w1, w2]]]) + idx = self._predict(wid) + for c in [self._id2word[i] for i in idx[0, :]]: + if c not in [w0, w1, w2]: + return c + return "unknown" + + def nearby(self, words, num=20): + """Prints out nearby words given a list of words.""" + ids = np.array([self._word2id.get(x, 0) for x in words]) + vals, idx = self._session.run( + [self._nearby_val, self._nearby_idx], {self._nearby_word: ids}) + for i in xrange(len(words)): + print "\n%s\n=====================================" % (words[i]) + for (neighbor, distance) in zip(idx[i, :num], vals[i, :num]): + print "%-20s %6.4f" % (self._id2word[neighbor], distance) + + +def _start_shell(local_ns=None): + # An interactive shell is useful for debugging/development. + import IPython + user_ns = {} + if local_ns: + user_ns.update(local_ns) + user_ns.update(globals()) + IPython.start_ipython(argv=[], user_ns=user_ns) + + +def main(_): + """Train a word2vec model.""" + opts = Options() + with tf.Graph().as_default(), tf.Session() as session: + model = Word2Vec(opts, session) + for _ in xrange(opts.epochs_to_train): + model.train() # Process one epoch + model.eval() # Eval analogies. + # Perform a final save. + model.saver.save(session, + opts.save_path + "model", + global_step=model.global_step) + if FLAGS.interactive: + # E.g., + # [0]: model.analogy('france', 'paris', 'russia') + # [1]: model.nearby(['proton', 'elephant', 'maxwell']) + _start_shell(locals()) + + +if __name__ == "__main__": + tf.app.run() diff --git a/tensorflow/models/embedding/word2vec_kernels.cc b/tensorflow/models/embedding/word2vec_kernels.cc new file mode 100644 index 0000000000..f68139fc91 --- /dev/null +++ b/tensorflow/models/embedding/word2vec_kernels.cc @@ -0,0 +1,287 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/random/distribution_sampler.h" +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/regexp.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/util/guarded_philox_random.h" + +namespace tensorflow { + +class SkipgramOp : public OpKernel { + public: + explicit SkipgramOp(OpKernelConstruction* ctx) + : OpKernel(ctx), rng_(&philox_) { + string filename; + OP_REQUIRES_OK(ctx, ctx->GetAttr("filename", &filename)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("batch_size", &batch_size_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("window_size", &window_size_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("min_count", &min_count_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("subsample", &subsample_)); + OP_REQUIRES_OK(ctx, Init(ctx->env(), filename)); + + mutex_lock l(mu_); + example_pos_ = corpus_size_; + label_pos_ = corpus_size_; + label_limit_ = corpus_size_; + } + + void Compute(OpKernelContext* ctx) override { + Tensor words_per_epoch(DT_INT64, TensorShape({})); + Tensor current_epoch(DT_INT32, TensorShape({})); + Tensor total_words_processed(DT_INT64, TensorShape({})); + Tensor examples(DT_INT32, TensorShape({batch_size_})); + auto Texamples = examples.flat<int32>(); + Tensor labels(DT_INT32, TensorShape({batch_size_})); + auto Tlabels = labels.flat<int32>(); + { + mutex_lock l(mu_); + for (int i = 0; i < batch_size_; ++i) { + NextExample(&Texamples(i), &Tlabels(i)); + } + words_per_epoch.scalar<int64>()() = corpus_size_; + current_epoch.scalar<int32>()() = current_epoch_; + total_words_processed.scalar<int64>()() = total_words_processed_; + } + ctx->set_output(0, word_); + ctx->set_output(1, freq_); + ctx->set_output(2, words_per_epoch); + ctx->set_output(3, current_epoch); + ctx->set_output(4, total_words_processed); + ctx->set_output(5, examples); + ctx->set_output(6, labels); + } + + private: + int32 batch_size_ = 0; + int32 window_size_ = 5; + float subsample_ = 1e-3; + int min_count_ = 5; + int32 vocab_size_ = 0; + Tensor word_; + Tensor freq_; + int32 corpus_size_ = 0; + std::vector<int32> corpus_; + + mutex mu_; + random::PhiloxRandom philox_ GUARDED_BY(mu_); + random::SimplePhilox rng_ GUARDED_BY(mu_); + int32 current_epoch_ GUARDED_BY(mu_) = -1; + int64 total_words_processed_ GUARDED_BY(mu_) = 0; + int32 example_pos_ GUARDED_BY(mu_); + int32 label_pos_ GUARDED_BY(mu_); + int32 label_limit_ GUARDED_BY(mu_); + + // {example_pos_, label_pos_} is the cursor for the next example. + // example_pos_ wrapps around at the end of corpus_. For each + // example, we randomly generate [label_pos_, label_limit) for + // labels. + void NextExample(int32* example, int32* label) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + while (true) { + if (label_pos_ >= label_limit_) { + if (example_pos_ + 1 >= corpus_size_) { + ++current_epoch_; + example_pos_ = 0; + } else { + ++example_pos_; + } + ++total_words_processed_; + int32 word_freq = freq_.flat<int32>()(corpus_[example_pos_]); + if (subsample_ > 0) { + // See Eq. 5 in http://arxiv.org/abs/1310.4546 + float keep_prob = + (std::sqrt(word_freq / (subsample_ * corpus_size_)) + 1) * + (subsample_ * corpus_size_) / word_freq; + if (rng_.RandFloat() > keep_prob) continue; + } + const int32 skip = 1 + rng_.Uniform(window_size_); + label_pos_ = std::max<int32>(0, example_pos_ - skip); + label_limit_ = std::min<int32>(corpus_size_, example_pos_ + skip + 1); + } + if (example_pos_ != label_pos_) { + break; + } + ++label_pos_; + } + *example = corpus_[example_pos_]; + *label = corpus_[label_pos_++]; + } + + Status Init(Env* env, const string& filename) { + string data; + TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &data)); + RE2 kWord("\\s*(\\S+)"); + auto input = ToRegexpStringPiece(data); + string w; + corpus_size_ = 0; + std::unordered_map<string, int32> word_freq; + while (RE2::Consume(&input, kWord, &w)) { + ++(word_freq[w]); + ++corpus_size_; + } + if (corpus_size_ < window_size_ * 10) { + return errors::InvalidArgument("The text file ", filename, + " contains too little data: ", + corpus_size_, " words"); + } + typedef std::pair<string, int32> WordFreq; + std::vector<WordFreq> ordered; + for (const auto& p : word_freq) { + if (p.second >= min_count_) ordered.push_back(p); + } + LOG(INFO) << "Data file: " << filename << " contains " << data.size() + << " bytes, " << corpus_size_ << " words, " << word_freq.size() + << " unique words, " << ordered.size() + << " unique frequent words."; + word_freq.clear(); + std::sort(ordered.begin(), ordered.end(), + [](const WordFreq& x, const WordFreq& y) { + return x.second > y.second; + }); + vocab_size_ = static_cast<int32>(1 + ordered.size()); + Tensor word(DT_STRING, TensorShape({vocab_size_})); + Tensor freq(DT_INT32, TensorShape({vocab_size_})); + word.flat<string>()(0) = "UNK"; + static const int32 kUnkId = 0; + std::unordered_map<string, int32> word_id; + int64 total_counted = 0; + for (std::size_t i = 0; i < ordered.size(); ++i) { + const auto& w = ordered[i].first; + auto id = i + 1; + word.flat<string>()(id) = w; + auto word_count = ordered[i].second; + freq.flat<int32>()(id) = word_count; + total_counted += word_count; + word_id[w] = id; + } + freq.flat<int32>()(kUnkId) = corpus_size_ - total_counted; + word_ = word; + freq_ = freq; + corpus_.reserve(corpus_size_); + input = ToRegexpStringPiece(data); + while (RE2::Consume(&input, kWord, &w)) { + corpus_.push_back(gtl::FindWithDefault(word_id, w, kUnkId)); + } + return Status::OK(); + } +}; + +REGISTER_KERNEL_BUILDER(Name("Skipgram").Device(DEVICE_CPU), SkipgramOp); + +class NegTrainOp : public OpKernel { + public: + explicit NegTrainOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + base_.Init(0, 0); + + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_negative_samples", &num_samples_)); + + std::vector<int32> vocab_count; + OP_REQUIRES_OK(ctx, ctx->GetAttr("vocab_count", &vocab_count)); + + std::vector<float> vocab_weights; + vocab_weights.reserve(vocab_count.size()); + for (const auto& f : vocab_count) { + float r = std::pow(static_cast<float>(f), 0.75f); + vocab_weights.push_back(r); + } + sampler_ = new random::DistributionSampler(vocab_weights); + } + + ~NegTrainOp() { delete sampler_; } + + void Compute(OpKernelContext* ctx) override { + Tensor w_in = ctx->mutable_input(0, false); + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(w_in.shape()), + errors::InvalidArgument("Must be a matrix")); + Tensor w_out = ctx->mutable_input(1, false); + OP_REQUIRES(ctx, w_in.shape() == w_out.shape(), + errors::InvalidArgument("w_in.shape == w_out.shape")); + const Tensor& examples = ctx->input(2); + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(examples.shape()), + errors::InvalidArgument("Must be a vector")); + const Tensor& labels = ctx->input(3); + OP_REQUIRES(ctx, examples.shape() == labels.shape(), + errors::InvalidArgument("examples.shape == labels.shape")); + const Tensor& learning_rate = ctx->input(4); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(learning_rate.shape()), + errors::InvalidArgument("Must be a scalar")); + + auto Tw_in = w_in.matrix<float>(); + auto Tw_out = w_out.matrix<float>(); + auto Texamples = examples.flat<int32>(); + auto Tlabels = labels.flat<int32>(); + auto lr = learning_rate.scalar<float>()(); + const int64 vocab_size = w_in.dim_size(0); + const int64 dims = w_in.dim_size(1); + const int64 batch_size = examples.dim_size(0); + OP_REQUIRES(ctx, vocab_size == sampler_->num(), + errors::InvalidArgument("vocab_size mismatches: ", vocab_size, + " vs. ", sampler_->num())); + + // Gradient accumulator for v_in. + Tensor buf(DT_FLOAT, TensorShape({dims})); + auto Tbuf = buf.flat<float>(); + + // Scalar buffer to hold sigmoid(+/- dot). + Tensor g_buf(DT_FLOAT, TensorShape({})); + auto g = g_buf.scalar<float>(); + + // The following loop needs 2 random 32-bit values per negative + // sample. We reserve 8 values per sample just in case the + // underlying implementation changes. + auto rnd = base_.ReserveSamples32(batch_size * num_samples_ * 8); + random::SimplePhilox srnd(&rnd); + + for (int64 i = 0; i < batch_size; ++i) { + const int32 example = Texamples(i); + DCHECK(0 <= example && example < vocab_size) << example; + const int32 label = Tlabels(i); + DCHECK(0 <= label && label < vocab_size) << label; + auto v_in = Tw_in.chip<0>(example); + + // Positive: example predicts label. + // forward: x = v_in' * v_out + // l = log(sigmoid(x)) + // backward: dl/dx = g = sigmoid(-x) + // dl/d(v_in) = g * v_out' + // dl/d(v_out) = v_in' * g + { + auto v_out = Tw_out.chip<0>(label); + auto dot = (v_in * v_out).sum(); + g = (dot.exp() + 1.f).inverse(); + Tbuf = v_out * (g() * lr); + v_out += v_in * (g() * lr); + } + + // Negative samples: + // forward: x = v_in' * v_sample + // l = log(sigmoid(-x)) + // backward: dl/dx = g = -sigmoid(x) + // dl/d(v_in) = g * v_out' + // dl/d(v_out) = v_in' * g + for (int j = 0; j < num_samples_; ++j) { + const int sample = sampler_->Sample(&srnd); + if (sample == label) continue; // Skip. + auto v_sample = Tw_out.chip<0>(sample); + auto dot = (v_in * v_sample).sum(); + g = -((-dot).exp() + 1.f).inverse(); + Tbuf += v_sample * (g() * lr); + v_sample += v_in * (g() * lr); + } + + // Applies the gradient on v_in. + v_in += Tbuf; + } + } + + private: + int32 num_samples_ = 0; + random::DistributionSampler* sampler_ = nullptr; + GuardedPhiloxRandom base_; +}; + +REGISTER_KERNEL_BUILDER(Name("NegTrain").Device(DEVICE_CPU), NegTrainOp); + +} // end namespace tensorflow diff --git a/tensorflow/models/embedding/word2vec_ops.cc b/tensorflow/models/embedding/word2vec_ops.cc new file mode 100644 index 0000000000..abe03baaf4 --- /dev/null +++ b/tensorflow/models/embedding/word2vec_ops.cc @@ -0,0 +1,56 @@ +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("Skipgram") + .Output("vocab_word: string") + .Output("vocab_freq: int32") + .Output("words_per_epoch: int64") + .Output("current_epoch: int32") + .Output("total_words_processed: int64") + .Output("examples: int32") + .Output("labels: int32") + .Attr("filename: string") + .Attr("batch_size: int") + .Attr("window_size: int = 5") + .Attr("min_count: int = 5") + .Attr("subsample: float = 1e-3") + .Doc(R"doc( +Parses a text file and creates a batch of examples. + +vocab_word: A vector of words in the corpus. +vocab_freq: Frequencies of words. Sorted in the non-ascending order. +words_per_epoch: Number of words per epoch in the data file. +current_epoch: The current epoch number. +total_words_processed: The total number of words processed so far. +examples: A vector of word ids. +labels: A vector of word ids. +filename: The corpus's text file name. +batch_size: The size of produced batch. +window_size: The number of words to predict to the left and right of the target. +min_count: The minimum number of word occurrences for it to be included in the + vocabulary. +subsample: Threshold for word occurrence. Words that appear with higher + frequency will be randomly down-sampled. Set to 0 to disable. +)doc"); + +REGISTER_OP("NegTrain") + .Input("w_in: Ref(float)") + .Input("w_out: Ref(float)") + .Input("examples: int32") + .Input("labels: int32") + .Input("lr: float") + .Attr("vocab_count: list(int)") + .Attr("num_negative_samples: int") + .Doc(R"doc( +Training via negative sampling. + +w_in: input word embedding. +w_out: output word embedding. +examples: A vector of word ids. +labels: A vector of word ids. +vocab_count: Count of words in the vocabulary. +num_negative_samples: Number of negative samples per exaple. +)doc"); + +} // end namespace tensorflow diff --git a/tensorflow/models/embedding/word2vec_optimized.py b/tensorflow/models/embedding/word2vec_optimized.py new file mode 100644 index 0000000000..23e7645a0b --- /dev/null +++ b/tensorflow/models/embedding/word2vec_optimized.py @@ -0,0 +1,405 @@ +"""Multi-threaded word2vec unbatched skip-gram model. + +Trains the model described in: +(Mikolov, et. al.) Efficient Estimation of Word Representations in Vector Space +ICLR 2013. +http://arxiv.org/abs/1301.3781 +This model does true SGD (i.e. no minibatching). To do this efficiently, custom +ops are used to sequentially process data within a 'batch'. + +The key ops used are: +* skipgram custom op that does input processing. +* neg_train custom op that efficiently calculates and applies the gradient using + true SGD. +""" + +import sys +import threading +import time + +import tensorflow.python.platform + +import numpy as np +import tensorflow as tf + +from tensorflow.models.embedding import gen_word2vec as word2vec + +flags = tf.app.flags + +flags.DEFINE_string("save_path", None, "Directory to write the model.") +flags.DEFINE_string( + "train_data", None, + "Training data. E.g., unzipped file http://mattmahoney.net/dc/text8.zip.") +flags.DEFINE_string( + "eval_data", None, "Analogy questions. " + "https://word2vec.googlecode.com/svn/trunk/questions-words.txt.") +flags.DEFINE_integer("embedding_size", 200, "The embedding dimension size.") +flags.DEFINE_integer( + "epochs_to_train", 15, + "Number of epochs to train. Each epoch processes the training data once " + "completely.") +flags.DEFINE_float("learning_rate", 0.025, "Initial learning rate.") +flags.DEFINE_integer("num_neg_samples", 25, + "Negative samples per training example.") +flags.DEFINE_integer("batch_size", 500, + "Numbers of training examples each step processes " + "(no minibatching).") +flags.DEFINE_integer("concurrent_steps", 12, + "The number of concurrent training steps.") +flags.DEFINE_integer("window_size", 5, + "The number of words to predict to the left and right " + "of the target word.") +flags.DEFINE_integer("min_count", 5, + "The minimum number of word occurrences for it to be " + "included in the vocabulary.") +flags.DEFINE_float("subsample", 1e-3, + "Subsample threshold for word occurrence. Words that appear " + "with higher frequency will be randomly down-sampled. Set " + "to 0 to disable.") +flags.DEFINE_boolean( + "interactive", False, + "If true, enters an IPython interactive session to play with the trained " + "model. E.g., try model.analogy('france', 'paris', 'russia') and " + "model.nearby(['proton', 'elephant', 'maxwell']") + +FLAGS = flags.FLAGS + + +class Options(object): + """Options used by our word2vec model.""" + + def __init__(self): + # Model options. + + # Embedding dimension. + self.emb_dim = FLAGS.embedding_size + + # Training options. + + # The training text file. + self.train_data = FLAGS.train_data + + # Number of negative samples per example. + self.num_samples = FLAGS.num_neg_samples + + # The initial learning rate. + self.learning_rate = FLAGS.learning_rate + + # Number of epochs to train. After these many epochs, the learning + # rate decays linearly to zero and the training stops. + self.epochs_to_train = FLAGS.epochs_to_train + + # Concurrent training steps. + self.concurrent_steps = FLAGS.concurrent_steps + + # Number of examples for one training step. + self.batch_size = FLAGS.batch_size + + # The number of words to predict to the left and right of the target word. + self.window_size = FLAGS.window_size + + # The minimum number of word occurrences for it to be included in the + # vocabulary. + self.min_count = FLAGS.min_count + + # Subsampling threshold for word occurrence. + self.subsample = FLAGS.subsample + + # Where to write out summaries. + self.save_path = FLAGS.save_path + + # Eval options. + + # The text file for eval. + self.eval_data = FLAGS.eval_data + + +class Word2Vec(object): + """Word2Vec model (Skipgram).""" + + def __init__(self, options, session): + self._options = options + self._session = session + self._word2id = {} + self._id2word = [] + self.build_graph() + self.build_eval_graph() + self.save_vocab() + self._read_analogies() + + def _read_analogies(self): + """Reads through the analogy question file. + + Returns: + questions: a [n, 4] numpy array containing the analogy question's + word ids. + questions_skipped: questions skipped due to unknown words. + """ + questions = [] + questions_skipped = 0 + with open(self._options.eval_data) as analogy_f: + for line in analogy_f: + if line.startswith(":"): # Skip comments. + continue + words = line.strip().lower().split(" ") + ids = [self._word2id.get(w.strip()) for w in words] + if None in ids or len(ids) != 4: + questions_skipped += 1 + else: + questions.append(np.array(ids)) + print "Eval analogy file: ", self._options.eval_data + print "Questions: ", len(questions) + print "Skipped: ", questions_skipped + self._analogy_questions = np.array(questions, dtype=np.int32) + + def build_graph(self): + """Build the model graph.""" + opts = self._options + + # The training data. A text file. + (words, counts, words_per_epoch, current_epoch, total_words_processed, + examples, labels) = word2vec.skipgram(filename=opts.train_data, + batch_size=opts.batch_size, + window_size=opts.window_size, + min_count=opts.min_count, + subsample=opts.subsample) + (opts.vocab_words, opts.vocab_counts, + opts.words_per_epoch) = self._session.run([words, counts, words_per_epoch]) + opts.vocab_size = len(opts.vocab_words) + print "Data file: ", opts.train_data + print "Vocab size: ", opts.vocab_size - 1, " + UNK" + print "Words per epoch: ", opts.words_per_epoch + + self._id2word = opts.vocab_words + for i, w in enumerate(self._id2word): + self._word2id[w] = i + + # Declare all variables we need. + # Input words embedding: [vocab_size, emb_dim] + w_in = tf.Variable( + tf.random_uniform( + [opts.vocab_size, + opts.emb_dim], -0.5 / opts.emb_dim, 0.5 / opts.emb_dim), + name="w_in") + + # Global step: scalar, i.e., shape []. + w_out = tf.Variable(tf.zeros([opts.vocab_size, opts.emb_dim]), name="w_out") + + # Global step: [] + global_step = tf.Variable(0, name="global_step") + + # Linear learning rate decay. + words_to_train = float(opts.words_per_epoch * opts.epochs_to_train) + lr = opts.learning_rate * tf.maximum( + 0.0001, + 1.0 - tf.cast(total_words_processed, tf.float32) / words_to_train) + + # Training nodes. + inc = global_step.assign_add(1) + with tf.control_dependencies([inc]): + train = word2vec.neg_train(w_in, + w_out, + examples, + labels, + lr, + vocab_count=opts.vocab_counts.tolist(), + num_negative_samples=opts.num_samples) + + self._w_in = w_in + self._examples = examples + self._labels = labels + self._lr = lr + self._train = train + self.step = global_step + self._epoch = current_epoch + self._words = total_words_processed + + def save_vocab(self): + """Save the vocabulary to a file so the model can be reloaded.""" + opts = self._options + with open(opts.save_path + "/vocab.txt", "w") as f: + for i in xrange(opts.vocab_size): + f.write(opts.vocab_words[i] + " " + str(opts.vocab_counts[i]) + "\n") + + def build_eval_graph(self): + """Build the evaluation graph.""" + # Eval graph + opts = self._options + + # Each analogy task is to predict the 4th word (d) given three + # words: a, b, c. E.g., a=italy, b=rome, c=france, we should + # predict d=paris. + + # The eval feeds three vectors of word ids for a, b, c, each of + # which is of size N, where N is the number of analogies we want to + # evaluate in one batch. + analogy_a = tf.placeholder(dtype=tf.int32) # [N] + analogy_b = tf.placeholder(dtype=tf.int32) # [N] + analogy_c = tf.placeholder(dtype=tf.int32) # [N] + + # Normalized word embeddings of shape [vocab_size, emb_dim]. + nemb = tf.nn.l2_normalize(self._w_in, 1) + + # Each row of a_emb, b_emb, c_emb is a word's embedding vector. + # They all have the shape [N, emb_dim] + a_emb = tf.gather(nemb, analogy_a) # a's embs + b_emb = tf.gather(nemb, analogy_b) # b's embs + c_emb = tf.gather(nemb, analogy_c) # c's embs + + # We expect that d's embedding vectors on the unit hyper-sphere is + # near: c_emb + (b_emb - a_emb), which has the shape [N, emb_dim]. + target = c_emb + (b_emb - a_emb) + + # Compute cosine distance between each pair of target and vocab. + # dist has shape [N, vocab_size]. + dist = tf.matmul(target, nemb, transpose_b=True) + + # For each question (row in dist), find the top 4 words. + _, pred_idx = tf.nn.top_k(dist, 4) + + # Nodes for computing neighbors for a given word according to + # their cosine distance. + nearby_word = tf.placeholder(dtype=tf.int32) # word id + nearby_emb = tf.gather(nemb, nearby_word) + nearby_dist = tf.matmul(nearby_emb, nemb, transpose_b=True) + nearby_val, nearby_idx = tf.nn.top_k(nearby_dist, + min(1000, opts.vocab_size)) + + # Nodes in the construct graph which are used by training and + # evaluation to run/feed/fetch. + self._analogy_a = analogy_a + self._analogy_b = analogy_b + self._analogy_c = analogy_c + self._analogy_pred_idx = pred_idx + self._nearby_word = nearby_word + self._nearby_val = nearby_val + self._nearby_idx = nearby_idx + + # Properly initialize all variables. + tf.initialize_all_variables().run() + + self.saver = tf.train.Saver() + + def _train_thread_body(self): + initial_epoch, = self._session.run([self._epoch]) + while True: + _, epoch = self._session.run([self._train, self._epoch]) + if epoch != initial_epoch: + break + + def train(self): + """Train the model.""" + opts = self._options + + initial_epoch, initial_words = self._session.run([self._epoch, self._words]) + + workers = [] + for _ in xrange(opts.concurrent_steps): + t = threading.Thread(target=self._train_thread_body) + t.start() + workers.append(t) + + last_words, last_time = initial_words, time.time() + while True: + time.sleep(5) # Reports our progress once a while. + (epoch, step, words, + lr) = self._session.run([self._epoch, self.step, self._words, self._lr]) + now = time.time() + last_words, last_time, rate = words, now, (words - last_words) / ( + now - last_time) + print "Epoch %4d Step %8d: lr = %5.3f words/sec = %8.0f\r" % (epoch, step, + lr, rate), + sys.stdout.flush() + if epoch != initial_epoch: + break + + for t in workers: + t.join() + + def _predict(self, analogy): + """Predict the top 4 answers for analogy questions.""" + idx, = self._session.run([self._analogy_pred_idx], { + self._analogy_a: analogy[:, 0], + self._analogy_b: analogy[:, 1], + self._analogy_c: analogy[:, 2] + }) + return idx + + def eval(self): + """Evaluate analogy questions and reports accuracy.""" + + # How many questions we get right at precision@1. + correct = 0 + + total = self._analogy_questions.shape[0] + start = 0 + while start < total: + limit = start + 2500 + sub = self._analogy_questions[start:limit, :] + idx = self._predict(sub) + start = limit + for question in xrange(sub.shape[0]): + for j in xrange(4): + if idx[question, j] == sub[question, 3]: + # Bingo! We predicted correctly. E.g., [italy, rome, france, paris]. + correct += 1 + break + elif idx[question, j] in sub[question, :3]: + # We need to skip words already in the question. + continue + else: + # The correct label is not the precision@1 + break + print + print "Eval %4d/%d accuracy = %4.1f%%" % (correct, total, + correct * 100.0 / total) + + def analogy(self, w0, w1, w2): + """Predict word w3 as in w0:w1 vs w2:w3.""" + wid = np.array([[self._word2id.get(w, 0) for w in [w0, w1, w2]]]) + idx = self._predict(wid) + for c in [self._id2word[i] for i in idx[0, :]]: + if c not in [w0, w1, w2]: + return c + return "unknown" + + def nearby(self, words, num=20): + """Prints out nearby words given a list of words.""" + ids = np.array([self._word2id.get(x, 0) for x in words]) + vals, idx = self._session.run( + [self._nearby_val, self._nearby_idx], {self._nearby_word: ids}) + for i in xrange(len(words)): + print "\n%s\n=====================================" % (words[i]) + for (neighbor, distance) in zip(idx[i, :num], vals[i, :num]): + print "%-20s %6.4f" % (self._id2word[neighbor], distance) + + +def _start_shell(local_ns=None): + # An interactive shell is useful for debugging/development. + import IPython + user_ns = {} + if local_ns: + user_ns.update(local_ns) + user_ns.update(globals()) + IPython.start_ipython(argv=[], user_ns=user_ns) + + +def main(_): + """Train a word2vec model.""" + opts = Options() + with tf.Graph().as_default(), tf.Session() as session: + model = Word2Vec(opts, session) + for _ in xrange(opts.epochs_to_train): + model.train() # Process one epoch + model.eval() # Eval analogies. + # Perform a final save. + model.saver.save(session, opts.save_path + "model", global_step=model.step) + if FLAGS.interactive: + # E.g., + # [0]: model.Analogy('france', 'paris', 'russia') + # [1]: model.Nearby(['proton', 'elephant', 'maxwell']) + _start_shell(locals()) + + +if __name__ == "__main__": + tf.app.run() diff --git a/tensorflow/models/image/alexnet/BUILD b/tensorflow/models/image/alexnet/BUILD new file mode 100644 index 0000000000..e1b9cd6965 --- /dev/null +++ b/tensorflow/models/image/alexnet/BUILD @@ -0,0 +1,28 @@ +# Description: +# Benchmark for AlexNet. + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +py_binary( + name = "alexnet_benchmark", + srcs = [ + "alexnet_benchmark.py", + ], + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/models/image/alexnet/__init__.py b/tensorflow/models/image/alexnet/__init__.py new file mode 100755 index 0000000000..e69de29bb2 --- /dev/null +++ b/tensorflow/models/image/alexnet/__init__.py diff --git a/tensorflow/models/image/alexnet/alexnet_benchmark.py b/tensorflow/models/image/alexnet/alexnet_benchmark.py new file mode 100644 index 0000000000..130948c4bf --- /dev/null +++ b/tensorflow/models/image/alexnet/alexnet_benchmark.py @@ -0,0 +1,215 @@ +"""Timing benchmark for AlexNet inference. + +To run, use: + bazel run -c opt --config=cuda \ + third_party/tensorflow/models/image/alexnet:alexnet_benchmark + +Across 100 steps on batch size = 128. + +Forward pass: +Run on Tesla K40c: 145 +/- 1.5 ms / batch +Run on Titan X: 70 +/- 0.1 ms / batch + +Forward-backward pass: +Run on Tesla K40c: 480 +/- 48 ms / batch +Run on Titan X: 244 +/- 30 ms / batch +""" +from datetime import datetime +import math +import time + +import tensorflow.python.platform +import tensorflow as tf + + +FLAGS = tf.app.flags.FLAGS + +tf.app.flags.DEFINE_integer('batch_size', 128, + """Batch size.""") +tf.app.flags.DEFINE_integer('num_batches', 100, + """Number of batches to run.""") + + +def print_activations(t): + print t.op.name, ' ', t.get_shape().as_list() + + +def inference(images): + """Build the AlexNet model. + + Args: + images: Images Tensor + + Returns: + pool5: the last Tensor in the convolutional component of AlexNet. + parameters: a list of Tensors corresponding to the weights and biases of the + AlexNet model. + """ + parameters = [] + # conv1 + with tf.name_scope('conv1') as scope: + kernel = tf.Variable(tf.truncated_normal([11, 11, 3, 64], dtype=tf.float32, + stddev=1e-1), name='weights') + conv = tf.nn.conv2d(images, kernel, [1, 4, 4, 1], padding='VALID') + biases = tf.Variable(tf.constant(0.0, shape=[64], dtype=tf.float32), + trainable=True, name='biases') + bias = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) + conv1 = tf.nn.relu(bias, name=scope) + print_activations(conv1) + parameters += [kernel, biases] + + # lrn1 + # TODO(shlens, jiayq): Add a GPU version of local response normalization. + + # pool1 + pool1 = tf.nn.max_pool(conv1, + ksize=[1, 3, 3, 1], + strides=[1, 2, 2, 1], + padding='VALID', + name='pool1') + print_activations(pool1) + + # conv2 + with tf.name_scope('conv2') as scope: + kernel = tf.Variable(tf.truncated_normal([5, 5, 64, 192], dtype=tf.float32, + stddev=1e-1), name='weights') + conv = tf.nn.conv2d(pool1, kernel, [1, 1, 1, 1], padding='SAME') + biases = tf.Variable(tf.constant(0.0, shape=[192], dtype=tf.float32), + trainable=True, name='biases') + bias = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) + conv2 = tf.nn.relu(bias, name=scope) + parameters += [kernel, biases] + print_activations(conv2) + + # pool2 + pool2 = tf.nn.max_pool(conv2, + ksize=[1, 3, 3, 1], + strides=[1, 2, 2, 1], + padding='VALID', + name='pool2') + print_activations(pool2) + + # conv3 + with tf.name_scope('conv3') as scope: + kernel = tf.Variable(tf.truncated_normal([3, 3, 192, 384], + dtype=tf.float32, + stddev=1e-1), name='weights') + conv = tf.nn.conv2d(pool2, kernel, [1, 1, 1, 1], padding='SAME') + biases = tf.Variable(tf.constant(0.0, shape=[384], dtype=tf.float32), + trainable=True, name='biases') + bias = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) + conv3 = tf.nn.relu(bias, name=scope) + parameters += [kernel, biases] + print_activations(conv3) + + # conv4 + with tf.name_scope('conv4') as scope: + kernel = tf.Variable(tf.truncated_normal([3, 3, 384, 256], + dtype=tf.float32, + stddev=1e-1), name='weights') + conv = tf.nn.conv2d(conv3, kernel, [1, 1, 1, 1], padding='SAME') + biases = tf.Variable(tf.constant(0.0, shape=[256], dtype=tf.float32), + trainable=True, name='biases') + bias = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) + conv4 = tf.nn.relu(bias, name=scope) + parameters += [kernel, biases] + print_activations(conv4) + + # conv5 + with tf.name_scope('conv5') as scope: + kernel = tf.Variable(tf.truncated_normal([3, 3, 256, 256], + dtype=tf.float32, + stddev=1e-1), name='weights') + conv = tf.nn.conv2d(conv4, kernel, [1, 1, 1, 1], padding='SAME') + biases = tf.Variable(tf.constant(0.0, shape=[256], dtype=tf.float32), + trainable=True, name='biases') + bias = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) + conv5 = tf.nn.relu(bias, name=scope) + parameters += [kernel, biases] + print_activations(conv5) + + # pool5 + pool5 = tf.nn.max_pool(conv5, + ksize=[1, 3, 3, 1], + strides=[1, 2, 2, 1], + padding='VALID', + name='pool5') + print_activations(pool5) + + return pool5, parameters + + +def time_tensorflow_run(session, target, info_string): + """Run the computation to obtain the target tensor and print timing stats. + + Args: + session: the TensorFlow session to run the computation under. + target: the targe Tensor that is passed to the session's run() function. + info_string: a string summarizing this run, to be printed with the stats. + + Returns: + None + """ + num_steps_burn_in = 10 + total_duration = 0.0 + total_duration_squared = 0.0 + for i in xrange(FLAGS.num_batches + num_steps_burn_in): + start_time = time.time() + _ = session.run(target) + duration = time.time() - start_time + if i > num_steps_burn_in: + if not i % 10: + print ('%s: step %d, duration = %.3f' % + (datetime.now(), i - num_steps_burn_in, duration)) + total_duration += duration + total_duration_squared += duration * duration + mn = total_duration / FLAGS.num_batches + vr = total_duration_squared / FLAGS.num_batches - mn * mn + sd = math.sqrt(vr) + print ('%s: %s across %d steps, %.3f +/- %.3f sec / batch' % + (datetime.now(), info_string, FLAGS.num_batches, mn, sd)) + + + +def run_benchmark(): + """Run the benchmark on AlexNet.""" + with tf.Graph().as_default(): + # Generate some dummy images. + image_size = 224 + # Note that our padding definition is slightly different the cuda-convnet. + # In order to force the model to start with the same activations sizes, + # we add 3 to the image_size and employ VALID padding above. + images = tf.Variable(tf.random_normal([FLAGS.batch_size, + image_size + 3, + image_size + 3, 3], + dtype=tf.float32, + stddev=1e-1)) + + # Build a Graph that computes the logits predictions from the + # inference model. + pool5, parameters = inference(images) + + # Build an initialization operation. + init = tf.initialize_all_variables() + + # Start running operations on the Graph. + sess = tf.Session('') + sess.run(init) + + # Run the forward benchmark. + time_tensorflow_run(sess, pool5, "Forward") + + # Add a simple objective so we can calculate the backward pass. + objective = tf.nn.l2_loss(pool5) + # Compute the gradient with respect to all the parameters. + grad = tf.gradients(objective, parameters) + # Run the backward benchmark. + time_tensorflow_run(sess, grad, "Forward-backward") + + +def main(_): + run_benchmark() + + +if __name__ == '__main__': + tf.app.run() diff --git a/tensorflow/models/image/cifar10/BUILD b/tensorflow/models/image/cifar10/BUILD new file mode 100644 index 0000000000..adf9aaffd4 --- /dev/null +++ b/tensorflow/models/image/cifar10/BUILD @@ -0,0 +1,79 @@ +# Description: +# Example TensorFlow models for CIFAR-10 + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +py_library( + name = "cifar10_input", + srcs = ["cifar10_input.py"], + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_test( + name = "cifar10_input_test", + srcs = ["cifar10_input_test.py"], + deps = [ + ":cifar10_input", + "//tensorflow:tensorflow_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + +py_library( + name = "cifar10", + srcs = ["cifar10.py"], + deps = [ + ":cifar10_input", + "//tensorflow:tensorflow_py", + ], +) + +py_binary( + name = "cifar10_eval", + srcs = [ + "cifar10_eval.py", + ], + visibility = ["//tensorflow:__subpackages__"], + deps = [ + ":cifar10", + ], +) + +py_binary( + name = "cifar10_train", + srcs = [ + "cifar10_train.py", + ], + visibility = ["//tensorflow:__subpackages__"], + deps = [ + ":cifar10", + ], +) + +py_binary( + name = "cifar10_multi_gpu_train", + srcs = [ + "cifar10_multi_gpu_train.py", + ], + visibility = ["//tensorflow:__subpackages__"], + deps = [ + ":cifar10", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/models/image/cifar10/README.md b/tensorflow/models/image/cifar10/README.md new file mode 100644 index 0000000000..67877aedc0 --- /dev/null +++ b/tensorflow/models/image/cifar10/README.md @@ -0,0 +1,10 @@ +CIFAR-10 is a common benchmark in machine learning for image recognition. + +http://www.cs.toronto.edu/~kriz/cifar.html + +Code in this directory demonstrates how to use TensorFlow to train and evaluate a convolutional neural network (CNN) on both CPU and GPU. We also demonstrate how to train a CNN over multiple GPUs. + +Detailed instructions on how to get started available at: + +http://tensorflow.org/tutorials/deep_cnn/ + diff --git a/tensorflow/models/image/cifar10/__init__.py b/tensorflow/models/image/cifar10/__init__.py new file mode 100755 index 0000000000..e69de29bb2 --- /dev/null +++ b/tensorflow/models/image/cifar10/__init__.py diff --git a/tensorflow/models/image/cifar10/cifar10.py b/tensorflow/models/image/cifar10/cifar10.py new file mode 100644 index 0000000000..7870080820 --- /dev/null +++ b/tensorflow/models/image/cifar10/cifar10.py @@ -0,0 +1,480 @@ +"""Builds the CIFAR-10 network. + +Summary of available functions: + + # Compute input images and labels for training. If you would like to run + # evaluations, use input() instead. + inputs, labels = distorted_inputs() + + # Compute inference on the model inputs to make a prediction. + predictions = inference(inputs) + + # Compute the total loss of the prediction with respect to the labels. + loss = loss(predictions, labels) + + # Create a graph to run one step of training with respect to the loss. + train_op = train(loss, global_step) +""" +# pylint: disable=missing-docstring +import gzip +import os +import re +import sys +import tarfile +import urllib + +import tensorflow.python.platform +import tensorflow as tf + +from tensorflow.models.image.cifar10 import cifar10_input +from tensorflow.python.platform import gfile + +FLAGS = tf.app.flags.FLAGS + +# Basic model parameters. +tf.app.flags.DEFINE_integer('batch_size', 128, + """Number of images to process in a batch.""") +tf.app.flags.DEFINE_string('data_dir', '/tmp/cifar10_data', + """Path to the CIFAR-10 data directory.""") + +# Process images of this size. Note that this differs from the original CIFAR +# image size of 32 x 32. If one alters this number, then the entire model +# architecture will change and any model would need to be retrained. +IMAGE_SIZE = 24 + +# Global constants describing the CIFAR-10 data set. +NUM_CLASSES = 10 +NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000 +NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000 + +# Constants describing the training process. +MOVING_AVERAGE_DECAY = 0.9999 # The decay to use for the moving average. +NUM_EPOCHS_PER_DECAY = 350.0 # Epochs after which learning rate decays. +LEARNING_RATE_DECAY_FACTOR = 0.1 # Learning rate decay factor. +INITIAL_LEARNING_RATE = 0.1 # Initial learning rate. + +# If a model is trained with multiple GPU's prefix all Op names with tower_name +# to differentiate the operations. Note that this prefix is removed from the +# names of the summaries when visualizing a model. +TOWER_NAME = 'tower' + +DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz' + + +def _activation_summary(x): + """Helper to create summaries for activations. + + Creates a summary that provides a histogram of activations. + Creates a summary that measure the sparsity of activations. + + Args: + x: Tensor + Returns: + nothing + """ + # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training + # session. This helps the clarity of presentation on tensorboard. + tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name) + tf.histogram_summary(tensor_name + '/activations', x) + tf.scalar_summary(tensor_name + '/sparsity', tf.nn.zero_fraction(x)) + + +def _variable_on_cpu(name, shape, initializer): + """Helper to create a Variable stored on CPU memory. + + Args: + name: name of the variable + shape: list of ints + initializer: initializer for Variable + + Returns: + Variable Tensor + """ + with tf.device('/cpu:0'): + var = tf.get_variable(name, shape, initializer=initializer) + return var + + +def _variable_with_weight_decay(name, shape, stddev, wd): + """Helper to create an initialized Variable with weight decay. + + Note that the Variable is initialized with a truncated normal distribution. + A weight decay is added only if one is specified. + + Args: + name: name of the variable + shape: list of ints + stddev: standard deviation of a truncated Gaussian + wd: add L2Loss weight decay multiplied by this float. If None, weight + decay is not added for this Variable. + + Returns: + Variable Tensor + """ + var = _variable_on_cpu(name, shape, + tf.truncated_normal_initializer(stddev=stddev)) + if wd: + weight_decay = tf.mul(tf.nn.l2_loss(var), wd, name='weight_loss') + tf.add_to_collection('losses', weight_decay) + return var + + +def _generate_image_and_label_batch(image, label, min_queue_examples): + """Construct a queued batch of images and labels. + + Args: + image: 3-D Tensor of [IMAGE_SIZE, IMAGE_SIZE, 3] of type.float32. + label: 1-D Tensor of type.int32 + min_queue_examples: int32, minimum number of samples to retain + in the queue that provides of batches of examples. + + Returns: + images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. + labels: Labels. 1D tensor of [batch_size] size. + """ + # Create a queue that shuffles the examples, and then + # read 'FLAGS.batch_size' images + labels from the example queue. + num_preprocess_threads = 16 + images, label_batch = tf.train.shuffle_batch( + [image, label], + batch_size=FLAGS.batch_size, + num_threads=num_preprocess_threads, + capacity=min_queue_examples + 3 * FLAGS.batch_size, + min_after_dequeue=min_queue_examples) + + # Display the training images in the visualizer. + tf.image_summary('images', images) + + return images, tf.reshape(label_batch, [FLAGS.batch_size]) + + +def distorted_inputs(): + """Construct distorted input for CIFAR training using the Reader ops. + + Raises: + ValueError: if no data_dir + + Returns: + images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. + labels: Labels. 1D tensor of [batch_size] size. + """ + filenames = [os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin', + 'data_batch_%d.bin' % i) + for i in xrange(1, 5)] + for f in filenames: + if not gfile.Exists(f): + raise ValueError('Failed to find file: ' + f) + + # Create a queue that produces the filenames to read. + filename_queue = tf.train.string_input_producer(filenames) + + # Read examples from files in the filename queue. + read_input = cifar10_input.read_cifar10(filename_queue) + reshaped_image = tf.cast(read_input.uint8image, tf.float32) + + height = IMAGE_SIZE + width = IMAGE_SIZE + + # Image processing for training the network. Note the many random + # distortions applied to the image. + + # Randomly crop a [height, width] section of the image. + distorted_image = tf.image.random_crop(reshaped_image, [height, width]) + + # Randomly flip the image horizontally. + distorted_image = tf.image.random_flip_left_right(distorted_image) + + # Because these operations are not commutative, consider randomizing + # randomize the order their operation. + distorted_image = tf.image.random_brightness(distorted_image, + max_delta=63) + distorted_image = tf.image.random_contrast(distorted_image, + lower=0.2, upper=1.8) + + # Subtract off the mean and divide by the variance of the pixels. + float_image = tf.image.per_image_whitening(distorted_image) + + # Ensure that the random shuffling has good mixing properties. + min_fraction_of_examples_in_queue = 0.4 + min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * + min_fraction_of_examples_in_queue) + print ('Filling queue with %d CIFAR images before starting to train. ' + 'This will take a few minutes.' % min_queue_examples) + + # Generate a batch of images and labels by building up a queue of examples. + return _generate_image_and_label_batch(float_image, read_input.label, + min_queue_examples) + + +def inputs(eval_data): + """Construct input for CIFAR evaluation using the Reader ops. + + Args: + eval_data: bool, indicating if one should use the train or eval data set. + + Raises: + ValueError: if no data_dir + + Returns: + images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. + labels: Labels. 1D tensor of [batch_size] size. + """ + if not FLAGS.data_dir: + raise ValueError('Please supply a data_dir') + + if not eval_data: + filenames = [os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin', + 'data_batch_%d.bin' % i) + for i in xrange(1, 5)] + num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN + else: + filenames = [os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin', + 'test_batch.bin')] + num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL + + for f in filenames: + if not gfile.Exists(f): + raise ValueError('Failed to find file: ' + f) + + # Create a queue that produces the filenames to read. + filename_queue = tf.train.string_input_producer(filenames) + + # Read examples from files in the filename queue. + read_input = cifar10_input.read_cifar10(filename_queue) + reshaped_image = tf.cast(read_input.uint8image, tf.float32) + + height = IMAGE_SIZE + width = IMAGE_SIZE + + # Image processing for evaluation. + # Crop the central [height, width] of the image. + resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image, + width, height) + + # Subtract off the mean and divide by the variance of the pixels. + float_image = tf.image.per_image_whitening(resized_image) + + # Ensure that the random shuffling has good mixing properties. + min_fraction_of_examples_in_queue = 0.4 + min_queue_examples = int(num_examples_per_epoch * + min_fraction_of_examples_in_queue) + + # Generate a batch of images and labels by building up a queue of examples. + return _generate_image_and_label_batch(float_image, read_input.label, + min_queue_examples) + + +def inference(images): + """Build the CIFAR-10 model. + + Args: + images: Images returned from distorted_inputs() or inputs(). + + Returns: + Logits. + """ + # We instantiate all variables using tf.get_variable() instead of + # tf.Variable() in order to share variables across multiple GPU training runs. + # If we only ran this model on a single GPU, we could simplify this function + # by replacing all instances of tf.get_variable() with tf.Variable(). + # + # conv1 + with tf.variable_scope('conv1') as scope: + kernel = _variable_with_weight_decay('weights', shape=[5, 5, 3, 64], + stddev=1e-4, wd=0.0) + conv = tf.nn.conv2d(images, kernel, [1, 1, 1, 1], padding='SAME') + biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.0)) + bias = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape().as_list()) + conv1 = tf.nn.relu(bias, name=scope.name) + _activation_summary(conv1) + + # pool1 + pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], + padding='SAME', name='pool1') + # norm1 + norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, + name='norm1') + + # conv2 + with tf.variable_scope('conv2') as scope: + kernel = _variable_with_weight_decay('weights', shape=[5, 5, 64, 64], + stddev=1e-4, wd=0.0) + conv = tf.nn.conv2d(norm1, kernel, [1, 1, 1, 1], padding='SAME') + biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.1)) + bias = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape().as_list()) + conv2 = tf.nn.relu(bias, name=scope.name) + _activation_summary(conv2) + + # norm2 + norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, + name='norm2') + # pool2 + pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], + strides=[1, 2, 2, 1], padding='SAME', name='pool2') + + # local3 + with tf.variable_scope('local3') as scope: + # Move everything into depth so we can perform a single matrix multiply. + dim = 1 + for d in pool2.get_shape()[1:].as_list(): + dim *= d + reshape = tf.reshape(pool2, [FLAGS.batch_size, dim]) + + weights = _variable_with_weight_decay('weights', shape=[dim, 384], + stddev=0.04, wd=0.004) + biases = _variable_on_cpu('biases', [384], tf.constant_initializer(0.1)) + local3 = tf.nn.relu_layer(reshape, weights, biases, name=scope.name) + _activation_summary(local3) + + # local4 + with tf.variable_scope('local4') as scope: + weights = _variable_with_weight_decay('weights', shape=[384, 192], + stddev=0.04, wd=0.004) + biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1)) + local4 = tf.nn.relu_layer(local3, weights, biases, name=scope.name) + _activation_summary(local4) + + # softmax, i.e. softmax(WX + b) + with tf.variable_scope('softmax_linear') as scope: + weights = _variable_with_weight_decay('weights', [192, NUM_CLASSES], + stddev=1/192.0, wd=0.0) + biases = _variable_on_cpu('biases', [NUM_CLASSES], + tf.constant_initializer(0.0)) + softmax_linear = tf.nn.xw_plus_b(local4, weights, biases, name=scope.name) + _activation_summary(softmax_linear) + + return softmax_linear + + +def loss(logits, labels): + """Add L2Loss to all the trainable variables. + + Add summary for for "Loss" and "Loss/avg". + Args: + logits: Logits from inference(). + labels: Labels from distorted_inputs or inputs(). 1-D tensor + of shape [batch_size] + + Returns: + Loss tensor of type float. + """ + # Reshape the labels into a dense Tensor of + # shape [batch_size, NUM_CLASSES]. + sparse_labels = tf.reshape(labels, [FLAGS.batch_size, 1]) + indices = tf.reshape(tf.range(0, FLAGS.batch_size, 1), [FLAGS.batch_size, 1]) + concated = tf.concat(1, [indices, sparse_labels]) + dense_labels = tf.sparse_to_dense(concated, + [FLAGS.batch_size, NUM_CLASSES], + 1.0, 0.0) + + # Calculate the average cross entropy loss across the batch. + cross_entropy = tf.nn.softmax_cross_entropy_with_logits( + logits, dense_labels, name='cross_entropy_per_example') + cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy') + tf.add_to_collection('losses', cross_entropy_mean) + + # The total loss is defined as the cross entropy loss plus all of the weight + # decay terms (L2 loss). + return tf.add_n(tf.get_collection('losses'), name='total_loss') + + +def _add_loss_summaries(total_loss): + """Add summaries for losses in CIFAR-10 model. + + Generates moving average for all losses and associated summaries for + visualizing the performance of the network. + + Args: + total_loss: Total loss from loss(). + Returns: + loss_averages_op: op for generating moving averages of losses. + """ + # Compute the moving average of all individual losses and the total loss. + loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg') + losses = tf.get_collection('losses') + loss_averages_op = loss_averages.apply(losses + [total_loss]) + + # Attach a scalar summmary to all individual losses and the total loss; do the + # same for the averaged version of the losses. + for l in losses + [total_loss]: + # Name each loss as '(raw)' and name the moving average version of the loss + # as the original loss name. + tf.scalar_summary(l.op.name +' (raw)', l) + tf.scalar_summary(l.op.name, loss_averages.average(l)) + + return loss_averages_op + + +def train(total_loss, global_step): + """Train CIFAR-10 model. + + Create an optimizer and apply to all trainable variables. Add moving + average for all trainable variables. + + Args: + total_loss: Total loss from loss(). + global_step: Integer Variable counting the number of training steps + processed. + Returns: + train_op: op for training. + """ + # Variables that affect learning rate. + num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_size + decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY) + + # Decay the learning rate exponentially based on the number of steps. + lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE, + global_step, + decay_steps, + LEARNING_RATE_DECAY_FACTOR, + staircase=True) + tf.scalar_summary('learning_rate', lr) + + # Generate moving averages of all losses and associated summaries. + loss_averages_op = _add_loss_summaries(total_loss) + + # Compute gradients. + with tf.control_dependencies([loss_averages_op]): + opt = tf.train.GradientDescentOptimizer(lr) + grads = opt.compute_gradients(total_loss) + + # Apply gradients. + apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) + + # Add histograms for trainable variables. + for var in tf.trainable_variables(): + tf.histogram_summary(var.op.name, var) + + # Add histograms for gradients. + for grad, var in grads: + if grad: + tf.histogram_summary(var.op.name + '/gradients', grad) + + # Track the moving averages of all trainable variables. + variable_averages = tf.train.ExponentialMovingAverage( + MOVING_AVERAGE_DECAY, global_step) + variables_averages_op = variable_averages.apply(tf.trainable_variables()) + + with tf.control_dependencies([apply_gradient_op, variables_averages_op]): + train_op = tf.no_op(name='train') + + return train_op + + +def maybe_download_and_extract(): + """Download and extract the tarball from Alex's website.""" + dest_directory = FLAGS.data_dir + if not os.path.exists(dest_directory): + os.makedirs(dest_directory) + filename = DATA_URL.split('/')[-1] + filepath = os.path.join(dest_directory, filename) + if not os.path.exists(filepath): + def _progress(count, block_size, total_size): + sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename, + float(count * block_size) / float(total_size) * 100.0)) + sys.stdout.flush() + filepath, _ = urllib.urlretrieve(DATA_URL, filepath, reporthook=_progress) + print + statinfo = os.stat(filepath) + print 'Succesfully downloaded', filename, statinfo.st_size, 'bytes.' + tarfile.open(filepath, 'r:gz').extractall(dest_directory) diff --git a/tensorflow/models/image/cifar10/cifar10_eval.py b/tensorflow/models/image/cifar10/cifar10_eval.py new file mode 100644 index 0000000000..73c224191d --- /dev/null +++ b/tensorflow/models/image/cifar10/cifar10_eval.py @@ -0,0 +1,148 @@ +"""Evaluation for CIFAR-10. + +Accuracy: +cifar10_train.py achieves 83.0% accuracy after 100K steps (256 epochs +of data) as judged by cifar10_eval.py. + +Speed: +On a single Tesla K40, cifar10_train.py processes a single batch of 128 images +in 0.25-0.35 sec (i.e. 350 - 600 images /sec). The model reaches ~86% +accuracy after 100K steps in 8 hours of training time. + +Usage: +Please see the tutorial and website for how to download the CIFAR-10 +data set, compile the program and train the model. + +http://tensorflow.org/tutorials/deep_cnn/ +""" +from datetime import datetime +import math +import time + +import tensorflow.python.platform +from tensorflow.python.platform import gfile +import numpy as np +import tensorflow as tf + +from tensorflow.models.image.cifar10 import cifar10 + +FLAGS = tf.app.flags.FLAGS + +tf.app.flags.DEFINE_string('eval_dir', '/tmp/cifar10_eval', + """Directory where to write event logs.""") +tf.app.flags.DEFINE_string('eval_data', 'test', + """Either 'test' or 'train_eval'.""") +tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/cifar10_train', + """Directory where to read model checkpoints.""") +tf.app.flags.DEFINE_integer('eval_interval_secs', 60 * 5, + """How often to run the eval.""") +tf.app.flags.DEFINE_integer('num_examples', 10000, + """Number of examples to run.""") +tf.app.flags.DEFINE_boolean('run_once', False, + """Whether to run eval only once.""") + + +def eval_once(saver, summary_writer, top_k_op, summary_op): + """Run Eval once. + + Args: + saver: Saver. + summary_writer: Summary writer. + top_k_op: Top K op. + summary_op: Summary op. + """ + with tf.Session() as sess: + ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) + if ckpt and ckpt.model_checkpoint_path: + # Restores from checkpoint + saver.restore(sess, ckpt.model_checkpoint_path) + # Assuming model_checkpoint_path looks something like: + # /my-favorite-path/cifar10_train/model.ckpt-0, + # extract global_step from it. + global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] + else: + print 'No checkpoint file found' + return + + # Start the queue runners. + coord = tf.train.Coordinator() + try: + threads = [] + for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS): + threads.extend(qr.create_threads(sess, coord=coord, daemon=True, + start=True)) + + num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size)) + true_count = 0 # Counts the number of correct predictions. + total_sample_count = num_iter * FLAGS.batch_size + step = 0 + while step < num_iter and not coord.should_stop(): + predictions = sess.run([top_k_op]) + true_count += np.sum(predictions) + step += 1 + + # Compute precision @ 1. + precision = float(true_count) / float(total_sample_count) + print '%s: precision @ 1 = %.3f' % (datetime.now(), precision) + + summary = tf.Summary() + summary.ParseFromString(sess.run(summary_op)) + summary.value.add(tag='Precision @ 1', simple_value=precision) + summary_writer.add_summary(summary, global_step) + except Exception, e: # pylint: disable=broad-except + coord.request_stop(e) + + coord.request_stop() + coord.join(threads, stop_grace_period_secs=10) + + +def evaluate(): + """Eval CIFAR-10 for a number of steps.""" + with tf.Graph().as_default(): + # Get images and labels for CIFAR-10. + eval_data = FLAGS.eval_data == 'test' + images, labels = cifar10.inputs(eval_data=eval_data) + + # Build a Graph that computes the logits predictions from the + # inference model. + logits = cifar10.inference(images) + + # Calculate predictions. + top_k_op = tf.nn.in_top_k(logits, labels, 1) + + # Restore the moving average version of the learned variables for eval. + variable_averages = tf.train.ExponentialMovingAverage( + cifar10.MOVING_AVERAGE_DECAY) + variables_to_restore = {} + for v in tf.all_variables(): + if v in tf.trainable_variables(): + restore_name = variable_averages.average_name(v) + else: + restore_name = v.op.name + variables_to_restore[restore_name] = v + saver = tf.train.Saver(variables_to_restore) + + # Build the summary operation based on the TF collection of Summaries. + summary_op = tf.merge_all_summaries() + + graph_def = tf.get_default_graph().as_graph_def() + summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir, + graph_def=graph_def) + + while True: + eval_once(saver, summary_writer, top_k_op, summary_op) + if FLAGS.run_once: + break + time.sleep(FLAGS.eval_interval_secs) + + +def main(argv=None): # pylint: disable=unused-argument + cifar10.maybe_download_and_extract() + if gfile.Exists(FLAGS.eval_dir): + gfile.DeleteRecursively(FLAGS.eval_dir) + gfile.MakeDirs(FLAGS.eval_dir) + evaluate() + + +if __name__ == '__main__': + tf.app.run() diff --git a/tensorflow/models/image/cifar10/cifar10_input.py b/tensorflow/models/image/cifar10/cifar10_input.py new file mode 100644 index 0000000000..686f1bf987 --- /dev/null +++ b/tensorflow/models/image/cifar10/cifar10_input.py @@ -0,0 +1,65 @@ +"""Routine for decoding the CIFAR-10 binary file format.""" + +import tensorflow.python.platform +import tensorflow as tf + + +def read_cifar10(filename_queue): + """Reads and parses examples from CIFAR10 data files. + + Recommendation: if you want N-way read parallelism, call this function + N times. This will give you N independent Readers reading different + files & positions within those files, which will give better mixing of + examples. + + Args: + filename_queue: A queue of strings with the filenames to read from. + + Returns: + An object representing a single example, with the following fields: + height: number of rows in the result (32) + width: number of columns in the result (32) + depth: number of color channels in the result (3) + key: a scalar string Tensor describing the filename & record number + for this example. + label: an int32 Tensor with the label in the range 0..9. + uint8image: a [height, width, depth] uint8 Tensor with the image data + """ + + class CIFAR10Record(object): + pass + result = CIFAR10Record() + + # Dimensions of the images in the CIFAR-10 dataset. + # See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the + # input format. + label_bytes = 1 # 2 for CIFAR-100 + result.height = 32 + result.width = 32 + result.depth = 3 + image_bytes = result.height * result.width * result.depth + # Every record consists of a label followed by the image, with a + # fixed number of bytes for each. + record_bytes = label_bytes + image_bytes + + # Read a record, getting filenames from the filename_queue. No + # header or footer in the CIFAR-10 format, so we leave header_bytes + # and footer_bytes at their default of 0. + reader = tf.FixedLengthRecordReader(record_bytes=record_bytes) + result.key, value = reader.read(filename_queue) + + # Convert from a string to a vector of uint8 that is record_bytes long. + record_bytes = tf.decode_raw(value, tf.uint8) + + # The first bytes represent the label, which we convert from uint8->int32. + result.label = tf.cast( + tf.slice(record_bytes, [0], [label_bytes]), tf.int32) + + # The remaining bytes after the label represent the image, which we reshape + # from [depth * height * width] to [depth, height, width]. + depth_major = tf.reshape(tf.slice(record_bytes, [label_bytes], [image_bytes]), + [result.depth, result.height, result.width]) + # Convert from [depth, height, width] to [height, width, depth]. + result.uint8image = tf.transpose(depth_major, [1, 2, 0]) + + return result diff --git a/tensorflow/models/image/cifar10/cifar10_input_test.py b/tensorflow/models/image/cifar10/cifar10_input_test.py new file mode 100644 index 0000000000..d43f5aedcf --- /dev/null +++ b/tensorflow/models/image/cifar10/cifar10_input_test.py @@ -0,0 +1,49 @@ +"""Tests for cifar10 input.""" + +import os + +import tensorflow.python.platform + +import tensorflow as tf + +from tensorflow.models.image.cifar10 import cifar10_input + + +class CIFAR10InputTest(tf.test.TestCase): + + def _record(self, label, red, green, blue): + image_size = 32 * 32 + record = "%s%s%s%s" % (chr(label), chr(red) * image_size, + chr(green) * image_size, chr(blue) * image_size) + expected = [[[red, green, blue]] * 32] * 32 + return record, expected + + def testSimple(self): + labels = [9, 3, 0] + records = [self._record(labels[0], 0, 128, 255), + self._record(labels[1], 255, 0, 1), + self._record(labels[2], 254, 255, 0)] + contents = "".join([record for record, _ in records]) + expected = [expected for _, expected in records] + filename = os.path.join(self.get_temp_dir(), "cifar") + open(filename, "w").write(contents) + + with self.test_session() as sess: + q = tf.FIFOQueue(99, [tf.string], shapes=()) + q.enqueue([filename]).run() + q.close().run() + result = cifar10_input.read_cifar10(q) + + for i in range(3): + key, label, uint8image = sess.run([ + result.key, result.label, result.uint8image]) + self.assertEqual("%s:%d" % (filename, i), key) + self.assertEqual(labels[i], label) + self.assertAllEqual(expected[i], uint8image) + + with self.assertRaises(tf.errors.OutOfRangeError): + sess.run([result.key, result.uint8image]) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py b/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py new file mode 100644 index 0000000000..54bc41f444 --- /dev/null +++ b/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py @@ -0,0 +1,265 @@ +"""A binary to train CIFAR-10 using multiple GPU's with synchronous updates. + +Accuracy: +cifar10_multi_gpu_train.py achieves ~86% accuracy after 100K steps (256 +epochs of data) as judged by cifar10_eval.py. + +Speed: With batch_size 128. + +System | Step Time (sec/batch) | Accuracy +-------------------------------------------------------------------- +1 Tesla K20m | 0.35-0.60 | ~86% at 60K steps (5 hours) +1 Tesla K40m | 0.25-0.35 | ~86% at 100K steps (4 hours) +2 Tesla K20m | 0.13-0.20 | ~84% at 30K steps (2.5 hours) +3 Tesla K20m | 0.13-0.18 | ~84% at 30K steps +4 Tesla K20m | ~0.10 | ~84% at 30K steps + +Usage: +Please see the tutorial and website for how to download the CIFAR-10 +data set, compile the program and train the model. + +http://tensorflow.org/tutorials/deep_cnn/ +""" +from datetime import datetime +import os.path +import re +import time + +# pylint: disable=unused-import,g-bad-import-order +import tensorflow.python.platform +from tensorflow.python.platform import gfile +import numpy as np +import tensorflow as tf +from tensorflow.models.image.cifar10 import cifar10 +# pylint: disable=unused-import,g-bad-import-order + +FLAGS = tf.app.flags.FLAGS + +tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train', + """Directory where to write event logs """ + """and checkpoint.""") +tf.app.flags.DEFINE_integer('max_steps', 1000000, + """Number of batches to run.""") +tf.app.flags.DEFINE_integer('num_gpus', 1, + """How many GPUs to use.""") +tf.app.flags.DEFINE_boolean('log_device_placement', False, + """Whether to log device placement.""") + + +def tower_loss(scope): + """Calculate the total loss on a single tower running the CIFAR model. + + Args: + scope: unique prefix string identifying the CIFAR tower, e.g. 'tower_0' + + Returns: + Tensor of shape [] containing the total loss for a batch of data + """ + # Get images and labels for CIFAR-10. + images, labels = cifar10.distorted_inputs() + + # Build inference Graph. + logits = cifar10.inference(images) + + # Build the portion of the Graph calculating the losses. Note that we will + # assemble the total_loss using a custom function below. + _ = cifar10.loss(logits, labels) + + # Assemble all of the losses for the current tower only. + losses = tf.get_collection('losses', scope) + + # Calculate the total loss for the current tower. + total_loss = tf.add_n(losses, name='total_loss') + + # Compute the moving average of all individual losses and the total loss. + loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg') + loss_averages_op = loss_averages.apply(losses + [total_loss]) + + # Attach a scalar summmary to all individual losses and the total loss; do the + # same for the averaged version of the losses. + for l in losses + [total_loss]: + # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training + # session. This helps the clarity of presentation on tensorboard. + loss_name = re.sub('%s_[0-9]*/' % cifar10.TOWER_NAME, '', l.op.name) + # Name each loss as '(raw)' and name the moving average version of the loss + # as the original loss name. + tf.scalar_summary(loss_name +' (raw)', l) + tf.scalar_summary(loss_name, loss_averages.average(l)) + + with tf.control_dependencies([loss_averages_op]): + total_loss = tf.identity(total_loss) + return total_loss + + +def average_gradients(tower_grads): + """Calculate the average gradient for each shared variable across all towers. + + Note that this function provides a synchronization point across all towers. + + Args: + tower_grads: List of lists of (gradient, variable) tuples. The outer list + is over individual gradients. The inner list is over the gradient + calculation for each tower. + Returns: + List of pairs of (gradient, variable) where the gradient has been averaged + across all towers. + """ + average_grads = [] + for grad_and_vars in zip(*tower_grads): + # Note that each grad_and_vars looks like the following: + # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) + grads = [] + for g, _ in grad_and_vars: + # Add 0 dimension to the gradients to represent the tower. + expanded_g = tf.expand_dims(g, 0) + + # Append on a 'tower' dimension which we will average over below. + grads.append(expanded_g) + + # Average over the 'tower' dimension. + grad = tf.concat(0, grads) + grad = tf.reduce_mean(grad, 0) + + # Keep in mind that the Variables are redundant because they are shared + # across towers. So .. we will just return the first tower's pointer to + # the Variable. + v = grad_and_vars[0][1] + grad_and_var = (grad, v) + average_grads.append(grad_and_var) + return average_grads + + +def train(): + """Train CIFAR-10 for a number of steps.""" + with tf.Graph().as_default(), tf.device('/cpu:0'): + # Create a variable to count the number of train() calls. This equals the + # number of batches processed * FLAGS.num_gpus. + global_step = tf.get_variable( + 'global_step', [], + initializer=tf.constant_initializer(0), trainable=False) + + # Calculate the learning rate schedule. + num_batches_per_epoch = (cifar10.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / + FLAGS.batch_size) + decay_steps = int(num_batches_per_epoch * cifar10.NUM_EPOCHS_PER_DECAY) + + # Decay the learning rate exponentially based on the number of steps. + lr = tf.train.exponential_decay(cifar10.INITIAL_LEARNING_RATE, + global_step, + decay_steps, + cifar10.LEARNING_RATE_DECAY_FACTOR, + staircase=True) + + # Create an optimizer that performs gradient descent. + opt = tf.train.GradientDescentOptimizer(lr) + + # Calculate the gradients for each model tower. + tower_grads = [] + for i in xrange(FLAGS.num_gpus): + with tf.device('/gpu:%d' % i): + with tf.name_scope('%s_%d' % (cifar10.TOWER_NAME, i)) as scope: + # Calculate the loss for one tower of the CIFAR model. This function + # constructs the entire CIFAR model but shares the variables across + # all towers. + loss = tower_loss(scope) + + # Reuse variables for the next tower. + tf.get_variable_scope().reuse_variables() + + # Retain the summaries from the final tower. + summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope) + + # Calculate the gradients for the batch of data on this CIFAR tower. + grads = opt.compute_gradients(loss) + + # Keep track of the gradients across all towers. + tower_grads.append(grads) + + # We must calculate the mean of each gradient. Note that this is the + # synchronization point across all towers. + grads = average_gradients(tower_grads) + + # Add a summary to track the learning rate. + summaries.append(tf.scalar_summary('learning_rate', lr)) + + # Add histograms for gradients. + for grad, var in grads: + if grad: + summaries.append( + tf.histogram_summary(var.op.name + '/gradients', grad)) + + # Apply the gradients to adjust the shared variables. + apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) + + # Add histograms for trainable variables. + for var in tf.trainable_variables(): + summaries.append(tf.histogram_summary(var.op.name, var)) + + # Track the moving averages of all trainable variables. + variable_averages = tf.train.ExponentialMovingAverage( + cifar10.MOVING_AVERAGE_DECAY, global_step) + variables_averages_op = variable_averages.apply(tf.trainable_variables()) + + # Group all updates to into a single train op. + train_op = tf.group(apply_gradient_op, variables_averages_op) + + # Create a saver. + saver = tf.train.Saver(tf.all_variables()) + + # Build the summary operation from the last tower summaries. + summary_op = tf.merge_summary(summaries) + + # Build an initialization operation to run below. + init = tf.initialize_all_variables() + + # Start running operations on the Graph. allow_soft_placement must be set to + # True to build towers on GPU, as some of the ops do not have GPU + # implementations. + sess = tf.Session(config=tf.ConfigProto( + allow_soft_placement=True, + log_device_placement=FLAGS.log_device_placement)) + sess.run(init) + + # Start the queue runners. + tf.train.start_queue_runners(sess=sess) + + summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, + graph_def=sess.graph_def) + + for step in xrange(FLAGS.max_steps): + start_time = time.time() + _, loss_value = sess.run([train_op, loss]) + duration = time.time() - start_time + + assert not np.isnan(loss_value), 'Model diverged with loss = NaN' + + if step % 10 == 0: + num_examples_per_step = FLAGS.batch_size * FLAGS.num_gpus + examples_per_sec = num_examples_per_step / float(duration) + sec_per_batch = float(duration) / FLAGS.num_gpus + + format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' + 'sec/batch)') + print (format_str % (datetime.now(), step, loss_value, + examples_per_sec, sec_per_batch)) + + if step % 100 == 0: + summary_str = sess.run(summary_op) + summary_writer.add_summary(summary_str, step) + + # Save the model checkpoint periodically. + if step % 1000 == 0 or (step + 1) == FLAGS.max_steps: + checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') + saver.save(sess, checkpoint_path, global_step=step) + + +def main(argv=None): # pylint: disable=unused-argument + cifar10.maybe_download_and_extract() + if gfile.Exists(FLAGS.train_dir): + gfile.DeleteRecursively(FLAGS.train_dir) + gfile.MakeDirs(FLAGS.train_dir) + train() + + +if __name__ == '__main__': + tf.app.run() diff --git a/tensorflow/models/image/cifar10/cifar10_train.py b/tensorflow/models/image/cifar10/cifar10_train.py new file mode 100644 index 0000000000..bcb6eeae58 --- /dev/null +++ b/tensorflow/models/image/cifar10/cifar10_train.py @@ -0,0 +1,119 @@ +"""A binary to train CIFAR-10 using a single GPU. + +Accuracy: +cifar10_train.py achieves ~86% accuracy after 100K steps (256 epochs of +data) as judged by cifar10_eval.py. + +Speed: With batch_size 128. + +System | Step Time (sec/batch) | Accuracy +------------------------------------------------------------------ +1 Tesla K20m | 0.35-0.60 | ~86% at 60K steps (5 hours) +1 Tesla K40m | 0.25-0.35 | ~86% at 100K steps (4 hours) + +Usage: +Please see the tutorial and website for how to download the CIFAR-10 +data set, compile the program and train the model. + +http://tensorflow.org/tutorials/deep_cnn/ +""" +from datetime import datetime +import os.path +import time + +import tensorflow.python.platform +from tensorflow.python.platform import gfile + +import numpy as np + +import tensorflow as tf + +from tensorflow.models.image.cifar10 import cifar10 + +FLAGS = tf.app.flags.FLAGS + +tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train', + """Directory where to write event logs """ + """and checkpoint.""") +tf.app.flags.DEFINE_integer('max_steps', 1000000, + """Number of batches to run.""") +tf.app.flags.DEFINE_boolean('log_device_placement', False, + """Whether to log device placement.""") + + +def train(): + """Train CIFAR-10 for a number of steps.""" + with tf.Graph().as_default(): + global_step = tf.Variable(0, trainable=False) + + # Get images and labels for CIFAR-10. + images, labels = cifar10.distorted_inputs() + + # Build a Graph that computes the logits predictions from the + # inference model. + logits = cifar10.inference(images) + + # Calculate loss. + loss = cifar10.loss(logits, labels) + + # Build a Graph that trains the model with one batch of examples and + # updates the model parameters. + train_op = cifar10.train(loss, global_step) + + # Create a saver. + saver = tf.train.Saver(tf.all_variables()) + + # Build the summary operation based on the TF collection of Summaries. + summary_op = tf.merge_all_summaries() + + # Build an initialization operation to run below. + init = tf.initialize_all_variables() + + # Start running operations on the Graph. + sess = tf.Session(config=tf.ConfigProto( + log_device_placement=FLAGS.log_device_placement)) + sess.run(init) + + # Start the queue runners. + tf.train.start_queue_runners(sess=sess) + + summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, + graph_def=sess.graph_def) + + for step in xrange(FLAGS.max_steps): + start_time = time.time() + _, loss_value = sess.run([train_op, loss]) + duration = time.time() - start_time + + assert not np.isnan(loss_value), 'Model diverged with loss = NaN' + + if step % 10 == 0: + num_examples_per_step = FLAGS.batch_size + examples_per_sec = num_examples_per_step / float(duration) + sec_per_batch = float(duration) + + format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' + 'sec/batch)') + print (format_str % (datetime.now(), step, loss_value, + examples_per_sec, sec_per_batch)) + + if step % 100 == 0: + summary_str = sess.run(summary_op) + summary_writer.add_summary(summary_str, step) + + # Save the model checkpoint periodically. + if step % 1000 == 0 or (step + 1) == FLAGS.max_steps: + checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') + saver.save(sess, checkpoint_path, global_step=step) + + +def main(argv=None): # pylint: disable=unused-argument + cifar10.maybe_download_and_extract() + if gfile.Exists(FLAGS.train_dir): + gfile.DeleteRecursively(FLAGS.train_dir) + gfile.MakeDirs(FLAGS.train_dir) + train() + + +if __name__ == '__main__': + tf.app.run() diff --git a/tensorflow/models/image/mnist/BUILD b/tensorflow/models/image/mnist/BUILD new file mode 100644 index 0000000000..76b31d0feb --- /dev/null +++ b/tensorflow/models/image/mnist/BUILD @@ -0,0 +1,44 @@ +# Description: +# Example TensorFlow models for MNIST that achieves high accuracy + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +py_binary( + name = "convolutional", + srcs = [ + "convolutional.py", + ], + visibility = ["//tensorflow:__subpackages__"], + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_test( + name = "convolutional_test", + size = "medium", + srcs = [ + "convolutional.py", + ], + args = [ + "--self_test=True", + ], + main = "convolutional.py", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/models/image/mnist/__init__.py b/tensorflow/models/image/mnist/__init__.py new file mode 100755 index 0000000000..e69de29bb2 --- /dev/null +++ b/tensorflow/models/image/mnist/__init__.py diff --git a/tensorflow/models/image/mnist/convolutional.py b/tensorflow/models/image/mnist/convolutional.py new file mode 100644 index 0000000000..8fb0e4dfb4 --- /dev/null +++ b/tensorflow/models/image/mnist/convolutional.py @@ -0,0 +1,270 @@ +"""Simple, end-to-end, LeNet-5-like convolutional MNIST model example. + +This should achieve a test error of 0.8%. Please keep this model as simple and +linear as possible, it is meant as a tutorial for simple convolutional models. +Run with --self_test on the command line to exectute a short self-test. +""" +import gzip +import os +import sys +import urllib + +import tensorflow.python.platform + +import numpy +import tensorflow as tf + +SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' +WORK_DIRECTORY = 'data' +IMAGE_SIZE = 28 +NUM_CHANNELS = 1 +PIXEL_DEPTH = 255 +NUM_LABELS = 10 +VALIDATION_SIZE = 5000 # Size of the validation set. +SEED = 66478 # Set to None for random seed. +BATCH_SIZE = 64 +NUM_EPOCHS = 10 + + +tf.app.flags.DEFINE_boolean("self_test", False, "True if running a self test.") +FLAGS = tf.app.flags.FLAGS + + +def maybe_download(filename): + """Download the data from Yann's website, unless it's already here.""" + if not os.path.exists(WORK_DIRECTORY): + os.mkdir(WORK_DIRECTORY) + filepath = os.path.join(WORK_DIRECTORY, filename) + if not os.path.exists(filepath): + filepath, _ = urllib.urlretrieve(SOURCE_URL + filename, filepath) + statinfo = os.stat(filepath) + print 'Succesfully downloaded', filename, statinfo.st_size, 'bytes.' + return filepath + + +def extract_data(filename, num_images): + """Extract the images into a 4D tensor [image index, y, x, channels]. + + Values are rescaled from [0, 255] down to [-0.5, 0.5]. + """ + print 'Extracting', filename + with gzip.open(filename) as bytestream: + bytestream.read(16) + buf = bytestream.read(IMAGE_SIZE * IMAGE_SIZE * num_images) + data = numpy.frombuffer(buf, dtype=numpy.uint8).astype(numpy.float32) + data = (data - (PIXEL_DEPTH / 2.0)) / PIXEL_DEPTH + data = data.reshape(num_images, IMAGE_SIZE, IMAGE_SIZE, 1) + return data + + +def extract_labels(filename, num_images): + """Extract the labels into a 1-hot matrix [image index, label index].""" + print 'Extracting', filename + with gzip.open(filename) as bytestream: + bytestream.read(8) + buf = bytestream.read(1 * num_images) + labels = numpy.frombuffer(buf, dtype=numpy.uint8) + # Convert to dense 1-hot representation. + return (numpy.arange(NUM_LABELS) == labels[:, None]).astype(numpy.float32) + + +def fake_data(num_images): + """Generate a fake dataset that matches the dimensions of MNIST.""" + data = numpy.ndarray( + shape=(num_images, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS), + dtype=numpy.float32) + labels = numpy.zeros(shape=(num_images, NUM_LABELS), dtype=numpy.float32) + for image in xrange(num_images): + label = image % 2 + data[image, :, :, 0] = label - 0.5 + labels[image, label] = 1.0 + return data, labels + + +def error_rate(predictions, labels): + """Return the error rate based on dense predictions and 1-hot labels.""" + return 100.0 - ( + 100.0 * + numpy.sum(numpy.argmax(predictions, 1) == numpy.argmax(labels, 1)) / + predictions.shape[0]) + + +def main(argv=None): # pylint: disable=unused-argument + if FLAGS.self_test: + print 'Running self-test.' + train_data, train_labels = fake_data(256) + validation_data, validation_labels = fake_data(16) + test_data, test_labels = fake_data(256) + num_epochs = 1 + else: + # Get the data. + train_data_filename = maybe_download('train-images-idx3-ubyte.gz') + train_labels_filename = maybe_download('train-labels-idx1-ubyte.gz') + test_data_filename = maybe_download('t10k-images-idx3-ubyte.gz') + test_labels_filename = maybe_download('t10k-labels-idx1-ubyte.gz') + + # Extract it into numpy arrays. + train_data = extract_data(train_data_filename, 60000) + train_labels = extract_labels(train_labels_filename, 60000) + test_data = extract_data(test_data_filename, 10000) + test_labels = extract_labels(test_labels_filename, 10000) + + # Generate a validation set. + validation_data = train_data[:VALIDATION_SIZE, :, :, :] + validation_labels = train_labels[:VALIDATION_SIZE] + train_data = train_data[VALIDATION_SIZE:, :, :, :] + train_labels = train_labels[VALIDATION_SIZE:] + num_epochs = NUM_EPOCHS + train_size = train_labels.shape[0] + + # This is where training samples and labels are fed to the graph. + # These placeholder nodes will be fed a batch of training data at each + # training step using the {feed_dict} argument to the Run() call below. + train_data_node = tf.placeholder( + tf.float32, + shape=(BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS)) + train_labels_node = tf.placeholder(tf.float32, + shape=(BATCH_SIZE, NUM_LABELS)) + # For the validation and test data, we'll just hold the entire dataset in + # one constant node. + validation_data_node = tf.constant(validation_data) + test_data_node = tf.constant(test_data) + + # The variables below hold all the trainable weights. They are passed an + # initial value which will be assigned when when we call: + # {tf.initialize_all_variables().run()} + conv1_weights = tf.Variable( + tf.truncated_normal([5, 5, NUM_CHANNELS, 32], # 5x5 filter, depth 32. + stddev=0.1, + seed=SEED)) + conv1_biases = tf.Variable(tf.zeros([32])) + conv2_weights = tf.Variable( + tf.truncated_normal([5, 5, 32, 64], + stddev=0.1, + seed=SEED)) + conv2_biases = tf.Variable(tf.constant(0.1, shape=[64])) + fc1_weights = tf.Variable( # fully connected, depth 512. + tf.truncated_normal([IMAGE_SIZE / 4 * IMAGE_SIZE / 4 * 64, 512], + stddev=0.1, + seed=SEED)) + fc1_biases = tf.Variable(tf.constant(0.1, shape=[512])) + fc2_weights = tf.Variable( + tf.truncated_normal([512, NUM_LABELS], + stddev=0.1, + seed=SEED)) + fc2_biases = tf.Variable(tf.constant(0.1, shape=[NUM_LABELS])) + + # We will replicate the model structure for the training subgraph, as well + # as the evaluation subgraphs, while sharing the trainable parameters. + def model(data, train=False): + """The Model definition.""" + # 2D convolution, with 'SAME' padding (i.e. the output feature map has + # the same size as the input). Note that {strides} is a 4D array whose + # shape matches the data layout: [image index, y, x, depth]. + conv = tf.nn.conv2d(data, + conv1_weights, + strides=[1, 1, 1, 1], + padding='SAME') + # Bias and rectified linear non-linearity. + relu = tf.nn.relu(tf.nn.bias_add(conv, conv1_biases)) + # Max pooling. The kernel size spec {ksize} also follows the layout of + # the data. Here we have a pooling window of 2, and a stride of 2. + pool = tf.nn.max_pool(relu, + ksize=[1, 2, 2, 1], + strides=[1, 2, 2, 1], + padding='SAME') + conv = tf.nn.conv2d(pool, + conv2_weights, + strides=[1, 1, 1, 1], + padding='SAME') + relu = tf.nn.relu(tf.nn.bias_add(conv, conv2_biases)) + pool = tf.nn.max_pool(relu, + ksize=[1, 2, 2, 1], + strides=[1, 2, 2, 1], + padding='SAME') + # Reshape the feature map cuboid into a 2D matrix to feed it to the + # fully connected layers. + pool_shape = pool.get_shape().as_list() + reshape = tf.reshape( + pool, + [pool_shape[0], pool_shape[1] * pool_shape[2] * pool_shape[3]]) + # Fully connected layer. Note that the '+' operation automatically + # broadcasts the biases. + hidden = tf.nn.relu(tf.matmul(reshape, fc1_weights) + fc1_biases) + # Add a 50% dropout during training only. Dropout also scales + # activations such that no rescaling is needed at evaluation time. + if train: + hidden = tf.nn.dropout(hidden, 0.5, seed=SEED) + return tf.matmul(hidden, fc2_weights) + fc2_biases + + # Training computation: logits + cross-entropy loss. + logits = model(train_data_node, True) + loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits( + logits, train_labels_node)) + + # L2 regularization for the fully connected parameters. + regularizers = (tf.nn.l2_loss(fc1_weights) + tf.nn.l2_loss(fc1_biases) + + tf.nn.l2_loss(fc2_weights) + tf.nn.l2_loss(fc2_biases)) + # Add the regularization term to the loss. + loss += 5e-4 * regularizers + + # Optimizer: set up a variable that's incremented once per batch and + # controls the learning rate decay. + batch = tf.Variable(0) + # Decay once per epoch, using an exponential schedule starting at 0.01. + learning_rate = tf.train.exponential_decay( + 0.01, # Base learning rate. + batch * BATCH_SIZE, # Current index into the dataset. + train_size, # Decay step. + 0.95, # Decay rate. + staircase=True) + # Use simple momentum for the optimization. + optimizer = tf.train.MomentumOptimizer(learning_rate, + 0.9).minimize(loss, + global_step=batch) + + # Predictions for the minibatch, validation set and test set. + train_prediction = tf.nn.softmax(logits) + # We'll compute them only once in a while by calling their {eval()} method. + validation_prediction = tf.nn.softmax(model(validation_data_node)) + test_prediction = tf.nn.softmax(model(test_data_node)) + + # Create a local session to run this computation. + with tf.Session() as s: + # Run all the initializers to prepare the trainable parameters. + tf.initialize_all_variables().run() + print 'Initialized!' + # Loop through training steps. + for step in xrange(int(num_epochs * train_size / BATCH_SIZE)): + # Compute the offset of the current minibatch in the data. + # Note that we could use better randomization across epochs. + offset = (step * BATCH_SIZE) % (train_size - BATCH_SIZE) + batch_data = train_data[offset:(offset + BATCH_SIZE), :, :, :] + batch_labels = train_labels[offset:(offset + BATCH_SIZE)] + # This dictionary maps the batch data (as a numpy array) to the + # node in the graph is should be fed to. + feed_dict = {train_data_node: batch_data, + train_labels_node: batch_labels} + # Run the graph and fetch some of the nodes. + _, l, lr, predictions = s.run( + [optimizer, loss, learning_rate, train_prediction], + feed_dict=feed_dict) + if step % 100 == 0: + print 'Epoch %.2f' % (float(step) * BATCH_SIZE / train_size) + print 'Minibatch loss: %.3f, learning rate: %.6f' % (l, lr) + print 'Minibatch error: %.1f%%' % error_rate(predictions, + batch_labels) + print 'Validation error: %.1f%%' % error_rate( + validation_prediction.eval(), validation_labels) + sys.stdout.flush() + # Finally print the result! + test_error = error_rate(test_prediction.eval(), test_labels) + print 'Test error: %.1f%%' % test_error + if FLAGS.self_test: + print 'test_error', test_error + assert test_error == 0.0, 'expected 0.0 test_error, got %.2f' % ( + test_error,) + + +if __name__ == '__main__': + tf.app.run() |