aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/g3doc/tutorials
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2015-12-06 14:53:28 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2015-12-06 14:53:28 -0800
commitf9d3e9d03c69bfac77a2fe1ad80f7c5aa517e0f0 (patch)
tree52302a06eae969c8f4e1d7af6749a85fe0ac4eb1 /tensorflow/g3doc/tutorials
parent40d0d2904e8e00d3c4bf43fa62130eeebceef147 (diff)
TensorFlow: upstream latest changes to git.
Change 109537918 TensorFlow pip setup: wheel >= 0.26 for python3 pip install Change 109505848 Fix distortion default value to 1.0 in fixed_unigram_candidate_sampler. This means we default to the actual provided unigram distribution, instead of to the uniform (as it is currently). Change 109470494 Bugfix in gradients calculation when the ys rely on each other. Change 109467619 Fix CIFAR-10 model to train on all the training data instead of just 80% of it. Fixes #396. Change 109467557 Replaced checkpoint file with binary GraphDef. Change 109467433 Updates to C++ tutorial section. Change 109465269 TensorFlow: update documentation for tutorials to not assume use of bazel (when possible). Change 109462916 A tutorial for image recognition to coincide with the release of the latest Inception image classification model. Change 109462342 Clear control dependencies in variable_scope.get_variable() when creating ops for the initializer. Add tests of various error conditions. Change 109461981 Various performance improvements in low-level node execution code paths. Speeds up ptb_word_lm on my desktop with a Titan X from 3638 words per second to 3751 words per second (3.1% speedup). Changes include: o Avoided many strcmp operations per node execution and extra touches of cache lines in executor.cc, by making all the various IsMerge, IsSwitch, IsSend, etc. operations instead be based on an internal enum value that is pre-computed at Node construction time, rather than doing string comparisons against node->type_string(). We were doing about 6 such comparisons per executed node. o Removed mutex_lock in executor.cc in ExecutorState::Process. The lock was not needed and the comment about the iterations array being potentially resized is not true (the iterations arrays are created with a fixed size). Checked with yuanbyu to confirm this. o Added new two-argument port::Tracing::ScopedAnnotation constructor that takes two StringPiece arguments, and only concatenates them lazily if tracing is enabled. Also changed the code in platform/tracing.{h,cc} so that the ScopedAnnotation constructor and the TraceMe constructor can be inlined. o In BaseGPUDevice::Compute, used the two-argument ScopedAnnotation constructor to avoid doing StrCat(opkernel->name(), ":", op_kernel->type_string()) on every node execution on a GPU. o Introduced a new TensorReference class that just holds a reference to an underlying TensorBuffer, and requires an explicit Unref(). o Changed the EventMgr interface to take a vector of TensorReference objects for EventMgr::ThenDeleteTensors, rather than a vector of Tensor objects. o Used TensorReference in a few places in gpu_util.cc o Minor: switched to using InlinedVectors in a few places to get better cache locality. Change 109456692 Updated the label_image example to use the latest Inception model Change 109456545 Provides classify_image which performs image recognition on a 1000 object label set. $ ./classify_image giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca (score = 0.88493) indri, indris, Indri indri, Indri brevicaudatus (score = 0.00878) lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens (score = 0.00317) custard apple (score = 0.00149) earthstar (score = 0.00127) Change 109455002 TensorFlow: make the helper libraries for various models available in the pip package so that when users type: python translate.py ... the absolute import works. This change is supposed to help make our tutorials run without the *need* to use bazel. Change 109450041 TensorFlow: remove cifar and convolutional binary copies from pip install. Adds embedding and some other models to the list. Change 109448520 Move the description of a failing invariant from a comment into the dcheck-fail message text. Change 109447577 TensorBoard has release tagging (tensorboard/TAG) Also track TensorBoard changes (tensorboard/CHANGES) Change 109444161 Added ParseSingleSequenceExample + python wrappers + unit tests. Change 109440864 Update all the TensorFlow Dockerfiles, and simplify GPU containers. This change updates all four of our Dockerfiles to match the targets discussed in https://github.com/tensorflow/tensorflow/issues/149. The most notable change here is moving the GPU images to use the NVidia containers which include cudnn and other build-time dependencies, dramatically simplifying both the build and run steps. A description of which tags exist and get pushed where will be in a follow-up. Change 109432591 Some pylint and pydoc changes in saver. Change 109430127 Remove unused hydrogen components Change 109419354 The RNN api, although moved into python/ops/, remains undocumented. It may still change at any time. Base CL: 109538006
Diffstat (limited to 'tensorflow/g3doc/tutorials')
-rw-r--r--tensorflow/g3doc/tutorials/image_recognition/index.md461
-rw-r--r--tensorflow/g3doc/tutorials/index.md10
-rw-r--r--tensorflow/g3doc/tutorials/recurrent/index.md24
-rw-r--r--tensorflow/g3doc/tutorials/seq2seq/index.md17
4 files changed, 486 insertions, 26 deletions
diff --git a/tensorflow/g3doc/tutorials/image_recognition/index.md b/tensorflow/g3doc/tutorials/image_recognition/index.md
new file mode 100644
index 0000000000..159a845fd5
--- /dev/null
+++ b/tensorflow/g3doc/tutorials/image_recognition/index.md
@@ -0,0 +1,461 @@
+# Image Recognition
+
+Our brains make vision seem easy. It doesn't take any effort for humans to
+tell apart a lion and a jaguar, read a sign, or recognize a human's face.
+But these are actually hard problems to solve with a computer: they only
+seem easy because our brains are incredibly good at understanding images.
+
+In the last few years, we've made tremendous progress on solving these difficult
+problems with computers. We've found that a kind of model called a deep
+[convolutional neural network](http://colah.github.io/posts/2014-07-Conv-Nets-Modular/)
+can achieve remarkable performance on hard visual recognition tasks --
+matching or exceeding human performance on some problems.
+
+Researchers at Google have gone through many models, repeatedly breaking records
+and setting new state-of-the-art results in computer vision: [QuocNet],
+[AlexNet], [Inception (GoogLeNet)], [BN-Inception-v2] and now [Inception-v3].
+We've published papers describing all these models but they're
+still hard to reproduce. We're now taking things a step further by releasing our
+latest model, Inception-v3.
+
+[QuocNet]: http://static.googleusercontent.com/media/research.google.com/en//archive/unsupervised_icml2012.pdf
+[AlexNet]: http://www.cs.toronto.edu/~fritz/absps/imagenet.pdf
+[Inception (GoogLeNet)]: http://arxiv.org/abs/1409.4842
+[BN-Inception-v2]: http://arxiv.org/abs/1502.03167
+[Inception-v3]: http://arxiv.org/abs/1512.00567
+
+Inception-v3 is trained for the [ImageNet] Large Visual Recognition Challenge
+using the data from 2012. This is a standard task in computer vision,
+where models try to classify entire
+images into [1000 classes], like "Zebra", "Dalmatian", and "Dishwasher".
+For example, here are the results from [AlexNet] classifying some images:
+
+<div style="width:50%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../../images/AlexClassification.png">
+</div>
+
+To compare models, we examine how often the model fails to predict the
+correct answer as one of their top 5 guesses -- termed "top-5 error rate".
+[AlexNet] achieved by setting a top-5 error rate of 15.3% on the 2012
+validation data set; [BN-Inception-v2] achieved 6.66%;
+[Inception-v3] reaches 3.46%.
+
+> How well do humans do on ImageNet Challenge? There's a [blog post] by
+Andrej Karpathy who attempted to measure his own performance. He reached
+5.1% top-5 error rate.
+
+[ImageNet]: http://image-net.org/
+[1000 classes]: http://image-net.org/challenges/LSVRC/2014/browse-synsets
+[blog post]: http://karpathy.github.io/2014/09/02/what-i-learned-from-competing-against-a-convnet-on-imagenet/
+
+This tutorial will teach you how to use [Inception-v3]. You'll learn how to
+classify images into [1000 classes] in Python or C++. You'll learn how
+to run the model on mobile devices. You'll also learn how to extract higher
+level features from this model which may be reused for other vision tasks.
+
+We're excited to see what the community will do with this model.
+
+
+##Usage with Python API
+
+`classify_image.py` downloads the trained model from `tensorflow.org`
+when the program is run for the first time. You'll need about 200M of free space
+available on your hard disk.
+
+The following instructions assume you installed TensorFlow from a PIP package
+and that your terminal resides in the TensorFlow root directory.
+
+ cd tensorflow/models/image/imagenet
+ python classify_image.py
+
+The above command will classify a supplied image of a panda bear.
+
+<div style="width:15%; margin:auto; margin-bottom:10px; margin-top:20px;">
+ <img style="width:100%" src="../../images/cropped_panda.jpg">
+</div>
+
+If the model runs correctly, the script will produce the following output:
+
+ giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca (score = 0.88493)
+ indri, indris, Indri indri, Indri brevicaudatus (score = 0.00878)
+ lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens (score = 0.00317)
+ custard apple (score = 0.00149)
+ earthstar (score = 0.00127)
+
+If you wish to supply other JPEG images, you may do so by editing
+the `--image_file` argument.
+
+> If you download the model data to a different directory, you
+will need to point `--model_dir` to the directory used.
+
+## Usage with the C++ API
+
+You can run the same [Inception-v3] model in
+C++, though it's packaged in a slightly more compact file, because we don't need
+to keep some data that's only used for training. You can download the archive
+containing the GraphDef that defines the model like this (running from the root
+directory of the TensorFlow repository):
+
+```bash
+wget https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip -O tensorflow/examples/label_image/data/inception_dec_2015.zip
+
+unzip tensorflow/examples/label_image/data/inception_dec_2015.zip -d tensorflow/examples/label_image/data/
+```
+
+Next, we need to compile the C++ binary that includes the code to load and run the graph.
+If you've followed [the instructions to download the source installation of
+TensorFlow](http://www.tensorflow.org/versions/master/get_started/os_setup.html#source)
+for your platform, you should be able to build the example by
+running this command from your shell terminal:
+
+```bash
+bazel build tensorflow/examples/label_image/...
+```
+
+That should create a binary executable that you can then run like this:
+
+```bash
+bazel-bin/tensorflow/examples/label_image/label_image
+```
+
+This uses the default example image that ships with the framework, and should
+output something similar to this:
+
+```
+I tensorflow/examples/label_image/main.cc:200] military uniform (866): 0.647296
+I tensorflow/examples/label_image/main.cc:200] suit (794): 0.0477196
+I tensorflow/examples/label_image/main.cc:200] academic gown (896): 0.0232411
+I tensorflow/examples/label_image/main.cc:200] bow tie (817): 0.0157356
+I tensorflow/examples/label_image/main.cc:200] bolo tie (940): 0.0145024
+```
+In this case, we're using the default image of
+[Admiral Grace Hopper](https://en.wikipedia.org/wiki/Grace_Hopper), and you can
+see the network correctly identifies she's wearing a military uniform, with a high
+score of 0.6.
+
+
+<div style="width:45%; margin:auto; margin-bottom:10px; margin-top:20px;">
+ <img style="width:100%" src="../../images/grace_hopper.jpg">
+</div>
+
+Next, try it out on your own images by supplying the --image= argument, e.g.
+
+```bash
+bazel-bin/tensorflow/examples/label_image/label_image --image=my_image.png
+```
+
+If you look inside the [`tensorflow/examples/label_image/main.cc`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/label_image/main.cc)
+file, you can find out
+how it works. We hope this code will help you integrate TensorFlow into
+your own applications, so we will walk step by step through the main functions:
+
+The command line flags control where the files are loaded from, and properties of the input images.
+The model expects to get square 299x299 RGB images, so those are the `input_width`
+and `input_height` flags. We also need to scale the pixel values from integers that
+are between 0 and 255 to the floating point values that the graph operates on.
+We control the scaling with the `input_mean` and `input_std` flags: we first subtract
+`input_mean` from each pixel value, then divide it by `input_std`.
+
+These values probably look somewhat magical, but they are just defined by the
+original model author based on what he/she wanted to use as input images for
+training. If you have a graph that you've trained yourself, you'll just need
+to adjust the values to match whatever you used during your training process.
+
+You can see how they're applied to an image in the [`ReadTensorFromImageFile()`]
+(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/label_image/main.cc#L88)
+function.
+
+```C++
+// Given an image file name, read in the data, try to decode it as an image,
+// resize it to the requested size, and then scale the values as desired.
+Status ReadTensorFromImageFile(string file_name, const int input_height,
+ const int input_width, const float input_mean,
+ const float input_std,
+ std::vector<Tensor>* out_tensors) {
+ tensorflow::GraphDefBuilder b;
+```
+We start by creating a `GraphDefBuilder`, which is an object we can use to
+specify a model to run or load.
+
+```C++
+ string input_name = "file_reader";
+ string output_name = "normalized";
+ tensorflow::Node* file_reader =
+ tensorflow::ops::ReadFile(tensorflow::ops::Const(file_name, b.opts()),
+ b.opts().WithName(input_name));
+```
+We then start creating nodes for the small model we want to run
+to load, resize, and scale the pixel values to get the result the main model
+expects as its input. The first node we create is just a `Const` op that holds a
+tensor with the file name of the image we want to load. That's then passed as the
+first input to the `ReadFile` op. You might notice we're passing `b.opts()` as the last
+argument to all the op creation functions. The argument ensures that the node is added to
+the model definition held in the `GraphDefBuilder`. We also name the `ReadFile`
+operator by making the `WithName()` call to `b.opts()`. This gives a name to the node,
+which isn't strictly necessary since an automatic name will be assigned if you don't
+do this, but it does make debugging a bit easier.
+
+```C++
+ // Now try to figure out what kind of file it is and decode it.
+ const int wanted_channels = 3;
+ tensorflow::Node* image_reader;
+ if (tensorflow::StringPiece(file_name).ends_with(".png")) {
+ image_reader = tensorflow::ops::DecodePng(
+ file_reader,
+ b.opts().WithAttr("channels", wanted_channels).WithName("png_reader"));
+ } else {
+ // Assume if it's not a PNG then it must be a JPEG.
+ image_reader = tensorflow::ops::DecodeJpeg(
+ file_reader,
+ b.opts().WithAttr("channels", wanted_channels).WithName("jpeg_reader"));
+ }
+ // Now cast the image data to float so we can do normal math on it.
+ tensorflow::Node* float_caster = tensorflow::ops::Cast(
+ image_reader, tensorflow::DT_FLOAT, b.opts().WithName("float_caster"));
+ // The convention for image ops in TensorFlow is that all images are expected
+ // to be in batches, so that they're four-dimensional arrays with indices of
+ // [batch, height, width, channel]. Because we only have a single image, we
+ // have to add a batch dimension of 1 to the start with ExpandDims().
+ tensorflow::Node* dims_expander = tensorflow::ops::ExpandDims(
+ float_caster, tensorflow::ops::Const(0, b.opts()), b.opts());
+ // Bilinearly resize the image to fit the required dimensions.
+ tensorflow::Node* resized = tensorflow::ops::ResizeBilinear(
+ dims_expander, tensorflow::ops::Const({input_height, input_width},
+ b.opts().WithName("size")),
+ b.opts());
+ // Subtract the mean and divide by the scale.
+ tensorflow::ops::Div(
+ tensorflow::ops::Sub(
+ resized, tensorflow::ops::Const({input_mean}, b.opts()), b.opts()),
+ tensorflow::ops::Const({input_std}, b.opts()),
+ b.opts().WithName(output_name));
+```
+We then keep adding more nodes, to decode the file data as an image, to cast the
+integers into floating point values, to resize it, and then finally to run the
+subtraction and division operations on the pixel values.
+
+```C++
+ // This runs the GraphDef network definition that we've just constructed, and
+ // returns the results in the output tensor.
+ tensorflow::GraphDef graph;
+ TF_RETURN_IF_ERROR(b.ToGraphDef(&graph));
+```
+At the end of this we have
+a model definition stored in the b variable, which we turn into a full graph
+definition with the `ToGraphDef()` function.
+
+```C++
+ std::unique_ptr<tensorflow::Session> session(
+ tensorflow::NewSession(tensorflow::SessionOptions()));
+ TF_RETURN_IF_ERROR(session->Create(graph));
+ TF_RETURN_IF_ERROR(session->Run({}, {output_name}, {}, out_tensors));
+ return Status::OK();
+```
+Then we create a [`Session`](http://www.tensorflow.org/versions/master/api_docs/cc/ClassSession.html#class-tensorflow-session)
+object, which is the interface to actually running the graph, and run it,
+specifying which node we want to get the output from, and where to put the
+output data.
+
+This gives us a vector of `Tensor` objects, which in this case we know will only be a
+single object long. You can think of a `Tensor` as a multi-dimensional array in this
+context, and it holds a 299 pixel high, 299 pixel width, 3 channel image as float
+values. If you have your own image-processing framework in your product already, you
+should be able to use that instead, as long as you apply the same transformations
+before you feed images into the main graph.
+
+This is a simple example of creating a small TensorFlow graph dynamically in C++,
+but for the pre-trained Inception model we want to load a much larger definition from
+a file. You can see how we do that in the `LoadGraph()` function.
+
+```C++
+// Reads a model graph definition from disk, and creates a session object you
+// can use to run it.
+Status LoadGraph(string graph_file_name,
+ std::unique_ptr<tensorflow::Session>* session) {
+ tensorflow::GraphDef graph_def;
+ Status load_graph_status =
+ ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
+ if (!load_graph_status.ok()) {
+ return tensorflow::errors::NotFound("Failed to load compute graph at '",
+ graph_file_name, "'");
+ }
+```
+If you've looked through the image loading code, a lot of the terms should seem familiar. Rather than
+using a `GraphDefBuilder` to produce a `GraphDef` object, we load a protobuf file that
+directly contains the `GraphDef`.
+
+```C++
+ session->reset(tensorflow::NewSession(tensorflow::SessionOptions()));
+ Status session_create_status = (*session)->Create(graph_def);
+ if (!session_create_status.ok()) {
+ return session_create_status;
+ }
+ return Status::OK();
+}
+```
+Then we create a Session object from that `GraphDef` and
+pass it back to the caller so that they can run it at a later time.
+
+The `GetTopLabels()` function is a lot like the image loading, except that in this case
+we want to take the results of running the main graph, and turn it into a sorted list
+of the highest-scoring labels. Just like the image loader, it creates a
+`GraphDefBuilder`, adds a couple of nodes to it, and then runs the short graph to get a
+pair of output tensors. In this case they represent the sorted scores and index
+positions of the highest results.
+
+```C++
+// Analyzes the output of the Inception graph to retrieve the highest scores and
+// their positions in the tensor, which correspond to categories.
+Status GetTopLabels(const std::vector<Tensor>& outputs, int how_many_labels,
+ Tensor* indices, Tensor* scores) {
+ tensorflow::GraphDefBuilder b;
+ string output_name = "top_k";
+ tensorflow::ops::TopK(tensorflow::ops::Const(outputs[0], b.opts()),
+ how_many_labels, b.opts().WithName(output_name));
+ // This runs the GraphDef network definition that we've just constructed, and
+ // returns the results in the output tensors.
+ tensorflow::GraphDef graph;
+ TF_RETURN_IF_ERROR(b.ToGraphDef(&graph));
+ std::unique_ptr<tensorflow::Session> session(
+ tensorflow::NewSession(tensorflow::SessionOptions()));
+ TF_RETURN_IF_ERROR(session->Create(graph));
+ // The TopK node returns two outputs, the scores and their original indices,
+ // so we have to append :0 and :1 to specify them both.
+ std::vector<Tensor> out_tensors;
+ TF_RETURN_IF_ERROR(session->Run({}, {output_name + ":0", output_name + ":1"},
+ {}, &out_tensors));
+ *scores = out_tensors[0];
+ *indices = out_tensors[1];
+ return Status::OK();
+```
+The `PrintTopLabels()` function takes those sorted results, and prints them out in a
+friendly way. The `CheckTopLabel()` function is very similar, but just makes sure that
+the top label is the one we expect, for debugging purposes.
+
+At the end, [`main()`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/label_image/main.cc#L252)
+ties together all of these calls.
+
+```C++
+int main(int argc, char* argv[]) {
+ // We need to call this to set up global state for TensorFlow.
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+ Status s = tensorflow::ParseCommandLineFlags(&argc, argv);
+ if (!s.ok()) {
+ LOG(ERROR) << "Error parsing command line flags: " << s.ToString();
+ return -1;
+ }
+
+ // First we load and initialize the model.
+ std::unique_ptr<tensorflow::Session> session;
+ string graph_path = tensorflow::io::JoinPath(FLAGS_root_dir, FLAGS_graph);
+ Status load_graph_status = LoadGraph(graph_path, &session);
+ if (!load_graph_status.ok()) {
+ LOG(ERROR) << load_graph_status;
+ return -1;
+ }
+```
+We load the main graph.
+
+```C++
+ // Get the image from disk as a float array of numbers, resized and normalized
+ // to the specifications the main graph expects.
+ std::vector<Tensor> resized_tensors;
+ string image_path = tensorflow::io::JoinPath(FLAGS_root_dir, FLAGS_image);
+ Status read_tensor_status = ReadTensorFromImageFile(
+ image_path, FLAGS_input_height, FLAGS_input_width, FLAGS_input_mean,
+ FLAGS_input_std, &resized_tensors);
+ if (!read_tensor_status.ok()) {
+ LOG(ERROR) << read_tensor_status;
+ return -1;
+ }
+ const Tensor& resized_tensor = resized_tensors[0];
+```
+Load, resize, and process the input image.
+
+```C++
+ // Actually run the image through the model.
+ std::vector<Tensor> outputs;
+ Status run_status = session->Run({{FLAGS_input_layer, resized_tensor}},
+ {FLAGS_output_layer}, {}, &outputs);
+ if (!run_status.ok()) {
+ LOG(ERROR) << "Running model failed: " << run_status;
+ return -1;
+ }
+```
+Here we run the loaded graph with the image as an input.
+
+```C++
+ // This is for automated testing to make sure we get the expected result with
+ // the default settings. We know that label 866 (military uniform) should be
+ // the top label for the Admiral Hopper image.
+ if (FLAGS_self_test) {
+ bool expected_matches;
+ Status check_status = CheckTopLabel(outputs, 866, &expected_matches);
+ if (!check_status.ok()) {
+ LOG(ERROR) << "Running check failed: " << check_status;
+ return -1;
+ }
+ if (!expected_matches) {
+ LOG(ERROR) << "Self-test failed!";
+ return -1;
+ }
+ }
+```
+For testing purposes we can check to make sure we get the output we expect here.
+
+```C++
+ // Do something interesting with the results we've generated.
+ Status print_status = PrintTopLabels(outputs, FLAGS_labels);
+```
+Finally we print the labels we found.
+
+```C++
+ if (!print_status.ok()) {
+ LOG(ERROR) << "Running print failed: " << print_status;
+ return -1;
+ }
+```
+
+The error handling here is using TensorFlow's `Status`
+object, which is very convenient because it lets you know whether any error has
+occurred with the `ok()` checker, and then can be printed out to give a readable error
+message.
+
+In this case we are demonstrating object recognition, but you should be able to
+use very similar code on other models you've found or trained yourself, across
+all
+sorts of domains. We hope this small example gives you some ideas on how to use
+TensorFlow within your own products.
+
+> **EXERCISE**: Transfer learning is the idea that, if you know how to solve a task well, you
+should be able to transfer some of that understanding to solving related
+problems. One way to perform transfer learning is to remove the final
+classification layer of the network and extract
+the [next-to-last layer of the CNN](http://arxiv.org/abs/1310.1531), in this case a 2048 dimensional vector.
+One can specify
+this by setting `--output_layer=pool_3` in the [C++ API example](#usage-with-the-c-api)
+and then
+changing the output tensor handling. Try extracting this feature on a set of images
+and see if you can predict new categories not in ImageNet.
+
+
+## Resources for Learning More
+
+To learn about neural networks in general, Michael Nielsen's
+[free online book](http://neuralnetworksanddeeplearning.com/chap1.html)
+is an excellent resource. For convolutional neural networks in particular,
+Chris Olah has some
+[nice blog posts](http://colah.github.io/posts/2014-07-Conv-Nets-Modular/),
+and Michael Nielsen's book has a
+[great chapter](http://neuralnetworksanddeeplearning.com/chap6.html)
+covering them.
+
+To find out more about implementing convolutional neural networks, you can jump to
+the TensorFlow [deep convolutional networks tutorial](http://www.tensorflow.org/tutorials/deep_cnn/index.html),
+or start a bit more gently with our
+[ML beginner](http://www.tensorflow.org/tutorials/mnist/beginners/index.html)
+or [ML expert](http://www.tensorflow.org/tutorials/mnist/pros/index.html)
+MNIST starter tutorials. Finally, if you want to get up to speed on research
+in this area, you can
+read the recent work of all the papers referenced in this tutorial.
+
diff --git a/tensorflow/g3doc/tutorials/index.md b/tensorflow/g3doc/tutorials/index.md
index ee22806f69..a3d925efeb 100644
--- a/tensorflow/g3doc/tutorials/index.md
+++ b/tensorflow/g3doc/tutorials/index.md
@@ -90,10 +90,14 @@ stuff.
[View Tutorial](../tutorials/mnist/download/index.md)
-## Visual Object Recognition
+## Image Recognition
-We will be releasing our state-of-the-art Inception object recognition model,
-complete and already trained.
+How to run object recognition using a convolutional neural network
+trained on ImageNet Challenge data and label set.
+
+[View Tutorial](../tutorials/image_recognition/index.md)
+
+We will be releasing code for training a state-of-the-art Inception model.
COMING SOON
diff --git a/tensorflow/g3doc/tutorials/recurrent/index.md b/tensorflow/g3doc/tutorials/recurrent/index.md
index dfba27a03f..f88a266542 100644
--- a/tensorflow/g3doc/tutorials/recurrent/index.md
+++ b/tensorflow/g3doc/tutorials/recurrent/index.md
@@ -168,25 +168,17 @@ for i in range(len(num_steps)):
final_state = state
```
-## Compile and Run the Code
+## Run the Code
-First, the library needs to be built. To compile it on CPU:
+We are assuming you have already installed via the pip package, have cloned the
+tensorflow git repository, and are in the root of the git tree. (If building
+from source, build the `tensorflow/models/rnn/ptb:ptb_word_lm` target using
+bazel).
+Next:
```
-bazel build -c opt tensorflow/models/rnn/ptb:ptb_word_lm
-```
-
-And if you have a fast GPU, run the following:
-
-```
-bazel build -c opt --config=cuda tensorflow/models/rnn/ptb:ptb_word_lm
-```
-
-Now we can run the model:
-
-```
-bazel-bin/tensorflow/models/rnn/ptb/ptb_word_lm \
- --data_path=/tmp/simple-examples/data/ --model small
+cd tensorflow/models/rnn/ptb
+python ptb_word_lm --data_path=/tmp/simple-examples/data/ --model small
```
There are 3 supported model configurations in the tutorial code: "small",
diff --git a/tensorflow/g3doc/tutorials/seq2seq/index.md b/tensorflow/g3doc/tutorials/seq2seq/index.md
index 63542dc6ba..a37068b8b4 100644
--- a/tensorflow/g3doc/tutorials/seq2seq/index.md
+++ b/tensorflow/g3doc/tutorials/seq2seq/index.md
@@ -9,16 +9,19 @@ a neural network to translate from English to French? It turns out that
the answer is *yes*.
This tutorial will show you how to build and train such a system end-to-end.
-You can start by running this binary.
+We are assuming you have already installed via the pip package, have cloned the
+tensorflow git repository, and are in the root of the git tree.
+
+You can then start by running the translate program:
```
-bazel run -c opt <...>/models/rnn/translate:translate
- --data_dir [your_data_directory]
+cd tensorflow/models/rnn/translate
+python translate.py --data_dir [your_data_directory]
```
It will download English-to-French translation data from the
[WMT'15 Website](http://www.statmt.org/wmt15/translation-task.html)
-prepare it for training and train. It takes about 20GB of disk space,
+prepare it for training and train. It takes about 20GB of disk space,
and a while to download and prepare (see [later](#run_it) for details),
so you can start and leave it running while reading this tutorial.
@@ -240,7 +243,7 @@ Both data-sets will be downloaded to `data_dir` and training will start,
saving checkpoints in `train_dir`, when this command is run.
```
-bazel run -c opt <...>/models/rnn/translate:translate
+python translate.py
--data_dir [your_data_directory] --train_dir [checkpoints_directory]
--en_vocab_size=40000 --fr_vocab_size=40000
```
@@ -259,7 +262,7 @@ results, but it might take too long or use too much memory for your GPU.
You can request to train a smaller model as in the following example.
```
-bazel run -c opt <...>/models/rnn/translate:translate
+python translate.py
--data_dir [your_data_directory] --train_dir [checkpoints_directory]
--size=256 --num_layers=2 --steps_per_checkpoint=50
```
@@ -296,7 +299,7 @@ point the model can be used for translating English sentences to French
using the `--decode` option.
```
-bazel run -c opt <...>/models/rnn/translate:translate --decode
+python translate.py --decode
--data_dir [your_data_directory] --train_dir [checkpoints_directory]
Reading model parameters from /tmp/translate.ckpt-340000