aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--AUTHORS1
-rw-r--r--README.md1
-rw-r--r--RELEASE.md2
-rw-r--r--tensorflow/contrib/learn/python/learn/README.md24
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn.py10
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn_test.py29
-rw-r--r--tensorflow/contrib/learn/python/learn/tests/dataframe/arithmetic_transform_test.py10
-rw-r--r--tensorflow/contrib/makefile/README.md10
-rw-r--r--tensorflow/core/client/tensor_c_api.cc2
-rw-r--r--tensorflow/core/public/tensor_c_api.h2
-rw-r--r--tensorflow/examples/skflow/digits.py4
-rw-r--r--tensorflow/examples/skflow/dnn_autoencoder_iris.py2
-rw-r--r--tensorflow/examples/skflow/hdf5_classification.py2
-rw-r--r--tensorflow/examples/skflow/iris_custom_model.py2
-rw-r--r--tensorflow/examples/skflow/iris_run_config.py2
-rw-r--r--tensorflow/examples/skflow/iris_val_based_early_stopping.py2
-rw-r--r--tensorflow/examples/skflow/iris_with_pipeline.py2
-rw-r--r--tensorflow/examples/skflow/language_model.py2
-rw-r--r--tensorflow/examples/skflow/mnist_rnn.py2
-rw-r--r--tensorflow/examples/skflow/mnist_weights.py2
-rw-r--r--tensorflow/examples/skflow/multioutput_regression.py2
-rw-r--r--tensorflow/examples/skflow/multiple_gpu.py2
-rw-r--r--tensorflow/examples/skflow/neural_translation.py2
-rw-r--r--tensorflow/examples/skflow/neural_translation_word.py2
-rw-r--r--tensorflow/examples/skflow/out_of_core_data_classification.py2
-rw-r--r--tensorflow/examples/skflow/text_classification.py2
-rw-r--r--tensorflow/examples/skflow/text_classification_builtin_rnn_model.py2
-rw-r--r--tensorflow/examples/skflow/text_classification_character_cnn.py2
-rw-r--r--tensorflow/examples/skflow/text_classification_character_rnn.py2
-rw-r--r--tensorflow/examples/skflow/text_classification_cnn.py2
-rw-r--r--tensorflow/examples/skflow/text_classification_save_restore.py2
-rw-r--r--tensorflow/examples/udacity/README.md17
-rw-r--r--tensorflow/g3doc/get_started/os_setup.md18
-rw-r--r--tensorflow/g3doc/how_tos/quantization/index.md2
-rw-r--r--tensorflow/g3doc/resources/index.md3
-rw-r--r--tensorflow/g3doc/tutorials/mnist/pros/index.md4
-rw-r--r--tensorflow/python/client/session.py2
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_test.py4
-rw-r--r--tensorflow/python/ops/math_ops.py19
39 files changed, 146 insertions, 58 deletions
diff --git a/AUTHORS b/AUTHORS
index e3289a50bc..a46ae7e616 100644
--- a/AUTHORS
+++ b/AUTHORS
@@ -7,3 +7,4 @@
# The email address is not required for organizations.
Google Inc.
+Yuan Tang terrytangyuan@gmail.com
diff --git a/README.md b/README.md
index e640f54774..578b985d64 100644
--- a/README.md
+++ b/README.md
@@ -59,6 +59,7 @@ Hello, TensorFlow!
* [TensorFlow website](http://tensorflow.org)
* [TensorFlow whitepaper](http://download.tensorflow.org/paper/whitepaper2015.pdf)
+* [TensorFlow Model Zoo](https://github.com/tensorflow/models)
* [TensorFlow MOOC on Udacity] (https://www.udacity.com/course/deep-learning--ud730)
The TensorFlow community has created amazing things with TensorFlow, please see the [resources section of tensorflow.org](https://www.tensorflow.org/versions/master/resources#community) for an incomplete list.
diff --git a/RELEASE.md b/RELEASE.md
index 3843d543e9..60d77764c0 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -51,7 +51,7 @@
This release contains contributions from many people at Google, as well as:
-Aaron Schumacher, Aidan Dang, Akihiko ITOH, Aki Sukegawa, Arbit Chen, Aziz Alto, Danijar Hafner, Erik Erwitt, Fabrizio Milo, Felix Maximilian Möller, Henry Saputra, Sung Kim, Igor Babuschkin, Jan Zikes, Jesper Steen Møller, Johannes Mayer, Justin Harris, Kashif Rasul, Kevin Robinson, Loo Rong Jie, Lucas Moura, Łukasz Bieniasz-Krzywiec, Mario Cho, Maxim Grechkin, Michael Heilman, Mostafa Rahmani, Mourad Mourafiq, @ninotoshi, Orion Reblitz-Richardson, Yuncheng Li, @raoqiyu, Robert DiPietro, Sam Abrahams, Sebastian Raschka, Siddharth Agrawal, @snakecharmer1024, Stephen Roller, Sung Kim, SunYeop Lee, Thijs Vogels, Till Hoffmann, Victor Melo, Ville Kallioniemi, Waleed Abdulla, Wenjian Huang, Yaroslav Bulatov, Yeison Rodriguez, Yuan (Terry) Tang, Yuxin Wu, @zhongzyd, Ziming Dong, Zohar Jackson
+Aaron Schumacher, Aidan Dang, Akihiko ITOH, Aki Sukegawa, Arbit Chen, Aziz Alto, Danijar Hafner, Erik Erwitt, Fabrizio Milo, Felix Maximilian Möller, Henry Saputra, Sung Kim, Igor Babuschkin, Jan Zikes, Jeremy Barnes, Jesper Steen Møller, Johannes Mayer, Justin Harris, Kashif Rasul, Kevin Robinson, Loo Rong Jie, Lucas Moura, Łukasz Bieniasz-Krzywiec, Mario Cho, Maxim Grechkin, Michael Heilman, Mostafa Rahmani, Mourad Mourafiq, @ninotoshi, Orion Reblitz-Richardson, Yuncheng Li, @raoqiyu, Robert DiPietro, Sam Abrahams, Sebastian Raschka, Siddharth Agrawal, @snakecharmer1024, Stephen Roller, Sung Kim, SunYeop Lee, Thijs Vogels, Till Hoffmann, Victor Melo, Ville Kallioniemi, Waleed Abdulla, Wenjian Huang, Yaroslav Bulatov, Yeison Rodriguez, Yuan (Terry) Tang, Yuxin Wu, @zhongzyd, Ziming Dong, Zohar Jackson
We are also grateful to all who filed issues or helped resolve them, asked and
answered questions, and were part of inspiring discussions.
diff --git a/tensorflow/contrib/learn/python/learn/README.md b/tensorflow/contrib/learn/python/learn/README.md
index f474eb4e54..2016f53a8a 100644
--- a/tensorflow/contrib/learn/python/learn/README.md
+++ b/tensorflow/contrib/learn/python/learn/README.md
@@ -59,8 +59,8 @@ Simple linear classification:
from sklearn import datasets, metrics
iris = datasets.load_iris()
-classifier = learn.TensorFlowLinearClassifier(n_classes=3)
-classifier.fit(iris.data, iris.target)
+classifier = learn.LinearClassifier(n_classes=3)
+classifier.fit(iris.data, iris.target, steps=200, batch_size=32)
score = metrics.accuracy_score(iris.target, classifier.predict(iris.data))
print("Accuracy: %f" % score)
```
@@ -74,8 +74,8 @@ from sklearn import datasets, metrics, preprocessing
boston = datasets.load_boston()
x = preprocessing.StandardScaler().fit_transform(boston.data)
-regressor = learn.TensorFlowLinearRegressor()
-regressor.fit(x, boston.target)
+regressor = learn.LinearRegressor()
+regressor.fit(x, boston.target, steps=200, batch_size=32)
score = metrics.mean_squared_error(regressor.predict(x), boston.target)
print ("MSE: %f" % score)
```
@@ -88,15 +88,15 @@ Example of 3 layer network with 10, 20 and 10 hidden units respectively:
from sklearn import datasets, metrics
iris = datasets.load_iris()
-classifier = learn.TensorFlowDNNClassifier(hidden_units=[10, 20, 10], n_classes=3)
-classifier.fit(iris.data, iris.target)
+classifier = learn.DNNClassifier(hidden_units=[10, 20, 10], n_classes=3)
+classifier.fit(iris.data, iris.target, steps=200, batch_size=32)
score = metrics.accuracy_score(iris.target, classifier.predict(iris.data))
print("Accuracy: %f" % score)
```
## Custom model
-Example of how to pass a custom model to the TensorFlowEstimator:
+Example of how to pass a custom model to the Estimator:
```python
from sklearn import datasets, metrics
@@ -108,7 +108,7 @@ def my_model(x, y):
layers = learn.ops.dnn(x, [10, 20, 10], dropout=0.5)
return learn.models.logistic_regression(layers, y)
-classifier = learn.TensorFlowEstimator(model_fn=my_model, n_classes=3)
+classifier = learn.Estimator(model_fn=my_model, n_classes=3)
classifier.fit(iris.data, iris.target)
score = metrics.accuracy_score(iris.target, classifier.predict(iris.data))
print("Accuracy: %f" % score)
@@ -116,16 +116,16 @@ print("Accuracy: %f" % score)
## Saving / Restoring models
-Each estimator has a ``save`` method which takes folder path where all model information will be saved. For restoring you can just call ``learn.TensorFlowEstimator.restore(path)`` and it will return object of your class.
+Each estimator has a ``save`` method which takes folder path where all model information will be saved. For restoring you can just call ``learn.Estimator.restore(path)`` and it will return object of your class.
Some example code:
```python
-classifier = learn.TensorFlowLinearRegression()
+classifier = learn.LinearRegressor()
classifier.fit(...)
classifier.save('/tmp/tf_examples/my_model_1/')
-new_classifier = TensorFlowEstimator.restore('/tmp/tf_examples/my_model_2')
+new_classifier = Estimator.restore('/tmp/tf_examples/my_model_2')
new_classifier.predict(...)
```
@@ -134,7 +134,7 @@ new_classifier.predict(...)
To get nice visualizations and summaries you can use ``logdir`` parameter on ``fit``. It will start writing summaries for ``loss`` and histograms for variables in your model. You can also add custom summaries in your custom model function by calling ``tf.summary`` and passing Tensors to report.
```python
-classifier = learn.TensorFlowLinearRegression()
+classifier = learn.LinearRegressor()
classifier.fit(x, y, logdir='/tmp/tf_examples/my_model_1/')
```
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn.py b/tensorflow/contrib/learn/python/learn/estimators/dnn.py
index fdb598efc5..63d103ed35 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn.py
@@ -152,6 +152,11 @@ class DNNClassifier(dnn_linear_combined.DNNLinearCombinedClassifier):
gradient_clip_norm=gradient_clip_norm,
enable_centered_bias=enable_centered_bias,
config=config)
+ self.feature_columns = feature_columns
+ self.optimizer = optimizer
+ self.activation_fn = activation_fn
+ self.dropout = dropout
+ self.hidden_units = hidden_units
self._feature_columns_inferred = False
# TODO(b/29580537): Remove feature_columns inference.
@@ -299,6 +304,11 @@ class DNNRegressor(dnn_linear_combined.DNNLinearCombinedRegressor):
gradient_clip_norm=gradient_clip_norm,
enable_centered_bias=enable_centered_bias,
config=config)
+ self.feature_columns = feature_columns
+ self.optimizer = optimizer
+ self.activation_fn = activation_fn
+ self.dropout = dropout
+ self.hidden_units = hidden_units
self._feature_columns_inferred = False
# TODO(b/29580537): Remove feature_columns inference.
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py
index ea09f71785..6304d06f55 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py
@@ -21,6 +21,13 @@ from __future__ import print_function
import tensorflow as tf
+# pylint: disable=g-import-not-at-top
+try:
+ from sklearn.cross_validation import cross_val_score
+ HAS_SKLEARN = True
+except ImportError:
+ HAS_SKLEARN = False
+
def _iris_input_fn():
iris = tf.contrib.learn.datasets.load_iris()
@@ -59,6 +66,28 @@ class DNNClassifierTest(tf.test.TestCase):
classifier.fit(input_fn=_iris_input_fn, steps=1000)
self.assertFalse('centered_bias_weight' in classifier.get_variable_names())
+ def testSklearnCompatibility(self):
+ """Tests compatibility with sklearn"""
+ if not HAS_SKLEARN:
+ return
+ iris = tf.contrib.learn.datasets.load_iris()
+ kwargs = {
+ "n_classes": 3,
+ "optimizer" : "Adam",
+ "hidden_units" : [3, 4]
+ }
+
+ classifier = tf.contrib.learn.DNNClassifier(**kwargs)
+
+ scores = cross_val_score(
+ classifier,
+ iris.data[1:5],
+ iris.target[1:5],
+ scoring="accuracy",
+ fit_params={"steps": 2}
+ )
+ self.assertAllClose(scores, [1, 1, 1])
+
class DNNRegressorTest(tf.test.TestCase):
diff --git a/tensorflow/contrib/learn/python/learn/tests/dataframe/arithmetic_transform_test.py b/tensorflow/contrib/learn/python/learn/tests/dataframe/arithmetic_transform_test.py
index be8305e3da..2266caeb2f 100644
--- a/tensorflow/contrib/learn/python/learn/tests/dataframe/arithmetic_transform_test.py
+++ b/tensorflow/contrib/learn/python/learn/tests/dataframe/arithmetic_transform_test.py
@@ -20,16 +20,24 @@ from __future__ import division
from __future__ import print_function
import numpy as np
-import pandas as pd
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.dataframe import tensorflow_dataframe as df
+# pylint: disable=g-import-not-at-top
+try:
+ import pandas as pd
+ HAS_PANDAS = True
+except ImportError:
+ HAS_PANDAS = False
+
class SumTestCase(tf.test.TestCase):
"""Test class for `Sum` transform."""
def testSum(self):
+ if not HAS_PANDAS:
+ return
num_rows = 100
pandas_df = pd.DataFrame({"a": np.arange(num_rows),
diff --git a/tensorflow/contrib/makefile/README.md b/tensorflow/contrib/makefile/README.md
index ebaacdfcd9..dff9373c10 100644
--- a/tensorflow/contrib/makefile/README.md
+++ b/tensorflow/contrib/makefile/README.md
@@ -61,7 +61,7 @@ On Ubuntu, you can do this:
```bash
sudo apt-get install autoconf automake libtool curl make g++ unzip
pushd .
-cd tensforflow/contrib/makefile/downloads/protobuf
+cd tensorflow/contrib/makefile/downloads/protobuf
./autogen.sh
./configure
make
@@ -104,7 +104,7 @@ tensorflow/contrib/makefile/gen/bin/benchmark \
## Android
First, you will need to download and unzip the
-[Native Development Kit (NDK)](http://developers.google.com/ndk). You will not
+[Native Development Kit (NDK)](https://developer.android.com/ndk/). You will not
need to install the standalone toolchain, however.
Assign your NDK location to $NDK_ROOT:
@@ -153,7 +153,7 @@ For more details, see the [benchmark documentation](../../tools/benchmark).
## iOS
_Note: To use this library in an iOS application, see related instructions in
-the [iOS examples](../ios_examples/] directory._
+the [iOS examples](../ios_examples/) directory._
Install XCode 7.3 or more recent. If you have not already, you will need to
install the command-line tools using `xcode-select`:
@@ -189,7 +189,7 @@ benchmark program. Although successfully compiling the benchmark program is a
sign of success, the program is not a complete iOS app.
To see TensorFlow running on iOS, the example Xcode project in
-[tensorflow/contrib/ios_example](../ios_example) shows how to use the static
+[tensorflow/contrib/ios_examples](../ios_examples) shows how to use the static
library in a simple app.
### Building by hand
@@ -227,7 +227,7 @@ benchmark program. Although successfully compiling the benchmark program is a
sign of success, the program is not a complete iOS app.
To see TensorFlow running on iOS, the example Xcode project in
-[tensorflow/contrib/ios_example](../ios_example) shows how to use the static
+[tensorflow/contrib/ios_examples](../ios_examples) shows how to use the static
library in a simple app.
#### Universal binaries
diff --git a/tensorflow/core/client/tensor_c_api.cc b/tensorflow/core/client/tensor_c_api.cc
index e8aee7d3b1..dccf66f2ae 100644
--- a/tensorflow/core/client/tensor_c_api.cc
+++ b/tensorflow/core/client/tensor_c_api.cc
@@ -115,7 +115,7 @@ struct TF_Tensor {
TensorBuffer* buffer;
};
-TF_Tensor* TF_NewTensor(TF_DataType dtype, tensorflow::int64* dims,
+TF_Tensor* TF_NewTensor(TF_DataType dtype, const tensorflow::int64* dims,
int num_dims, void* data, size_t len,
void (*deallocator)(void* data, size_t len, void* arg),
void* deallocator_arg) {
diff --git a/tensorflow/core/public/tensor_c_api.h b/tensorflow/core/public/tensor_c_api.h
index 1de5a86503..d01f3a14cc 100644
--- a/tensorflow/core/public/tensor_c_api.h
+++ b/tensorflow/core/public/tensor_c_api.h
@@ -187,7 +187,7 @@ typedef struct TF_Tensor TF_Tensor;
// (*deallocator)(data, len, deallocator_arg)
// Clients must provide a custom deallocator function so they can pass in
// memory managed by something like numpy.
-extern TF_Tensor* TF_NewTensor(TF_DataType, long long* dims, int num_dims,
+extern TF_Tensor* TF_NewTensor(TF_DataType, const long long* dims, int num_dims,
void* data, size_t len,
void (*deallocator)(void* data, size_t len,
void* arg),
diff --git a/tensorflow/examples/skflow/digits.py b/tensorflow/examples/skflow/digits.py
index b3c684b7df..6c9aec52da 100644
--- a/tensorflow/examples/skflow/digits.py
+++ b/tensorflow/examples/skflow/digits.py
@@ -1,4 +1,4 @@
-# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -54,6 +54,6 @@ val_monitor = monitors.ValidationMonitor(X_val, y_val, every_n_steps=50)
classifier = learn.TensorFlowEstimator(model_fn=conv_model, n_classes=10,
steps=1000, learning_rate=0.05,
batch_size=128)
-classifier.fit(X_train, y_train, val_monitor)
+classifier.fit(X_train, y_train, monitors=[val_monitor])
score = metrics.accuracy_score(y_test, classifier.predict(X_test))
print('Test Accuracy: {0:f}'.format(score))
diff --git a/tensorflow/examples/skflow/dnn_autoencoder_iris.py b/tensorflow/examples/skflow/dnn_autoencoder_iris.py
index c4383ae608..284bd9e58a 100644
--- a/tensorflow/examples/skflow/dnn_autoencoder_iris.py
+++ b/tensorflow/examples/skflow/dnn_autoencoder_iris.py
@@ -1,4 +1,4 @@
-# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tensorflow/examples/skflow/hdf5_classification.py b/tensorflow/examples/skflow/hdf5_classification.py
index edcce6fe6f..50e7d73b95 100644
--- a/tensorflow/examples/skflow/hdf5_classification.py
+++ b/tensorflow/examples/skflow/hdf5_classification.py
@@ -1,4 +1,4 @@
-# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tensorflow/examples/skflow/iris_custom_model.py b/tensorflow/examples/skflow/iris_custom_model.py
index afce504b74..8e2ab2ec88 100644
--- a/tensorflow/examples/skflow/iris_custom_model.py
+++ b/tensorflow/examples/skflow/iris_custom_model.py
@@ -1,4 +1,4 @@
-# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tensorflow/examples/skflow/iris_run_config.py b/tensorflow/examples/skflow/iris_run_config.py
index de9b44d460..6ca563e9a3 100644
--- a/tensorflow/examples/skflow/iris_run_config.py
+++ b/tensorflow/examples/skflow/iris_run_config.py
@@ -1,4 +1,4 @@
-# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tensorflow/examples/skflow/iris_val_based_early_stopping.py b/tensorflow/examples/skflow/iris_val_based_early_stopping.py
index 70dd8053aa..c80a0ccca1 100644
--- a/tensorflow/examples/skflow/iris_val_based_early_stopping.py
+++ b/tensorflow/examples/skflow/iris_val_based_early_stopping.py
@@ -1,4 +1,4 @@
-# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tensorflow/examples/skflow/iris_with_pipeline.py b/tensorflow/examples/skflow/iris_with_pipeline.py
index ee5f9aed81..c548387f38 100644
--- a/tensorflow/examples/skflow/iris_with_pipeline.py
+++ b/tensorflow/examples/skflow/iris_with_pipeline.py
@@ -1,4 +1,4 @@
-# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tensorflow/examples/skflow/language_model.py b/tensorflow/examples/skflow/language_model.py
index dcd65bf9f6..7ee709fd91 100644
--- a/tensorflow/examples/skflow/language_model.py
+++ b/tensorflow/examples/skflow/language_model.py
@@ -1,6 +1,6 @@
# encoding: utf-8
-# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tensorflow/examples/skflow/mnist_rnn.py b/tensorflow/examples/skflow/mnist_rnn.py
index a6a594fad5..ddd6d7910f 100644
--- a/tensorflow/examples/skflow/mnist_rnn.py
+++ b/tensorflow/examples/skflow/mnist_rnn.py
@@ -1,4 +1,4 @@
-# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tensorflow/examples/skflow/mnist_weights.py b/tensorflow/examples/skflow/mnist_weights.py
index 9ad019f9a4..37d527c42c 100644
--- a/tensorflow/examples/skflow/mnist_weights.py
+++ b/tensorflow/examples/skflow/mnist_weights.py
@@ -1,4 +1,4 @@
-# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+#t Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tensorflow/examples/skflow/multioutput_regression.py b/tensorflow/examples/skflow/multioutput_regression.py
index c0ddf1cf30..ef76a6ce27 100644
--- a/tensorflow/examples/skflow/multioutput_regression.py
+++ b/tensorflow/examples/skflow/multioutput_regression.py
@@ -1,4 +1,4 @@
-# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tensorflow/examples/skflow/multiple_gpu.py b/tensorflow/examples/skflow/multiple_gpu.py
index 1168184a38..50e4b8252e 100644
--- a/tensorflow/examples/skflow/multiple_gpu.py
+++ b/tensorflow/examples/skflow/multiple_gpu.py
@@ -1,4 +1,4 @@
-# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tensorflow/examples/skflow/neural_translation.py b/tensorflow/examples/skflow/neural_translation.py
index 7832767145..ded54608ba 100644
--- a/tensorflow/examples/skflow/neural_translation.py
+++ b/tensorflow/examples/skflow/neural_translation.py
@@ -1,6 +1,6 @@
# encoding: utf-8
-# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tensorflow/examples/skflow/neural_translation_word.py b/tensorflow/examples/skflow/neural_translation_word.py
index 90c73f0ba5..185835c139 100644
--- a/tensorflow/examples/skflow/neural_translation_word.py
+++ b/tensorflow/examples/skflow/neural_translation_word.py
@@ -1,6 +1,6 @@
# encoding: utf-8
-# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tensorflow/examples/skflow/out_of_core_data_classification.py b/tensorflow/examples/skflow/out_of_core_data_classification.py
index 5f612db3d7..5ed6033cc0 100644
--- a/tensorflow/examples/skflow/out_of_core_data_classification.py
+++ b/tensorflow/examples/skflow/out_of_core_data_classification.py
@@ -1,4 +1,4 @@
-# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tensorflow/examples/skflow/text_classification.py b/tensorflow/examples/skflow/text_classification.py
index fe19e273d3..3d34617016 100644
--- a/tensorflow/examples/skflow/text_classification.py
+++ b/tensorflow/examples/skflow/text_classification.py
@@ -1,4 +1,4 @@
-# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tensorflow/examples/skflow/text_classification_builtin_rnn_model.py b/tensorflow/examples/skflow/text_classification_builtin_rnn_model.py
index fef5a2d9b3..afaa0bfff7 100644
--- a/tensorflow/examples/skflow/text_classification_builtin_rnn_model.py
+++ b/tensorflow/examples/skflow/text_classification_builtin_rnn_model.py
@@ -1,4 +1,4 @@
-# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tensorflow/examples/skflow/text_classification_character_cnn.py b/tensorflow/examples/skflow/text_classification_character_cnn.py
index 998ed30807..be627f316e 100644
--- a/tensorflow/examples/skflow/text_classification_character_cnn.py
+++ b/tensorflow/examples/skflow/text_classification_character_cnn.py
@@ -1,4 +1,4 @@
-# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tensorflow/examples/skflow/text_classification_character_rnn.py b/tensorflow/examples/skflow/text_classification_character_rnn.py
index a3de8aa42b..864f678d4e 100644
--- a/tensorflow/examples/skflow/text_classification_character_rnn.py
+++ b/tensorflow/examples/skflow/text_classification_character_rnn.py
@@ -1,4 +1,4 @@
-# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tensorflow/examples/skflow/text_classification_cnn.py b/tensorflow/examples/skflow/text_classification_cnn.py
index 0cbed33ef1..46238d2f03 100644
--- a/tensorflow/examples/skflow/text_classification_cnn.py
+++ b/tensorflow/examples/skflow/text_classification_cnn.py
@@ -1,4 +1,4 @@
-# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tensorflow/examples/skflow/text_classification_save_restore.py b/tensorflow/examples/skflow/text_classification_save_restore.py
index 9cabc32205..2b2831eb52 100644
--- a/tensorflow/examples/skflow/text_classification_save_restore.py
+++ b/tensorflow/examples/skflow/text_classification_save_restore.py
@@ -1,4 +1,4 @@
-# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tensorflow/examples/udacity/README.md b/tensorflow/examples/udacity/README.md
index 98edc71e59..4743ab557b 100644
--- a/tensorflow/examples/udacity/README.md
+++ b/tensorflow/examples/udacity/README.md
@@ -6,7 +6,7 @@ Course information can be found at https://www.udacity.com/course/deep-learning-
Running the Docker container from the Google Cloud repository
-------------------------------------------------------------
- docker run -p 8888:8888 -it --rm b.gcr.io/tensorflow-udacity/assignments:0.5.0
+ docker run -p 8888:8888 -it b.gcr.io/tensorflow-udacity/assignments:0.5.0
Accessing the Notebooks
-----------------------
@@ -19,6 +19,21 @@ On mac, find the virtual machine's IP using:
Then go to: http://IP:8888 (likely http://192.168.99.100:8888)
+Saving Your Progress
+--------------------
+
+Because of the `--rm` flag above, stopping the docker container removes it, so any changes you've made will disappear. One way around this is to remove the `--rm` flag, and name the container for easy restarting:
+```sh
+# you only need to "run" the container the first time:
+docker run -p 8888:8888 -it --name tensorflow-udacity b.gcr.io/tensorflow-udacity/assignments:0.5.0
+# …do various things…
+# when you're done, control-C to kill jupyter and stop the container
+# when you're ready to do more things, you can now just "start" the container:
+docker start -ai tensorflow-udacity
+# …do more things…
+# …repeat…
+```
+
FAQ
---
diff --git a/tensorflow/g3doc/get_started/os_setup.md b/tensorflow/g3doc/get_started/os_setup.md
index 923535144b..e1cece4faa 100644
--- a/tensorflow/g3doc/get_started/os_setup.md
+++ b/tensorflow/g3doc/get_started/os_setup.md
@@ -63,7 +63,7 @@ Then, select the correct binary to install:
# Ubuntu/Linux 64-bit, CPU only, Python 2.7
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp27-none-linux_x86_64.whl
-# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7
+# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7
# Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below.
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp27-none-linux_x86_64.whl
@@ -73,14 +73,14 @@ $ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/tensorflow-
# Ubuntu/Linux 64-bit, CPU only, Python 3.4
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp34-cp34m-linux_x86_64.whl
-# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4
+# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4
# Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below.
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp34-cp34m-linux_x86_64.whl
# Ubuntu/Linux 64-bit, CPU only, Python 3.5
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp35-cp35m-linux_x86_64.whl
-# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5
+# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5
# Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below.
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp35-cp35m-linux_x86_64.whl
@@ -153,7 +153,7 @@ Now, install TensorFlow just as you would for a regular Pip installation. First
# Ubuntu/Linux 64-bit, CPU only, Python 2.7
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp27-none-linux_x86_64.whl
-# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7
+# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7
# Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below.
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp27-none-linux_x86_64.whl
@@ -163,14 +163,14 @@ Now, install TensorFlow just as you would for a regular Pip installation. First
# Ubuntu/Linux 64-bit, CPU only, Python 3.4
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp34-cp34m-linux_x86_64.whl
-# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4
+# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4
# Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below.
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp34-cp34m-linux_x86_64.whl
# Ubuntu/Linux 64-bit, CPU only, Python 3.5
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp35-cp35m-linux_x86_64.whl
-# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5
+# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5
# Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below.
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp35-cp35m-linux_x86_64.whl
@@ -277,7 +277,7 @@ Now, install TensorFlow just as you would for a regular Pip installation. First
# Ubuntu/Linux 64-bit, CPU only, Python 2.7
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp27-none-linux_x86_64.whl
-# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7
+# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7
# Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below.
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp27-none-linux_x86_64.whl
@@ -287,14 +287,14 @@ Now, install TensorFlow just as you would for a regular Pip installation. First
# Ubuntu/Linux 64-bit, CPU only, Python 3.4
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp34-cp34m-linux_x86_64.whl
-# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4
+# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4
# Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below.
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp34-cp34m-linux_x86_64.whl
# Ubuntu/Linux 64-bit, CPU only, Python 3.5
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp35-cp35m-linux_x86_64.whl
-# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5
+# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5
# Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below.
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp35-cp35m-linux_x86_64.whl
diff --git a/tensorflow/g3doc/how_tos/quantization/index.md b/tensorflow/g3doc/how_tos/quantization/index.md
index 0431a7ad61..61461822a0 100644
--- a/tensorflow/g3doc/how_tos/quantization/index.md
+++ b/tensorflow/g3doc/how_tos/quantization/index.md
@@ -6,7 +6,7 @@ were the top priorities. Using floating point arithmetic was the easiest way to
preserve accuracy, and GPUs were well-equipped to accelerate those calculations,
so it's natural that not much attention was paid to other numerical formats.
-These days, we actually have a lot of models being being deployed in commercial
+These days, we actually have a lot of models being deployed in commercial
applications. The computation demands of training grow with the number of
researchers, but the cycles needed for inference expand in proportion to users.
That means pure inference efficiency has become a burning issue for a lot of
diff --git a/tensorflow/g3doc/resources/index.md b/tensorflow/g3doc/resources/index.md
index 249ec50327..2c5d06946c 100644
--- a/tensorflow/g3doc/resources/index.md
+++ b/tensorflow/g3doc/resources/index.md
@@ -33,8 +33,9 @@ something amazing with TensorFlow, we'd like to hear about it!
The TensorFlow community has created many great projects around TensorFlow, including:
+* [@jtoy's awesome "Awesome TensorFlow" list of awesome things](https://github.com/jtoy/awesome-tensorflow)
* [TensorFlow tutorials](https://github.com/pkmital/tensorflow_tutorials)
-* [Scikit Flow - Simplified Interface for TensorFlow](https://github.com/tensorflow/skflow)
+* [Scikit Flow - Simplified Interface for TensorFlow](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/learn/python/learn)
* [Caffe to TensorFlow model converter](https://github.com/ethereon/caffe-tensorflow)
### Development
diff --git a/tensorflow/g3doc/tutorials/mnist/pros/index.md b/tensorflow/g3doc/tutorials/mnist/pros/index.md
index 12de1df66c..324a29c02e 100644
--- a/tensorflow/g3doc/tutorials/mnist/pros/index.md
+++ b/tensorflow/g3doc/tutorials/mnist/pros/index.md
@@ -232,7 +232,7 @@ print(accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
## Build a Multilayer Convolutional Network
-Getting 91% accuracy on MNIST is bad. It's almost embarrassingly bad. In this
+Getting 92% accuracy on MNIST is bad. It's almost embarrassingly bad. In this
section, we'll fix that, jumping from a very simple model to something
moderately sophisticated: a small convolutional neural network. This will get us
to around 99.2% accuracy -- not state of the art, but respectable.
@@ -243,7 +243,7 @@ To create this model, we're going to need to create a lot of weights and biases.
One should generally initialize weights with a small amount of noise for
symmetry breaking, and to prevent 0 gradients. Since we're using ReLU neurons,
it is also good practice to initialize them with a slightly positive initial
-bias to avoid "dead neurons." Instead of doing this repeatedly while we build
+bias to avoid "dead neurons". Instead of doing this repeatedly while we build
the model, let's create two handy functions to do it for us.
```python
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 5f0549463f..20658fa632 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -205,7 +205,7 @@ class BaseSession(SessionInterface):
Use with the `with` keyword to specify that calls to
[`Operation.run()`](../../api_docs/python/framework.md#Operation.run) or
- [`Tensor.run()`](../../api_docs/python/framework.md#Tensor.run) should be
+ [`Tensor.eval()`](../../api_docs/python/framework.md#Tensor.eval) should be
executed in this session.
```python
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index 7f1be574bb..093da97469 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -215,6 +215,7 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3)
self._compareBothSparse(x, np.tanh, tf.tanh)
self._compareBothSparse(y, np.sign, tf.sign)
+ self._compareBothSparse(x, np.vectorize(math.erf), tf.erf)
def testFloatTanhEdge(self):
x = np.arange(40, 40 + 6).reshape(6).astype(np.float32)
@@ -254,6 +255,7 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBothSparse(x, np.sqrt, tf.sqrt, tol=1e-3)
self._compareBothSparse(x, np.tanh, tf.tanh)
self._compareBothSparse(x, np.sign, tf.sign)
+ self._compareBothSparse(x, np.sign, tf.erf)
def testDoubleBasic(self):
x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float64)
@@ -292,6 +294,7 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3)
self._compareBothSparse(x, np.tanh, tf.tanh)
self._compareBothSparse(y, np.sign, tf.sign)
+ self._compareBothSparse(x, np.vectorize(math.erf), tf.erf)
def testHalfBasic(self):
x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float16)
@@ -325,6 +328,7 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3)
self._compareBothSparse(x, np.tanh, tf.tanh)
self._compareBothSparse(y, np.sign, tf.sign)
+ self._compareBothSparse(x, np.vectorize(math.erf), tf.erf, tol=1e-3)
def testInt32Basic(self):
x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int32)
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 0bcf45db76..07d93160ad 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -348,6 +348,25 @@ def sqrt(x, name=None):
return gen_math_ops.sqrt(x, name=name)
+def erf(x, name=None):
+ """Computes the Gauss error function of `x` element-wise.
+
+ Args:
+ x: A `Tensor` of `SparseTensor`. Must be one of the following types: `half`,
+ `float32`, `float64`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
+ """
+ with ops.op_scope([x], name, "Erf") as name:
+ if isinstance(x, ops.SparseTensor):
+ x_erf = gen_math_ops.erf(x.values, name=name)
+ return ops.SparseTensor(indices=x.indices, values=x_erf, shape=x.shape)
+ else:
+ return gen_math_ops.erf(x, name=name)
+
+
def complex_abs(x, name=None):
r"""Computes the complex absolute value of a tensor.