diff options
author | Illia Polosukhin <ilblackdragon@gmail.com> | 2016-04-20 17:36:50 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-04-20 18:41:27 -0700 |
commit | 2f190ff993d45bf98aeb0931b90705c8f6fbb18e (patch) | |
tree | 672b3eaef1ebac3f8c857dd39152094960f67f29 /tensorflow/g3doc/contrib | |
parent | 0eab496110d7efe03d9dcac01b2922b591aa5833 (diff) |
tf.learn: Clean up documentation, remove contrib/learn/g3doc and move it to g3doc/contrib/learn. Add api docs for learn (need more work).
Change: 120399615
Diffstat (limited to 'tensorflow/g3doc/contrib')
-rw-r--r-- | tensorflow/g3doc/contrib/learn/get_started/index.md | 112 | ||||
-rw-r--r-- | tensorflow/g3doc/contrib/learn/index.md | 30 |
2 files changed, 142 insertions, 0 deletions
diff --git a/tensorflow/g3doc/contrib/learn/get_started/index.md b/tensorflow/g3doc/contrib/learn/get_started/index.md new file mode 100644 index 0000000000..f34c3456cf --- /dev/null +++ b/tensorflow/g3doc/contrib/learn/get_started/index.md @@ -0,0 +1,112 @@ +# Introduction + +Below are few simple examples of the API to get you started with TensorFlow Learn. +For more examples, please see [examples](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/skflow). + +## General tips + +- It's useful to re-scale dataset before passing to estimator to 0 mean and unit standard deviation. Stochastic Gradient Descent doesn't always do the right thing when variable are very different scale. + +- Categorical variables should be managed before passing input to the estimator. + +## Linear Classifier + +Simple linear classification: + + from tensorflow.contrib import learn + from sklearn import datasets, metrics + + iris = datasets.load_iris() + classifier = learn.TensorFlowLinearClassifier(n_classes=3) + classifier.fit(iris.data, iris.target) + score = metrics.accuracy_score(iris.target, classifier.predict(iris.data)) + print("Accuracy: %f" % score) + +## Linear Regressor + +Simple linear regression: + + from tensorflow.contrib import learn + from sklearn import datasets, metrics, preprocessing + + boston = datasets.load_boston() + X = preprocessing.StandardScaler().fit_transform(boston.data) + regressor = learn.TensorFlowLinearRegressor() + regressor.fit(X, boston.target) + score = metrics.mean_squared_error(regressor.predict(X), boston.target) + print ("MSE: %f" % score) + +## Deep Neural Network + +Example of 3 layer network with 10, 20 and 10 hidden units respectively: + + from tensorflow.contrib import learn + from sklearn import datasets, metrics + + iris = datasets.load_iris() + classifier = learn.TensorFlowDNNClassifier(hidden_units=[10, 20, 10], n_classes=3) + classifier.fit(iris.data, iris.target) + score = metrics.accuracy_score(iris.target, classifier.predict(iris.data)) + print("Accuracy: %f" % score) + +## Custom model + +Example of how to pass a custom model to the TensorFlowEstimator: + + from tensorflow.contrib import learn + from sklearn import datasets, metrics + + iris = datasets.load_iris() + + def my_model(X, y): + """This is DNN with 10, 20, 10 hidden layers, and dropout of 0.5 probability.""" + layers = learn.ops.dnn(X, [10, 20, 10], keep_prob=0.5) + return learn.models.logistic_regression(layers, y) + + classifier = learn.TensorFlowEstimator(model_fn=my_model, n_classes=3) + classifier.fit(iris.data, iris.target) + score = metrics.accuracy_score(iris.target, classifier.predict(iris.data)) + print("Accuracy: %f" % score) + +## Saving / Restoring models + +Each estimator has a ``save`` method which takes folder path where all model information will be saved. For restoring you can just call ``learn.TensorFlowEstimator.restore(path)`` and it will return object of your class. + +Some example code: + + from tensorflow.contrib import learn + + classifier = learn.TensorFlowLinearRegression() + classifier.fit(...) + classifier.save('/tmp/tf_examples/my_model_1/') + + new_classifier = TensorFlowEstimator.restore('/tmp/tf_examples/my_model_2') + new_classifier.predict(...) + +## Summaries + +To get nice visualizations and summaries you can use ``logdir`` parameter on ``fit``. It will start writing summaries for ``loss`` and histograms for variables in your model. You can also add custom summaries in your custom model function by calling ``tf.summary`` and passing Tensors to report. + + classifier = learn.TensorFlowLinearRegression() + classifier.fit(X, y, logdir='/tmp/tf_examples/my_model_1/') + +Then run next command in command line: + + tensorboard --logdir=/tmp/tf_examples/my_model_1 + +and follow reported url. + +Graph visualization: Text classification RNN Graph image + +Loss visualization: Text classification RNN Loss image + + +## More examples + +See [examples folder](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/skflow) for: + +- Easy way to handle categorical variables - words are just an example of categorical variable. +- Text Classification - see examples for RNN, CNN on word and characters. +- Language modeling and text sequence to sequence. +- Images (CNNs) - see example for digit recognition. +- More & deeper - different examples showing DNNs and CNNs diff --git a/tensorflow/g3doc/contrib/learn/index.md b/tensorflow/g3doc/contrib/learn/index.md new file mode 100644 index 0000000000..7d77dccca1 --- /dev/null +++ b/tensorflow/g3doc/contrib/learn/index.md @@ -0,0 +1,30 @@ +# TensorFlow Learn + +This is an API for building learning models with TensorFlow. +This library covers variety of needs from linear models to *Deep Learning* +applications like text and image understanding. + +## Get Started + +[View Introduction](get_started/index.md) + +## Tutorials + +- [Introduction to Scikit Flow and why you want to start learning + TensorFlow](https://medium.com/@ilblackdragon/tensorflow-tutorial-part-1-c559c63c0cb1) +- [DNNs, custom model and Digit recognition + examples](https://medium.com/@ilblackdragon/tensorflow-tutorial-part-2-9ffe47049c92>) +- [Categorical variables: One hot vs Distributed + representation](https://medium.com/@ilblackdragon/tensorflow-tutorial-part-3-c5fc0662bc08>) +- More coming soon. + +## Community + +- Twitter [#skflow](https://twitter.com/search?q=skflow&src=typd>). +- StackOverflow with +[skflow tag](http://stackoverflow.com/questions/tagged/skflow>) +for questions and struggles. +- Github [issues](https://github.com/tensorflow/tensorflow/issues>) +for technical discussions and feature requests. +- [Gitter channel](https://gitter.im/tensorflow/skflow>) +for non-trivial discussions. |