aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2016-05-23 11:39:39 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-23 12:42:36 -0700
commit892ca4ddc12852a7b4633fd08f163941356cb4e6 (patch)
treebe913f46bb9323685c5a807a89fca6dc52a25504
parent76d90938f95a14a5723c253ec8529e93939a25e2 (diff)
Merge changes from github.
Change: 123026122
-rw-r--r--tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc3
-rw-r--r--tensorflow/contrib/learn/python/learn/README.md2
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/__init__.py3
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/_sklearn.py53
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/autoencoder.py116
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/base.py18
-rw-r--r--tensorflow/contrib/learn/python/learn/io/data_feeder.py25
-rw-r--r--tensorflow/contrib/learn/python/learn/models.py31
-rw-r--r--tensorflow/contrib/learn/python/learn/ops/__init__.py1
-rw-r--r--tensorflow/contrib/learn/python/learn/ops/autoencoder_ops.py56
-rw-r--r--tensorflow/contrib/learn/python/learn/tests/test_custom_decay.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/tests/test_data_feeder.py20
-rw-r--r--tensorflow/contrib/learn/python/learn/tests/test_estimators.py5
-rw-r--r--tensorflow/contrib/learn/python/learn/tests/test_grid_search.py16
-rw-r--r--tensorflow/contrib/learn/python/learn/tests/test_nonlinear.py10
-rw-r--r--tensorflow/contrib/learn/python/learn/tests/test_saver.py1
-rw-r--r--tensorflow/contrib/metrics/python/kernel_tests/confusion_matrix_ops_test.py28
-rw-r--r--tensorflow/contrib/metrics/python/ops/confusion_matrix_ops.py30
-rw-r--r--tensorflow/core/BUILD25
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.cc5
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc2
-rw-r--r--tensorflow/core/kernels/conv_grad_ops.cc28
-rw-r--r--tensorflow/core/kernels/conv_ops.cc1
-rw-r--r--tensorflow/core/kernels/cudnn_pooling_gpu.h2
-rw-r--r--tensorflow/core/kernels/cwise_op_abs.cc3
-rw-r--r--tensorflow/core/kernels/cwise_op_add.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_complex.cc18
-rw-r--r--tensorflow/core/kernels/cwise_op_conj.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_cos.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_div.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_equal_to.cc5
-rw-r--r--tensorflow/core/kernels/cwise_op_exp.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_complex.cu.cc2
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_equal_to.cu.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_imag.cu.cc2
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_not_equal_to.cu.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_real.cu.cc2
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_select.cu.cc1
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_zeta.cu.cc27
-rw-r--r--tensorflow/core/kernels/cwise_op_imag.cc18
-rw-r--r--tensorflow/core/kernels/cwise_op_inverse.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_log.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_mul.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_neg.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_not_equal_to.cc5
-rw-r--r--tensorflow/core/kernels/cwise_op_pow.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_real.cc19
-rw-r--r--tensorflow/core/kernels/cwise_op_rsqrt.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_select.cc1
-rw-r--r--tensorflow/core/kernels/cwise_op_sigmoid.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_sign.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_sin.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_sqrt.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_square.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_sub.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_tanh.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_zeta.cc21
-rw-r--r--tensorflow/core/kernels/cwise_ops.h6
-rw-r--r--tensorflow/core/kernels/cwise_ops_common.h7
-rw-r--r--tensorflow/core/kernels/cwise_ops_gpu_common.cu.h4
-rw-r--r--tensorflow/core/kernels/diag_op.cc2
-rw-r--r--tensorflow/core/kernels/edit_distance_op.cc10
-rw-r--r--tensorflow/core/kernels/matmul_op.cc3
-rw-r--r--tensorflow/core/kernels/queue_base.cc2
-rw-r--r--tensorflow/core/kernels/reduction_ops_gpu.cu.cc1
-rw-r--r--tensorflow/core/kernels/reduction_ops_sum.cc17
-rw-r--r--tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc1
-rw-r--r--tensorflow/core/kernels/transpose_functor_cpu.cc4
-rw-r--r--tensorflow/core/kernels/xent_op.cc4
-rw-r--r--tensorflow/core/kernels/xent_op_gpu.cu.cc3
-rw-r--r--tensorflow/core/ops/array_ops.cc43
-rw-r--r--tensorflow/core/ops/math_grad.cc4
-rw-r--r--tensorflow/core/ops/math_ops.cc88
-rw-r--r--tensorflow/core/ops/ops.pbtxt2
-rw-r--r--tensorflow/core/protobuf/worker.proto2
-rw-r--r--tensorflow/examples/skflow/README.md19
-rw-r--r--tensorflow/examples/skflow/boston.py2
-rw-r--r--tensorflow/examples/skflow/digits.py10
-rw-r--r--tensorflow/examples/skflow/dnn_autoencoder_iris.py35
-rw-r--r--tensorflow/examples/skflow/hdf5_classification.py49
-rw-r--r--tensorflow/examples/skflow/iris.py6
-rw-r--r--tensorflow/examples/skflow/iris_custom_model.py8
-rw-r--r--tensorflow/examples/skflow/iris_run_config.py6
-rw-r--r--tensorflow/examples/skflow/iris_save_restore.py6
-rw-r--r--tensorflow/examples/skflow/iris_val_based_early_stopping.py10
-rw-r--r--tensorflow/examples/skflow/iris_with_pipeline.py4
-rw-r--r--tensorflow/examples/skflow/language_model.py22
-rw-r--r--tensorflow/examples/skflow/mnist.py16
-rw-r--r--tensorflow/examples/skflow/mnist_weights.py16
-rw-r--r--tensorflow/examples/skflow/multioutput_regression.py10
-rw-r--r--tensorflow/examples/skflow/multiple_gpu.py8
-rw-r--r--tensorflow/examples/skflow/neural_translation.py16
-rw-r--r--tensorflow/examples/skflow/neural_translation_word.py18
-rw-r--r--tensorflow/examples/skflow/out_of_core_data_classification.py4
-rw-r--r--tensorflow/examples/skflow/resnet.py20
-rw-r--r--tensorflow/examples/skflow/text_classification.py18
-rw-r--r--tensorflow/examples/skflow/text_classification_builtin_rnn_model.py12
-rw-r--r--tensorflow/examples/skflow/text_classification_character_cnn.py16
-rw-r--r--tensorflow/examples/skflow/text_classification_character_rnn.py14
-rw-r--r--tensorflow/examples/skflow/text_classification_cnn.py16
-rw-r--r--tensorflow/examples/skflow/text_classification_save_restore.py20
-rw-r--r--tensorflow/examples/udacity/1_notmnist.ipynb22
-rw-r--r--tensorflow/examples/udacity/5_word2vec.ipynb5
-rw-r--r--tensorflow/g3doc/api_docs/python/constant_op.md4
-rw-r--r--tensorflow/g3doc/api_docs/python/control_flow_ops.md4
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/tf.diag.md2
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/tf.diag_part.md2
-rw-r--r--tensorflow/g3doc/api_docs/python/math_ops.md4
-rw-r--r--tensorflow/g3doc/api_docs/python/nn.md6
-rw-r--r--tensorflow/g3doc/api_docs/python/train.md4
-rw-r--r--tensorflow/g3doc/get_started/os_setup.md61
-rw-r--r--tensorflow/g3doc/how_tos/adding_an_op/index.md2
-rw-r--r--tensorflow/g3doc/how_tos/distributed/index.md8
-rw-r--r--tensorflow/g3doc/how_tos/threading_and_queues/index.md12
-rw-r--r--tensorflow/g3doc/resources/dims_types.md1
-rwxr-xr-xtensorflow/g3doc/tutorials/mandelbrot/index.md7
-rw-r--r--tensorflow/python/__init__.py4
-rw-r--r--tensorflow/python/kernel_tests/conv_ops_test.py27
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_test.py130
-rw-r--r--tensorflow/python/kernel_tests/diag_op_test.py45
-rw-r--r--tensorflow/python/kernel_tests/gradient_checker.py2
-rw-r--r--tensorflow/python/kernel_tests/matmul_op_test.py28
-rw-r--r--tensorflow/python/kernel_tests/padding_fifo_queue_test.py2
-rw-r--r--tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/tensor_array_ops_test.py12
-rw-r--r--tensorflow/python/kernel_tests/xent_op_test.py5
-rw-r--r--tensorflow/python/ops/array_ops.py2
-rw-r--r--tensorflow/python/ops/math_grad.py41
-rw-r--r--tensorflow/python/ops/math_grad_test.py4
-rw-r--r--tensorflow/python/ops/math_ops.py115
-rw-r--r--tensorflow/python/ops/nn_ops.py20
-rw-r--r--tensorflow/python/ops/rnn_cell.py36
-rw-r--r--tensorflow/python/training/moving_averages.py4
-rw-r--r--tensorflow/tools/docker/README.md6
-rw-r--r--tensorflow/tools/docker/jupyter_notebook_config.py7
-rw-r--r--third_party/gpus/cuda/platform.bzl2
136 files changed, 1402 insertions, 509 deletions
diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc
index bfad5515e6..629072ed7e 100644
--- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc
+++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc
@@ -22,8 +22,6 @@
#include <sys/wait.h>
#include <unistd.h>
-#include <string>
-#include <tuple>
#include <vector>
#include "tensorflow/core/lib/io/path.h"
@@ -40,6 +38,7 @@ namespace {
const char kFfmpegExecutable[] = "ffmpeg";
const int32 kDefaultProbeSize = 5000000; // 5MB
+
std::vector<string> FfmpegCommandLine(const string& input_filename,
const string& output_filename,
const string& input_format_id,
diff --git a/tensorflow/contrib/learn/python/learn/README.md b/tensorflow/contrib/learn/python/learn/README.md
index f557999828..2ab165f284 100644
--- a/tensorflow/contrib/learn/python/learn/README.md
+++ b/tensorflow/contrib/learn/python/learn/README.md
@@ -105,7 +105,7 @@ iris = datasets.load_iris()
def my_model(X, y):
"""This is DNN with 10, 20, 10 hidden layers, and dropout of 0.5 probability."""
- layers = learn.ops.dnn(X, [10, 20, 10], keep_prob=0.5)
+ layers = learn.ops.dnn(X, [10, 20, 10], dropout=0.5)
return learn.models.logistic_regression(layers, y)
classifier = learn.TensorFlowEstimator(model_fn=my_model, n_classes=3)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/__init__.py b/tensorflow/contrib/learn/python/learn/estimators/__init__.py
index 1ad6491a95..e714c15f2e 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/__init__.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/__init__.py
@@ -16,7 +16,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowEstimator
+from tensorflow.contrib.learn.python.learn.estimators.autoencoder import TensorFlowDNNAutoencoder
+from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowEstimator, TensorFlowBaseTransformer
from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNClassifier
from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNRegressor
from tensorflow.contrib.learn.python.learn.estimators.dnn import TensorFlowDNNClassifier
diff --git a/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py b/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py
index 491eba45a4..dcd1d81056 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py
@@ -16,8 +16,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
+import os
+
import numpy as np
+def _pprint(d):
+ return ', '.join(['%s=%s' % (key, str(value)) for key, value in d.items()])
+
class _BaseEstimator(object):
"""This is a cross-import when sklearn is not available.
@@ -43,6 +49,9 @@ class _BaseEstimator(object):
for key in param_names:
value = getattr(self, key, None)
+ if isinstance(value, collections.Callable):
+ continue
+
# XXX: should we rather test if instance of estimator?
if deep and hasattr(value, 'get_params'):
deep_items = value.get_params().items()
@@ -90,9 +99,7 @@ class _BaseEstimator(object):
def __repr__(self):
class_name = self.__class__.__name__
return '%s(%s)' % (class_name,
- _pprint(
- self.get_params(deep=False),
- offset=len(class_name),),)
+ _pprint(self.get_params(deep=False)),)
class _ClassifierMixin():
@@ -104,6 +111,8 @@ class _RegressorMixin():
"""Mixin class for all regression estimators."""
pass
+class _TransformerMixin():
+ """Mixin class for all transformer estimators."""
class _NotFittedError(ValueError, AttributeError):
"""Exception class to raise if estimator is used before fitting.
@@ -152,6 +161,7 @@ def _train_test_split(*args, **options):
train_size = 1 - test_size
train_size = train_size * args[0].shape[0]
+ np.random.seed(random_state)
indices = np.random.permutation(args[0].shape[0])
train_idx, test_idx = indices[:train_size], indices[:train_size]
result = []
@@ -159,30 +169,29 @@ def _train_test_split(*args, **options):
result += [x.take(train_idx, axis=0), x.take(test_idx, axis=0)]
return tuple(result)
-# Try to import sklearn, if fail - use _BaseEstimator.
-try:
- from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
-except ImportError:
+
+# If "TENSORFLOW_SKLEARN" flag is defined then try to import from sklearn.
+TRY_IMPORT_SKLEARN = os.environ.get('TENSORFLOW_SKLEARN', False)
+if TRY_IMPORT_SKLEARN:
+ from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin, TransformerMixin
+ from sklearn.metrics import accuracy_score, log_loss, mean_squared_error
+ from sklearn.cross_validation import train_test_split
+ try:
+ from sklearn.exceptions import NotFittedError
+ except ImportError:
+ try:
+ from sklearn.utils.validation import NotFittedError
+ except ImportError:
+ NotFittedError = _NotFittedError
+else:
+ # Naive implementations of sklearn classes and functions.
BaseEstimator = _BaseEstimator
ClassifierMixin = _ClassifierMixin
RegressorMixin = _RegressorMixin
-
-# Try to import exception for not fitted error.
-try:
- from sklearn.exceptions import NotFittedError
-except ImportError:
+ TransformerMixin = _TransformerMixin
NotFittedError = _NotFittedError
-
-# Try to import metrics
-try:
- from sklearn.metrics import accuracy_score, log_loss, mean_squared_error
-except ImportError:
accuracy_score = _accuracy_score
log_loss = None
mean_squared_error = _mean_squared_error
-
-# Try to import train_test_split
-try:
- from sklearn.cross_validation import train_test_split
-except ImportError:
train_test_split = _train_test_split
+
diff --git a/tensorflow/contrib/learn/python/learn/estimators/autoencoder.py b/tensorflow/contrib/learn/python/learn/estimators/autoencoder.py
new file mode 100644
index 0000000000..690bac8f19
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/estimators/autoencoder.py
@@ -0,0 +1,116 @@
+"""Deep Autoencoder estimators."""
+# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.ops import nn
+from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowBaseTransformer
+from tensorflow.contrib.learn.python.learn import models
+
+
+class TensorFlowDNNAutoencoder(TensorFlowBaseTransformer):
+ """TensorFlow Autoencoder Regressor model.
+
+ Parameters:
+ hidden_units: List of hidden units per layer.
+ batch_size: Mini batch size.
+ activation: activation function used to map inner latent layer onto
+ reconstruction layer.
+ add_noise: a function that adds noise to tensor_in,
+ e.g. def add_noise(x):
+ return(x + np.random.normal(0, 0.1, (len(x), len(x[0]))))
+ steps: Number of steps to run over data.
+ optimizer: Optimizer name (or class), for example "SGD", "Adam",
+ "Adagrad".
+ learning_rate: If this is constant float value, no decay function is used.
+ Instead, a customized decay function can be passed that accepts
+ global_step as parameter and returns a Tensor.
+ e.g. exponential decay function:
+ def exp_decay(global_step):
+ return tf.train.exponential_decay(
+ learning_rate=0.1, global_step,
+ decay_steps=2, decay_rate=0.001)
+ continue_training: when continue_training is True, once initialized
+ model will be continuely trained on every call of fit.
+ config: RunConfig object that controls the configurations of the session,
+ e.g. num_cores, gpu_memory_fraction, etc.
+ verbose: Controls the verbosity, possible values:
+ 0: the algorithm and debug information is muted.
+ 1: trainer prints the progress.
+ 2: log device placement is printed.
+ dropout: When not None, the probability we will drop out a given
+ coordinate.
+ """
+ def __init__(self, hidden_units, n_classes=0, batch_size=32,
+ steps=200, optimizer="Adagrad", learning_rate=0.1,
+ clip_gradients=5.0, activation=nn.relu, add_noise=None,
+ continue_training=False, config=None,
+ verbose=1, dropout=None):
+ self.hidden_units = hidden_units
+ self.dropout = dropout
+ self.activation = activation
+ self.add_noise = add_noise
+ super(TensorFlowDNNAutoencoder, self).__init__(
+ model_fn=self._model_fn,
+ n_classes=n_classes,
+ batch_size=batch_size, steps=steps, optimizer=optimizer,
+ learning_rate=learning_rate, clip_gradients=clip_gradients,
+ continue_training=continue_training,
+ config=config, verbose=verbose)
+
+ def _model_fn(self, X, y):
+ encoder, decoder, autoencoder_estimator = models.get_autoencoder_model(
+ self.hidden_units,
+ models.linear_regression,
+ activation=self.activation,
+ add_noise=self.add_noise,
+ dropout=self.dropout)(X)
+ self.encoder = encoder
+ self.decoder = decoder
+ return autoencoder_estimator
+
+ def generate(self, hidden=None):
+ """Generate new data using trained construction layer"""
+ if hidden is None:
+ last_layer = len(self.hidden_units) - 1
+ bias = self.get_tensor_value('encoder/dnn/layer%d/Linear/Bias:0' % last_layer)
+ import numpy as np
+ hidden = np.random.normal(size=bias.shape)
+ hidden = np.reshape(hidden, (1, len(hidden)))
+ return self._session.run(self.decoder, feed_dict={self.encoder: hidden})
+
+ @property
+ def weights_(self):
+ """Returns weights of the autoencoder's weight layers."""
+ weights = []
+ for layer in range(len(self.hidden_units)):
+ weights.append(self.get_tensor_value('encoder/dnn/layer%d/Linear/Matrix:0' % layer))
+ for layer in range(len(self.hidden_units)):
+ weights.append(self.get_tensor_value('decoder/dnn/layer%d/Linear/Matrix:0' % layer))
+ weights.append(self.get_tensor_value('linear_regression/weights:0'))
+ return weights
+
+ @property
+ def bias_(self):
+ """Returns bias of the autoencoder's bias layers."""
+ biases = []
+ for layer in range(len(self.hidden_units)):
+ biases.append(self.get_tensor_value('encoder/dnn/layer%d/Linear/Bias:0' % layer))
+ for layer in range(len(self.hidden_units)):
+ biases.append(self.get_tensor_value('decoder/dnn/layer%d/Linear/Bias:0' % layer))
+ biases.append(self.get_tensor_value('linear_regression/bias:0'))
+ return biases
+
diff --git a/tensorflow/contrib/learn/python/learn/estimators/base.py b/tensorflow/contrib/learn/python/learn/estimators/base.py
index 9459bdbcc9..a94ee9e717 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/base.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/base.py
@@ -360,7 +360,7 @@ class TensorFlowEstimator(estimator.Estimator):
reconfigured.
Returns:
- Estiamator, object of the subclass of TensorFlowEstimator.
+ Estimator, object of the subclass of TensorFlowEstimator.
"""
model_def_filename = os.path.join(path, 'model.def')
if not os.path.exists(model_def_filename):
@@ -379,6 +379,7 @@ class TensorFlowEstimator(estimator.Estimator):
new_value = locals()[key]
if new_value is not None:
model_def[key] = new_value
+
class_name = model_def.pop('class_name')
if class_name == 'TensorFlowEstimator':
custom_estimator = TensorFlowEstimator(model_fn=None, **model_def)
@@ -392,3 +393,18 @@ class TensorFlowEstimator(estimator.Estimator):
estimator = getattr(estimators, class_name)(**model_def)
estimator._restore(path)
return estimator
+
+
+class TensorFlowBaseTransformer(TensorFlowEstimator, _sklearn.TransformerMixin):
+ """TensorFlow Base Transformer class."""
+ def transform(self, X):
+ """Transform X using trained transformer."""
+ return(super(TensorFlowBaseTransformer, self).predict(X, axis=1, batch_size=None))
+
+ def fit(self, X, y=None, monitor=None, logdir=None):
+ """Fit a transformer."""
+ return(super(TensorFlowBaseTransformer, self).fit(X, y, monitors=None, logdir=None))
+
+ def fit_transform(self, X, y=None, monitor=None, logdir=None):
+ """Fit transformer and transform X using trained transformer."""
+ return(self.fit(X, y, monitor=None, logdir=None).transform(X))
diff --git a/tensorflow/contrib/learn/python/learn/io/data_feeder.py b/tensorflow/contrib/learn/python/learn/io/data_feeder.py
index 866a860e0e..babb216d9d 100644
--- a/tensorflow/contrib/learn/python/learn/io/data_feeder.py
+++ b/tensorflow/contrib/learn/python/learn/io/data_feeder.py
@@ -88,9 +88,8 @@ def setup_train_data_feeder(
X, y = _data_type_filter(X, y)
if HAS_DASK:
import dask.dataframe as dd
- allowed_classes = (dd.Series, dd.DataFrame)
- if (isinstance(X, allowed_classes) and
- (y is None or isinstance(y, allowed_classes))):
+ if (isinstance(X, (dd.Series, dd.DataFrame)) and
+ (y is None or isinstance(y, (dd.Series, dd.DataFrame)))):
data_feeder_cls = DaskDataFeeder
else:
data_feeder_cls = DataFeeder
@@ -156,7 +155,7 @@ def setup_processor_data_feeder(X):
def check_array(array, dtype):
- """Checks array on dtype and convers it if different.
+ """Checks array on dtype and converts it if different.
Args:
array: Input array.
@@ -165,7 +164,10 @@ def check_array(array, dtype):
Returns:
Original array or converted.
"""
- array = np.array(array, dtype=dtype, order=None, copy=False)
+ # skip check if array is instance of other classes, e.g. h5py.Dataset
+ # to avoid copying array and loading whole data into memory
+ if isinstance(array, (np.ndarray, list)):
+ array = np.array(array, dtype=dtype, order=None, copy=False)
return array
@@ -459,10 +461,8 @@ class DaskDataFeeder(object):
"""Data feeder for TF trainer that reads data from dask.Series and dask.DataFrame.
Numpy arrays can be serialized to disk and it's possible to do random seeks
- into them.
- DaskDataFeeder will remove requirement to have full dataset in the memory
- and still do
- random seeks for sampling of batches.
+ into them. DaskDataFeeder will remove requirement to have full dataset in the
+ memory and still do random seeks for sampling of batches.
Parameters:
X: iterator that returns for each element, returns features.
@@ -483,10 +483,9 @@ class DaskDataFeeder(object):
input_dtype: dtype of input.
output_dtype: dtype of output.
"""
-
def __init__(self, X, y, n_classes, batch_size, shuffle=True, random_state=None):
import dask.dataframe as dd
- # TODO: check X and y dtypes in dask_io like pandas
+ # TODO(terrytangyuan): check X and y dtypes in dask_io like pandas
self.X = X
self.y = y
# save column names
@@ -497,6 +496,8 @@ class DaskDataFeeder(object):
# deal with cases where two DFs have overlapped default numeric colnames
self.y_columns = len(self.X_columns) + 1
self.y = self.y.rename(columns={y.columns[0]: self.y_columns})
+
+ # TODO(terrytangyuan): deal with unsupervised cases
# combine into a data frame
self.df = dd.multi.concat([self.X, self.y], axis=1)
self.n_classes = n_classes
@@ -535,7 +536,6 @@ class DaskDataFeeder(object):
A function that when called samples a random subset of batch size
from X and y.
"""
-
def _feed_dict_fn():
# TODO: option for with/without replacement (dev version of dask)
sample = self.df.random_split(
@@ -555,5 +555,4 @@ class DaskDataFeeder(object):
encoded_out[np.arange(out.size), out] = 1
return {input_placeholder.name: inp,
output_placeholder.name: encoded_out}
-
return _feed_dict_fn
diff --git a/tensorflow/contrib/learn/python/learn/models.py b/tensorflow/contrib/learn/python/learn/models.py
index 2c83a95ab8..8cabd390fc 100644
--- a/tensorflow/contrib/learn/python/learn/models.py
+++ b/tensorflow/contrib/learn/python/learn/models.py
@@ -18,6 +18,7 @@ from __future__ import print_function
from tensorflow.contrib.learn.python.learn.ops import dnn_ops
from tensorflow.contrib.learn.python.learn.ops import losses_ops
+from tensorflow.contrib.learn.python.learn.ops import autoencoder_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops as array_ops_
@@ -187,6 +188,36 @@ def get_dnn_model(hidden_units, target_predictor_fn, dropout=None):
return dnn_estimator
+def get_autoencoder_model(hidden_units, target_predictor_fn,
+ activation, add_noise=None, dropout=None):
+ """Returns a function that creates a Autoencoder TensorFlow subgraph with given
+ params.
+
+ Args:
+ hidden_units: List of values of hidden units for layers.
+ target_predictor_fn: Function that will predict target from input
+ features. This can be logistic regression,
+ linear regression or any other model,
+ that takes X, y and returns predictions and loss tensors.
+ activation: activation function used to map inner latent layer onto
+ reconstruction layer.
+ add_noise: a function that adds noise to tensor_in,
+ e.g. def add_noise(x):
+ return(x + np.random.normal(0, 0.1, (len(x), len(x[0]))))
+ dropout: When not none, causes dropout regularization to be used,
+ with the specified probability of removing a given coordinate.
+
+ Returns:
+ A function that creates the subgraph.
+ """
+ def dnn_autoencoder_estimator(X):
+ """Autoencoder estimator with target predictor function on top."""
+ encoder, decoder = autoencoder_ops.dnn_autoencoder(
+ X, hidden_units, activation,
+ add_noise=add_noise, dropout=dropout)
+ return encoder, decoder, target_predictor_fn(X, decoder)
+ return dnn_autoencoder_estimator
+
## This will be in Tensorflow 0.7.
## TODO(ilblackdragon): Clean this up when it's released
diff --git a/tensorflow/contrib/learn/python/learn/ops/__init__.py b/tensorflow/contrib/learn/python/learn/ops/__init__.py
index b8d00ac27c..0bda7110aa 100644
--- a/tensorflow/contrib/learn/python/learn/ops/__init__.py
+++ b/tensorflow/contrib/learn/python/learn/ops/__init__.py
@@ -19,6 +19,7 @@ from __future__ import print_function
from tensorflow.contrib.learn.python.learn.ops.array_ops import *
from tensorflow.contrib.learn.python.learn.ops.conv_ops import *
from tensorflow.contrib.learn.python.learn.ops.dnn_ops import *
+from tensorflow.contrib.learn.python.learn.ops.autoencoder_ops import *
from tensorflow.contrib.learn.python.learn.ops.dropout_ops import *
from tensorflow.contrib.learn.python.learn.ops.embeddings_ops import *
from tensorflow.contrib.learn.python.learn.ops.losses_ops import *
diff --git a/tensorflow/contrib/learn/python/learn/ops/autoencoder_ops.py b/tensorflow/contrib/learn/python/learn/ops/autoencoder_ops.py
new file mode 100644
index 0000000000..e337b3b124
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/ops/autoencoder_ops.py
@@ -0,0 +1,56 @@
+"""TensorFlow ops for autoencoder."""
+# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import variable_scope as vs
+from tensorflow.contrib.learn.python.learn.ops import dnn_ops
+
+
+def dnn_autoencoder(tensor_in, hidden_units,
+ activation=nn.relu, add_noise=None,
+ dropout=None, scope=None):
+ """Creates fully connected autoencoder subgraph.
+
+ Args:
+ tensor_in: tensor or placeholder for input features.
+ hidden_units: list of counts of hidden units in each layer.
+ activation: activation function used to map inner latent layer onto
+ reconstruction layer.
+ add_noise: a function that adds noise to tensor_in,
+ e.g. def add_noise(x):
+ return(x + np.random.normal(0, 0.1, (len(x), len(x[0]))))
+ dropout: if not None, will add a dropout layer with given
+ probability.
+ scope: the variable scope for this op.
+
+ Returns:
+ Tensors for encoder and decoder.
+ """
+ with vs.variable_op_scope([tensor_in], scope, "autoencoder"):
+ if add_noise is not None:
+ tensor_in = add_noise(tensor_in)
+ with vs.variable_scope('encoder'):
+ # build DNN encoder
+ encoder = dnn_ops.dnn(tensor_in, hidden_units,
+ activation=activation, dropout=dropout)
+ with vs.variable_scope('decoder'):
+ # reverse hidden_units and built DNN decoder
+ decoder = dnn_ops.dnn(encoder, hidden_units[::-1],
+ activation=activation, dropout=dropout)
+ return encoder, decoder
+
diff --git a/tensorflow/contrib/learn/python/learn/tests/test_custom_decay.py b/tensorflow/contrib/learn/python/learn/tests/test_custom_decay.py
index 3963f5eb4e..cb459e1bb4 100644
--- a/tensorflow/contrib/learn/python/learn/tests/test_custom_decay.py
+++ b/tensorflow/contrib/learn/python/learn/tests/test_custom_decay.py
@@ -45,12 +45,12 @@ class CustomDecayTest(tf.test.TestCase):
classifier = learn.TensorFlowDNNClassifier(hidden_units=[10, 20, 10],
n_classes=3,
- steps=800,
+ steps=500,
learning_rate=exp_decay)
classifier.fit(X_train, y_train)
score = accuracy_score(y_test, classifier.predict(X_test))
- self.assertGreater(score, 0.7, "Failed with score = {0}".format(score))
+ self.assertGreater(score, 0.65, "Failed with score = {0}".format(score))
if __name__ == "__main__":
diff --git a/tensorflow/contrib/learn/python/learn/tests/test_data_feeder.py b/tensorflow/contrib/learn/python/learn/tests/test_data_feeder.py
index ea361a0d68..77f9f2d6f2 100644
--- a/tensorflow/contrib/learn/python/learn/tests/test_data_feeder.py
+++ b/tensorflow/contrib/learn/python/learn/tests/test_data_feeder.py
@@ -126,6 +126,26 @@ class DataFeederTest(tf.test.TestCase):
[0.60000002, 0.2]])
self.assertAllClose(feed_dict[out.name], [[0., 0., 1.], [0., 1., 0.]])
+ def test_hdf5_data_feeder(self):
+ try:
+ import h5py
+ X = np.matrix([[1, 2], [3, 4]])
+ y = np.array([1, 2])
+ h5f = h5py.File('test_hdf5.h5', 'w')
+ h5f.create_dataset('X', data=X)
+ h5f.create_dataset('y', data=y)
+ h5f.close()
+ h5f = h5py.File('test_hdf5.h5', 'r')
+ X = h5f['X']
+ y = h5f['y']
+ df = data_feeder.DataFeeder(X, y, n_classes=0, batch_size=3)
+ inp, out = df.input_builder()
+ feed_dict_fn = df.get_feed_dict_fn()
+ feed_dict = feed_dict_fn()
+ self.assertAllClose(feed_dict[inp.name], [[3, 4], [1, 2]])
+ self.assertAllClose(feed_dict[out.name], [2, 1])
+ except ImportError:
+ print("Skipped test for hdf5 since it's not installed.")
class SetupPredictDataFeederTest(tf.test.TestCase):
diff --git a/tensorflow/contrib/learn/python/learn/tests/test_estimators.py b/tensorflow/contrib/learn/python/learn/tests/test_estimators.py
index 0c93c5aa50..069a146593 100644
--- a/tensorflow/contrib/learn/python/learn/tests/test_estimators.py
+++ b/tensorflow/contrib/learn/python/learn/tests/test_estimators.py
@@ -35,7 +35,6 @@ class CustomOptimizer(tf.test.TestCase):
iris.target,
test_size=0.2,
random_state=42)
-
# setup exponential decay function
def exp_decay(global_step):
return tf.train.exponential_decay(learning_rate=0.1,
@@ -48,13 +47,13 @@ class CustomOptimizer(tf.test.TestCase):
classifier = learn.TensorFlowDNNClassifier(hidden_units=[10, 20, 10],
n_classes=3,
- steps=800,
+ steps=400,
learning_rate=exp_decay,
optimizer=custom_optimizer)
classifier.fit(X_train, y_train)
score = accuracy_score(y_test, classifier.predict(X_test))
- self.assertGreater(score, 0.7, "Failed with score = {0}".format(score))
+ self.assertGreater(score, 0.65, "Failed with score = {0}".format(score))
if __name__ == "__main__":
diff --git a/tensorflow/contrib/learn/python/learn/tests/test_grid_search.py b/tensorflow/contrib/learn/python/learn/tests/test_grid_search.py
index ad744eb2be..1008c53c99 100644
--- a/tensorflow/contrib/learn/python/learn/tests/test_grid_search.py
+++ b/tensorflow/contrib/learn/python/learn/tests/test_grid_search.py
@@ -15,15 +15,17 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
import random
-try:
- from sklearn import datasets
- from sklearn.grid_search import GridSearchCV
- from sklearn.metrics import accuracy_score, mean_squared_error
- HAS_SKLEARN = True
-except ImportError:
- HAS_SKLEARN = False
+HAS_SKLEARN = os.environ.get('TENSORFLOW_SKLEARN', False)
+if HAS_SKLEARN:
+ try:
+ from sklearn import datasets
+ from sklearn.grid_search import GridSearchCV
+ from sklearn.metrics import accuracy_score, mean_squared_error
+ except ImportError:
+ HAS_SKLEARN = False
import tensorflow as tf
diff --git a/tensorflow/contrib/learn/python/learn/tests/test_nonlinear.py b/tensorflow/contrib/learn/python/learn/tests/test_nonlinear.py
index 5d50d2bbbb..7b8a27eebc 100644
--- a/tensorflow/contrib/learn/python/learn/tests/test_nonlinear.py
+++ b/tensorflow/contrib/learn/python/learn/tests/test_nonlinear.py
@@ -167,6 +167,16 @@ class NonLinearTest(tf.test.TestCase):
predictions = classifier.predict(test_data)
self.assertAllClose(predictions, np.array([1, 0]))
+ # def testDNNAutoencoder(self):
+ # import numpy as np
+ # iris = datasets.load_iris()
+ # autoencoder = learn.TensorFlowDNNAutoencoder(hidden_units=[10, 20])
+ # transformed = autoencoder.fit_transform(iris.data[1:2])
+ # expected = np.array([[ -3.57627869e-07, 1.17000043e+00, 1.01902664e+00, 1.19209290e-07,
+ # 0.00000000e+00, 1.19209290e-07, -5.96046448e-08, -2.38418579e-07,
+ # 9.74681854e-01, 1.19209290e-07]])
+ # self.assertAllClose(transformed, expected)
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/contrib/learn/python/learn/tests/test_saver.py b/tensorflow/contrib/learn/python/learn/tests/test_saver.py
index 28e3197b4c..8592e9c0bc 100644
--- a/tensorflow/contrib/learn/python/learn/tests/test_saver.py
+++ b/tensorflow/contrib/learn/python/learn/tests/test_saver.py
@@ -83,6 +83,5 @@ class SaverTest(tf.test.TestCase):
with self.assertRaises(NotImplementedError):
learn.TensorFlowEstimator.restore(path)
-
if __name__ == '__main__':
tf.test.main()
diff --git a/tensorflow/contrib/metrics/python/kernel_tests/confusion_matrix_ops_test.py b/tensorflow/contrib/metrics/python/kernel_tests/confusion_matrix_ops_test.py
index fc71d5c690..1ab27aaa90 100644
--- a/tensorflow/contrib/metrics/python/kernel_tests/confusion_matrix_ops_test.py
+++ b/tensorflow/contrib/metrics/python/kernel_tests/confusion_matrix_ops_test.py
@@ -23,15 +23,15 @@ import tensorflow as tf
class ConfusionMatrixTest(tf.test.TestCase):
- def _testConfMatrix(self, predictions, targets, truth):
+ def _testConfMatrix(self, predictions, labels, truth):
with self.test_session():
- ans = tf.contrib.metrics.confusion_matrix(predictions, targets)
+ ans = tf.contrib.metrics.confusion_matrix(predictions, labels)
tf_ans = ans.eval()
self.assertAllClose(tf_ans, truth, atol=1e-10)
def _testBasic(self, dtype):
predictions = np.arange(5, dtype=dtype)
- targets = np.arange(5, dtype=dtype)
+ labels = np.arange(5, dtype=dtype)
truth = np.asarray(
[[1, 0, 0, 0, 0],
@@ -43,7 +43,7 @@ class ConfusionMatrixTest(tf.test.TestCase):
self._testConfMatrix(
predictions=predictions,
- targets=targets,
+ labels=labels,
truth=truth)
def testInt32Basic(self, dtype=np.int32):
@@ -54,7 +54,7 @@ class ConfusionMatrixTest(tf.test.TestCase):
def _testDiffentLabelsInPredictionAndTarget(self, dtype):
predictions = np.asarray([1, 2, 3], dtype=dtype)
- targets = np.asarray([4, 5, 6], dtype=dtype)
+ labels = np.asarray([4, 5, 6], dtype=dtype)
truth = np.asarray(
[[0, 0, 0, 0, 0, 0, 0],
@@ -68,7 +68,7 @@ class ConfusionMatrixTest(tf.test.TestCase):
self._testConfMatrix(
predictions=predictions,
- targets=targets,
+ labels=labels,
truth=truth)
def testInt32DifferentLabels(self, dtype=np.int32):
@@ -79,7 +79,7 @@ class ConfusionMatrixTest(tf.test.TestCase):
def _testMultipleLabels(self, dtype):
predictions = np.asarray([1, 1, 2, 3, 5, 6, 1, 2, 3, 4], dtype=dtype)
- targets = np.asarray([1, 1, 2, 3, 5, 1, 3, 6, 3, 1], dtype=dtype)
+ labels = np.asarray([1, 1, 2, 3, 5, 1, 3, 6, 3, 1], dtype=dtype)
truth = np.asarray(
[[0, 0, 0, 0, 0, 0, 0],
@@ -93,7 +93,7 @@ class ConfusionMatrixTest(tf.test.TestCase):
self._testConfMatrix(
predictions=predictions,
- targets=targets,
+ labels=labels,
truth=truth)
def testInt32MultipleLabels(self, dtype=np.int32):
@@ -104,24 +104,24 @@ class ConfusionMatrixTest(tf.test.TestCase):
def testInvalidRank(self):
predictions = np.asarray([[1, 2, 3]])
- targets = np.asarray([1, 2, 3])
+ labels = np.asarray([1, 2, 3])
self.assertRaisesRegexp(
ValueError, "are not compatible",
- tf.contrib.metrics.confusion_matrix, predictions, targets)
+ tf.contrib.metrics.confusion_matrix, predictions, labels)
predictions = np.asarray([1, 2, 3])
- targets = np.asarray([[1, 2, 3]])
+ labels = np.asarray([[1, 2, 3]])
self.assertRaisesRegexp(
ValueError, "are not compatible",
- tf.contrib.metrics.confusion_matrix, predictions, targets)
+ tf.contrib.metrics.confusion_matrix, predictions, labels)
def testInputDifferentSize(self):
predictions = np.asarray([1, 2, 3])
- targets = np.asarray([1, 2])
+ labels = np.asarray([1, 2])
self.assertRaisesRegexp(
ValueError,
"are not compatible",
- tf.contrib.metrics.confusion_matrix, predictions, targets)
+ tf.contrib.metrics.confusion_matrix, predictions, labels)
if __name__ == '__main__':
tf.test.main()
diff --git a/tensorflow/contrib/metrics/python/ops/confusion_matrix_ops.py b/tensorflow/contrib/metrics/python/ops/confusion_matrix_ops.py
index f8d28788ca..0f9baf2ee7 100644
--- a/tensorflow/contrib/metrics/python/ops/confusion_matrix_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/confusion_matrix_ops.py
@@ -25,14 +25,17 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
-def confusion_matrix(predictions, targets, num_classes=None, name=None):
- """Computes the confusion matrix from predictions and targets
+"""Confusion matrix related metrics."""
+
+
+def confusion_matrix(predictions, labels, num_classes=None, name=None):
+ """Computes the confusion matrix from predictions and labels
Calculate the Confusion Matrix for a pair of prediction and
- target 1-D int arrays.
+ label 1-D int arrays.
Considering a prediction array such as: `[1, 2, 3]`
- And a target array such as: `[2, 2, 3]`
+ And a label array such as: `[2, 2, 3]`
The confusion matrix returned would be the following one:
[[0, 0, 0]
@@ -41,18 +44,18 @@ def confusion_matrix(predictions, targets, num_classes=None, name=None):
[0, 0, 1]]
Where the matrix rows represent the prediction labels and the columns
- represents the target labels. The confusion matrix is always a 2-D array
+ represents the real labels. The confusion matrix is always a 2-D array
of shape [n, n], where n is the number of valid labels for a given
- classification task. Both prediction and target must be 1-D arrays of
+ classification task. Both prediction and labels must be 1-D arrays of
the same shape in order for this function to work.
Args:
predictions: A 1-D array represeting the predictions for a given
classification.
- targets: A 1-D represeting the real labels for the classification task.
+ labels: A 1-D represeting the real labels for the classification task.
num_classes: The possible number of labels the classification task can
have. If this value is not provided, it will be calculated
- using both predictions and targets array.
+ using both predictions and labels array.
name: Scope name.
Returns:
@@ -60,22 +63,21 @@ def confusion_matrix(predictions, targets, num_classes=None, name=None):
possible labels in the classification task.
Raises:
- ValueError: If both predictions and targets are not 1-D vectors and do not
+ ValueError: If both predictions and labels are not 1-D vectors and do not
have the same size.
"""
- with ops.op_scope([predictions, targets, num_classes], name,
+ with ops.op_scope([predictions, labels, num_classes], name,
'confusion_matrix') as name:
predictions = ops.convert_to_tensor(
predictions, name='predictions', dtype=dtypes.int64)
- targets = ops.convert_to_tensor(targets, name='targets', dtype=dtypes.int64)
+ labels = ops.convert_to_tensor(labels, name='labels', dtype=dtypes.int64)
if num_classes is None:
num_classes = math_ops.maximum(math_ops.reduce_max(predictions),
- math_ops.reduce_max(targets)) + 1
+ math_ops.reduce_max(labels)) + 1
shape = array_ops.pack([num_classes, num_classes])
- indices = array_ops.transpose(
- array_ops.pack([predictions, targets]))
+ indices = array_ops.transpose(array_ops.pack([predictions, labels]))
values = array_ops.ones_like(predictions, dtype=dtypes.int32)
cm_sparse = ops.SparseTensor(
indices=indices, values=values, shape=shape)
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index c9754f518c..1be21fa954 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1291,6 +1291,7 @@ tf_cc_tests_gpu(
exclude = [
# Run by tests below
"common_runtime/gpu/gpu_allocator_retry_test.cc",
+ "common_runtime/gpu/gpu_debug_allocator_test.cc",
"common_runtime/gpu/gpu_stream_util_test.cc",
"common_runtime/gpu/gpu_tracer_test.cc",
],
@@ -1482,6 +1483,30 @@ tf_cc_test_gpu(
)
tf_cc_test_gpu(
+ name = "common_runtime/gpu/gpu_debug_allocator_test.cc",
+ size = "medium",
+ args = ["\"--gtest_death_test_style=threadsafe\""],
+ linkstatic = tf_kernel_tests_linkstatic(),
+ tags = tf_cuda_tests_tags(),
+ deps = [
+ ":core_cpu",
+ ":core_cpu_internal",
+ ":direct_session",
+ ":framework",
+ ":framework_internal",
+ ":gpu_runtime",
+ ":lib",
+ ":lib_internal",
+ ":protos_all_cc",
+ ":test",
+ ":test_main",
+ ":testlib",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/core/kernels:ops_util",
+ ],
+)
+
+tf_cc_test_gpu(
name = "common_runtime/gpu/gpu_stream_util_test.cc",
size = "small",
linkstatic = tf_kernel_tests_linkstatic(),
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 35a16e8bfe..93ed0d8c32 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -621,16 +621,17 @@ LocalDevice* BaseGPUDeviceFactory::CreateGPUDevice(
int64 total_memory, available_memory;
CHECK(se->DeviceMemoryUsage(&available_memory, &total_memory));
- int64 allocated_memory = available_memory;
+ int64 allocated_memory;
double config_memory_fraction =
options.config.gpu_options().per_process_gpu_memory_fraction();
if (config_memory_fraction == 0) {
+ allocated_memory = available_memory;
const int64 min_system_memory = MinSystemMemory(available_memory);
if (min_system_memory < allocated_memory) {
allocated_memory -= min_system_memory;
}
} else {
- allocated_memory *= config_memory_fraction;
+ allocated_memory = total_memory * config_memory_fraction;
}
Bytes allocated_bytes = static_cast<Bytes>(allocated_memory);
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index 4e4b6d3703..3267a200f0 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -264,7 +264,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
// partitions_ is immutable after RegisterPartitions() call
// finishes. RunPartitions() can access partitions_ safely without
- // acquring locks.
+ // acquiring locks.
std::vector<Part> partitions_;
mutable mutex mu_;
diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc
index 0057db6967..e4e5610448 100644
--- a/tensorflow/core/kernels/conv_grad_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_ops.cc
@@ -18,6 +18,7 @@ limitations under the License.
#define USE_EIGEN_TENSOR
#define EIGEN_USE_THREADS
+#include <algorithm>
#include <vector>
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -839,13 +840,13 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
context->allocate_output(0, input_shape, &in_backprop));
const int padding_rows =
- (padding_ == VALID)
- ? 0
- : (output_rows - 1) * stride_rows + filter_rows - input_rows;
+ (padding_ == VALID) ? 0
+ : std::max<int>(0, (output_rows - 1) * stride_rows +
+ filter_rows - input_rows);
const int padding_cols =
- (padding_ == VALID)
- ? 0
- : (output_cols - 1) * stride_cols + filter_cols - input_cols;
+ (padding_ == VALID) ? 0
+ : std::max<int>(0, (output_cols - 1) * stride_cols +
+ filter_cols - input_cols);
// TODO(keveman): cuDNN only supports equal padding on both sides, so only
// calling it when that is true. Remove this check when (if?) cuDNN starts
@@ -889,6 +890,7 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
", n=", n, ", k=", k));
}
+
return;
}
@@ -1058,6 +1060,7 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
"cuDNN Backward Data function launch failure : input shape(",
input_shape.DebugString(), ") filter shape(",
filter_shape.DebugString(), ")"));
+ return;
}
if (rows_odd || cols_odd) {
@@ -1148,13 +1151,13 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
context->allocate_output(0, filter_shape, &filter_backprop));
const int padding_rows =
- (padding_ == VALID)
- ? 0
- : (output_rows - 1) * stride_rows + filter_rows - input_rows;
+ (padding_ == VALID) ? 0
+ : std::max<int>(0, (output_rows - 1) * stride_rows +
+ filter_rows - input_rows);
const int padding_cols =
- (padding_ == VALID)
- ? 0
- : (output_cols - 1) * stride_cols + filter_cols - input_cols;
+ (padding_ == VALID) ? 0
+ : std::max<int>(0, (output_cols - 1) * stride_cols +
+ filter_cols - input_cols);
// TODO(zhengxq): cuDNN only supports equal padding on both sides, so only
// calling it when that is true. Remove this check when (if?) cuDNN starts
@@ -1387,6 +1390,7 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
"cuDNN Backward Filter function launch failure : input shape(",
input_shape.DebugString(), ") filter shape(",
filter_shape.DebugString(), ")"));
+ return;
}
auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index de8d8e784d..4eb67268c2 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -312,6 +312,7 @@ struct LaunchConvOp<GPUDevice, T> {
ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
", n=", n, ", k=", k));
}
+
return;
}
int padding_rows = 0;
diff --git a/tensorflow/core/kernels/cudnn_pooling_gpu.h b/tensorflow/core/kernels/cudnn_pooling_gpu.h
index d1982299e2..59ae794737 100644
--- a/tensorflow/core/kernels/cudnn_pooling_gpu.h
+++ b/tensorflow/core/kernels/cudnn_pooling_gpu.h
@@ -18,6 +18,8 @@ limitations under the License.
#ifndef TENSORFLOW_KERNELS_CUDNN_POOLING_GPU_H_
#define TENSORFLOW_KERNELS_CUDNN_POOLING_GPU_H_
+#include <array>
+
#include "tensorflow/core/framework/op_kernel.h"
#if GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_abs.cc b/tensorflow/core/kernels/cwise_op_abs.cc
index ff9db97fe9..e745143339 100644
--- a/tensorflow/core/kernels/cwise_op_abs.cc
+++ b/tensorflow/core/kernels/cwise_op_abs.cc
@@ -19,8 +19,7 @@ namespace tensorflow {
REGISTER5(UnaryOp, CPU, "Abs", functor::abs, float, Eigen::half, double, int32,
int64);
#if !defined(__ANDROID__)
-REGISTER_KERNEL_BUILDER(Name("ComplexAbs").Device(DEVICE_CPU),
- UnaryOp<CPUDevice, functor::abs<complex64>>);
+REGISTER2(UnaryOp, CPU, "ComplexAbs", functor::abs, complex64, complex128);
#endif
#if GOOGLE_CUDA
REGISTER4(UnaryOp, GPU, "Abs", functor::abs, float, Eigen::half, double, int64);
diff --git a/tensorflow/core/kernels/cwise_op_add.cc b/tensorflow/core/kernels/cwise_op_add.cc
index 1457b74e6f..4aa3761ffe 100644
--- a/tensorflow/core/kernels/cwise_op_add.cc
+++ b/tensorflow/core/kernels/cwise_op_add.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER9(BinaryOp, CPU, "Add", functor::add, float, Eigen::half, double, int32,
- int64, int8, int16, complex64, string);
+REGISTER10(BinaryOp, CPU, "Add", functor::add, float, Eigen::half, double, int32,
+ int64, int8, int16, complex64, complex128, string);
#if GOOGLE_CUDA
REGISTER4(BinaryOp, GPU, "Add", functor::add, float, Eigen::half, double,
int64);
diff --git a/tensorflow/core/kernels/cwise_op_complex.cc b/tensorflow/core/kernels/cwise_op_complex.cc
index d7e4638ff1..f8796f72d8 100644
--- a/tensorflow/core/kernels/cwise_op_complex.cc
+++ b/tensorflow/core/kernels/cwise_op_complex.cc
@@ -16,10 +16,20 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER_KERNEL_BUILDER(Name("Complex").Device(DEVICE_CPU),
- BinaryOp<CPUDevice, functor::make_complex<float>>);
+#define REGISTER_COMPLEX(D, R, C) \
+ REGISTER_KERNEL_BUILDER(Name("Complex") \
+ .Device(DEVICE_##D) \
+ .TypeConstraint<R>("T") \
+ .TypeConstraint<C>("Tout"), \
+ BinaryOp<D##Device, functor::make_complex<R>>);
+
+REGISTER_COMPLEX(CPU, float, complex64);
+REGISTER_COMPLEX(CPU, double, complex128);
+
#if GOOGLE_CUDA
-REGISTER_KERNEL_BUILDER(Name("Complex").Device(DEVICE_GPU),
- BinaryOp<GPUDevice, functor::make_complex<float>>);
+REGISTER_COMPLEX(GPU, float, complex64);
+REGISTER_COMPLEX(GPU, double, complex128);
#endif
+
+#undef REGISTER_COMPLEX
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_conj.cc b/tensorflow/core/kernels/cwise_op_conj.cc
index 37ab71dd65..e4cb6aabe3 100644
--- a/tensorflow/core/kernels/cwise_op_conj.cc
+++ b/tensorflow/core/kernels/cwise_op_conj.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER_KERNEL_BUILDER(Name("Conj").Device(DEVICE_CPU),
- UnaryOp<CPUDevice, functor::conj<complex64>>);
+
+REGISTER2(UnaryOp, CPU, "Conj", functor::conj, complex64, complex128);
#if GOOGLE_CUDA
// REGISTER_KERNEL_BUILDER(Name("Conj").Device(DEVICE_GPU),
// UnaryOp<GPUDevice, functor::conj<complex64>>);
diff --git a/tensorflow/core/kernels/cwise_op_cos.cc b/tensorflow/core/kernels/cwise_op_cos.cc
index 6958fa22b8..cd7dd976db 100644
--- a/tensorflow/core/kernels/cwise_op_cos.cc
+++ b/tensorflow/core/kernels/cwise_op_cos.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER4(UnaryOp, CPU, "Cos", functor::cos, float, Eigen::half, double,
- complex64);
+REGISTER5(UnaryOp, CPU, "Cos", functor::cos, float, Eigen::half, double,
+ complex64, complex128);
#if GOOGLE_CUDA
REGISTER3(UnaryOp, GPU, "Cos", functor::cos, float, Eigen::half, double);
#endif
diff --git a/tensorflow/core/kernels/cwise_op_div.cc b/tensorflow/core/kernels/cwise_op_div.cc
index c4899473e6..617d663e8d 100644
--- a/tensorflow/core/kernels/cwise_op_div.cc
+++ b/tensorflow/core/kernels/cwise_op_div.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER4(BinaryOp, CPU, "Div", functor::div, float, Eigen::half, double,
- complex64);
+REGISTER5(BinaryOp, CPU, "Div", functor::div, float, Eigen::half, double,
+ complex64, complex128);
REGISTER4(BinaryOp, CPU, "Div", functor::safe_div, uint8, int16, int32, int64);
#if GOOGLE_CUDA
REGISTER6(BinaryOp, GPU, "Div", functor::div, float, Eigen::half, double, uint8,
diff --git a/tensorflow/core/kernels/cwise_op_equal_to.cc b/tensorflow/core/kernels/cwise_op_equal_to.cc
index c1d5b2f4ed..d9cd413534 100644
--- a/tensorflow/core/kernels/cwise_op_equal_to.cc
+++ b/tensorflow/core/kernels/cwise_op_equal_to.cc
@@ -16,8 +16,9 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER11(BinaryOp, CPU, "Equal", functor::equal_to, float, Eigen::half,
- double, uint8, int8, int16, int32, int64, complex64, string, bool);
+REGISTER12(BinaryOp, CPU, "Equal", functor::equal_to, float, Eigen::half,
+ double, uint8, int8, int16, int32, int64, complex64, complex128,
+ string, bool);
#if GOOGLE_CUDA
REGISTER8(BinaryOp, GPU, "Equal", functor::equal_to, float, Eigen::half, double,
uint8, int8, int16, int64, bool);
diff --git a/tensorflow/core/kernels/cwise_op_exp.cc b/tensorflow/core/kernels/cwise_op_exp.cc
index 2d7df89149..4191018174 100644
--- a/tensorflow/core/kernels/cwise_op_exp.cc
+++ b/tensorflow/core/kernels/cwise_op_exp.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER4(UnaryOp, CPU, "Exp", functor::exp, float, Eigen::half, double,
- complex64);
+REGISTER5(UnaryOp, CPU, "Exp", functor::exp, float, Eigen::half, double,
+ complex64, complex128);
#if GOOGLE_CUDA
REGISTER3(UnaryOp, GPU, "Exp", functor::exp, float, Eigen::half, double);
#endif
diff --git a/tensorflow/core/kernels/cwise_op_gpu_complex.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_complex.cu.cc
index cb6fafaad3..c7433c2de9 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_complex.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_complex.cu.cc
@@ -19,7 +19,7 @@ limitations under the License.
namespace tensorflow {
namespace functor {
-DEFINE_BINARY1(make_complex, float);
+DEFINE_BINARY2(make_complex, float, double);
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_gpu_equal_to.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_equal_to.cu.cc
index c183425254..0d2c299e75 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_equal_to.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_equal_to.cu.cc
@@ -19,8 +19,8 @@ limitations under the License.
namespace tensorflow {
namespace functor {
-DEFINE_BINARY9(equal_to, float, Eigen::half, double, uint8, int8, int16, int64,
- complex64, bool);
+DEFINE_BINARY10(equal_to, float, Eigen::half, double, uint8, int8, int16, int64,
+ complex64, complex128, bool);
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_gpu_imag.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_imag.cu.cc
index 58cdccfb30..f43aee786c 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_imag.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_imag.cu.cc
@@ -19,7 +19,7 @@ limitations under the License.
namespace tensorflow {
namespace functor {
-DEFINE_UNARY1(get_imag, complex64);
+DEFINE_UNARY2(get_imag, complex64, complex128);
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_gpu_not_equal_to.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_not_equal_to.cu.cc
index 171c1f9e40..ecd91cfd13 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_not_equal_to.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_not_equal_to.cu.cc
@@ -19,8 +19,8 @@ limitations under the License.
namespace tensorflow {
namespace functor {
-DEFINE_BINARY9(not_equal_to, float, Eigen::half, double, uint8, int8, int16,
- int64, complex64, bool);
+DEFINE_BINARY10(not_equal_to, float, Eigen::half, double, uint8, int8, int16,
+ int64, complex64, complex128, bool);
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_gpu_real.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_real.cu.cc
index 6ad3505422..f8cbb33703 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_real.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_real.cu.cc
@@ -19,7 +19,7 @@ limitations under the License.
namespace tensorflow {
namespace functor {
-DEFINE_UNARY1(get_real, complex64);
+DEFINE_UNARY2(get_real, complex64, complex128);
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc
index f29d07e2db..6867a76c98 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc
@@ -77,6 +77,7 @@ SELECT_FUNCTOR(double);
SELECT_FUNCTOR(int32);
SELECT_FUNCTOR(int64);
SELECT_FUNCTOR(complex64);
+SELECT_FUNCTOR(complex128);
#undef SELECT_FUNCTOR
diff --git a/tensorflow/core/kernels/cwise_op_gpu_zeta.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_zeta.cu.cc
new file mode 100644
index 0000000000..3bbc4ec39a
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_zeta.cu.cc
@@ -0,0 +1,27 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_BINARY2(zeta, float, double);
+DEFINE_BINARY2(polygamma, float, double);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_imag.cc b/tensorflow/core/kernels/cwise_op_imag.cc
index 7366f7342f..18bf8d83fe 100644
--- a/tensorflow/core/kernels/cwise_op_imag.cc
+++ b/tensorflow/core/kernels/cwise_op_imag.cc
@@ -16,10 +16,20 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER_KERNEL_BUILDER(Name("Imag").Device(DEVICE_CPU),
- UnaryOp<CPUDevice, functor::get_imag<complex64>>);
+#define REGISTER_COMPLEX(D, R, C) \
+ REGISTER_KERNEL_BUILDER(Name("Imag") \
+ .Device(DEVICE_##D) \
+ .TypeConstraint<C>("T") \
+ .TypeConstraint<R>("Tout"), \
+ UnaryOp<D##Device, functor::get_imag<C>>);
+
+REGISTER_COMPLEX(CPU, float, complex64);
+REGISTER_COMPLEX(CPU, double, complex128);
+
#if GOOGLE_CUDA
-REGISTER_KERNEL_BUILDER(Name("Imag").Device(DEVICE_GPU),
- UnaryOp<GPUDevice, functor::get_imag<complex64>>);
+REGISTER_COMPLEX(GPU, float, complex64);
+REGISTER_COMPLEX(GPU, double, complex128);
#endif
+
+#undef REGISTER_COMPLEX
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_inverse.cc b/tensorflow/core/kernels/cwise_op_inverse.cc
index 05834996c0..168ab268b4 100644
--- a/tensorflow/core/kernels/cwise_op_inverse.cc
+++ b/tensorflow/core/kernels/cwise_op_inverse.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER4(UnaryOp, CPU, "Inv", functor::inverse, float, Eigen::half, double,
- complex64);
+REGISTER5(UnaryOp, CPU, "Inv", functor::inverse, float, Eigen::half, double,
+ complex64, complex128);
#if GOOGLE_CUDA
REGISTER4(UnaryOp, GPU, "Inv", functor::inverse, float, Eigen::half, double,
int64);
diff --git a/tensorflow/core/kernels/cwise_op_log.cc b/tensorflow/core/kernels/cwise_op_log.cc
index ab6a1f9778..587b4f9eca 100644
--- a/tensorflow/core/kernels/cwise_op_log.cc
+++ b/tensorflow/core/kernels/cwise_op_log.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER4(UnaryOp, CPU, "Log", functor::log, float, Eigen::half, double,
- complex64);
+REGISTER5(UnaryOp, CPU, "Log", functor::log, float, Eigen::half, double,
+ complex64, complex128);
#if GOOGLE_CUDA
REGISTER3(UnaryOp, GPU, "Log", functor::log, float, Eigen::half, double);
#endif
diff --git a/tensorflow/core/kernels/cwise_op_mul.cc b/tensorflow/core/kernels/cwise_op_mul.cc
index 395cea5d7f..e1783950a7 100644
--- a/tensorflow/core/kernels/cwise_op_mul.cc
+++ b/tensorflow/core/kernels/cwise_op_mul.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER9(BinaryOp, CPU, "Mul", functor::mul, float, Eigen::half, double, uint8,
- int8, int16, int32, int64, complex64);
+REGISTER10(BinaryOp, CPU, "Mul", functor::mul, float, Eigen::half, double, uint8,
+ int8, int16, int32, int64, complex64, complex128);
#if GOOGLE_CUDA
REGISTER7(BinaryOp, GPU, "Mul", functor::mul, float, Eigen::half, double, uint8,
int8, int16, int64);
diff --git a/tensorflow/core/kernels/cwise_op_neg.cc b/tensorflow/core/kernels/cwise_op_neg.cc
index 2b672285b3..7c580c57e1 100644
--- a/tensorflow/core/kernels/cwise_op_neg.cc
+++ b/tensorflow/core/kernels/cwise_op_neg.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER6(UnaryOp, CPU, "Neg", functor::neg, float, Eigen::half, double, int32,
- complex64, int64);
+REGISTER7(UnaryOp, CPU, "Neg", functor::neg, float, Eigen::half, double, int32,
+ complex64, int64, complex128);
#if GOOGLE_CUDA
REGISTER4(UnaryOp, GPU, "Neg", functor::neg, float, Eigen::half, double, int64);
diff --git a/tensorflow/core/kernels/cwise_op_not_equal_to.cc b/tensorflow/core/kernels/cwise_op_not_equal_to.cc
index da4a40ad9d..a4fbd8a280 100644
--- a/tensorflow/core/kernels/cwise_op_not_equal_to.cc
+++ b/tensorflow/core/kernels/cwise_op_not_equal_to.cc
@@ -16,8 +16,9 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER11(BinaryOp, CPU, "NotEqual", functor::not_equal_to, float, Eigen::half,
- double, uint8, int8, int16, int32, int64, complex64, string, bool);
+REGISTER12(BinaryOp, CPU, "NotEqual", functor::not_equal_to, float, Eigen::half,
+ double, uint8, int8, int16, int32, int64, complex64, complex128,
+ string, bool);
#if GOOGLE_CUDA
REGISTER8(BinaryOp, GPU, "NotEqual", functor::not_equal_to, float, Eigen::half,
double, uint8, int8, int16, int64, bool);
diff --git a/tensorflow/core/kernels/cwise_op_pow.cc b/tensorflow/core/kernels/cwise_op_pow.cc
index 8bb71c03d8..3c136d7e2c 100644
--- a/tensorflow/core/kernels/cwise_op_pow.cc
+++ b/tensorflow/core/kernels/cwise_op_pow.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER6(BinaryOp, CPU, "Pow", functor::pow, float, Eigen::half, double, int32,
- int64, complex64);
+REGISTER7(BinaryOp, CPU, "Pow", functor::pow, float, Eigen::half, double, int32,
+ int64, complex64, complex128);
#if GOOGLE_CUDA
REGISTER4(BinaryOp, GPU, "Pow", functor::pow, float, Eigen::half, double,
int64);
diff --git a/tensorflow/core/kernels/cwise_op_real.cc b/tensorflow/core/kernels/cwise_op_real.cc
index f8619b6c69..214fb96291 100644
--- a/tensorflow/core/kernels/cwise_op_real.cc
+++ b/tensorflow/core/kernels/cwise_op_real.cc
@@ -16,10 +16,21 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER_KERNEL_BUILDER(Name("Real").Device(DEVICE_CPU),
- UnaryOp<CPUDevice, functor::get_real<complex64>>);
+
+#define REGISTER_COMPLEX(D, R, C) \
+ REGISTER_KERNEL_BUILDER(Name("Real") \
+ .Device(DEVICE_##D) \
+ .TypeConstraint<C>("T") \
+ .TypeConstraint<R>("Tout"), \
+ UnaryOp<D##Device, functor::get_real<C>>);
+
+REGISTER_COMPLEX(CPU, float, complex64);
+REGISTER_COMPLEX(CPU, double, complex128);
+
#if GOOGLE_CUDA
-REGISTER_KERNEL_BUILDER(Name("Real").Device(DEVICE_GPU),
- UnaryOp<GPUDevice, functor::get_real<complex64>>);
+REGISTER_COMPLEX(GPU, float, complex64);
+REGISTER_COMPLEX(GPU, double, complex128);
#endif
+
+#undef REGISTER_COMPLEX
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_rsqrt.cc b/tensorflow/core/kernels/cwise_op_rsqrt.cc
index ff3e8d778f..ff6ae3bd2f 100644
--- a/tensorflow/core/kernels/cwise_op_rsqrt.cc
+++ b/tensorflow/core/kernels/cwise_op_rsqrt.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER4(UnaryOp, CPU, "Rsqrt", functor::rsqrt, float, Eigen::half, double,
- complex64);
+REGISTER5(UnaryOp, CPU, "Rsqrt", functor::rsqrt, float, Eigen::half, double,
+ complex64, complex128);
#if GOOGLE_CUDA
REGISTER3(UnaryOp, GPU, "Rsqrt", functor::rsqrt, float, Eigen::half, double);
#endif
diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc
index f9df4ad94e..4308bfd253 100644
--- a/tensorflow/core/kernels/cwise_op_select.cc
+++ b/tensorflow/core/kernels/cwise_op_select.cc
@@ -133,6 +133,7 @@ REGISTER_SELECT_GPU(double);
REGISTER_SELECT_GPU(int32);
REGISTER_SELECT_GPU(int64);
REGISTER_SELECT_GPU(complex64);
+REGISTER_SELECT_GPU(complex128);
#undef REGISTER_SELECT_GPU
diff --git a/tensorflow/core/kernels/cwise_op_sigmoid.cc b/tensorflow/core/kernels/cwise_op_sigmoid.cc
index 574866b3f0..3b51a293c7 100644
--- a/tensorflow/core/kernels/cwise_op_sigmoid.cc
+++ b/tensorflow/core/kernels/cwise_op_sigmoid.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER4(UnaryOp, CPU, "Sigmoid", functor::sigmoid, float, Eigen::half, double,
- complex64);
+REGISTER5(UnaryOp, CPU, "Sigmoid", functor::sigmoid, float, Eigen::half, double,
+ complex64, complex128);
#if GOOGLE_CUDA
REGISTER3(UnaryOp, GPU, "Sigmoid", functor::sigmoid, float, Eigen::half,
double);
diff --git a/tensorflow/core/kernels/cwise_op_sign.cc b/tensorflow/core/kernels/cwise_op_sign.cc
index 7107970332..1d17a4a066 100644
--- a/tensorflow/core/kernels/cwise_op_sign.cc
+++ b/tensorflow/core/kernels/cwise_op_sign.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER6(UnaryOp, CPU, "Sign", functor::sign, float, double, int32, int64,
- complex64, Eigen::half);
+REGISTER7(UnaryOp, CPU, "Sign", functor::sign, float, double, int32, int64,
+ complex64, Eigen::half, complex128);
#if GOOGLE_CUDA
REGISTER4(UnaryOp, GPU, "Sign", functor::sign, float, Eigen::half, double,
int64);
diff --git a/tensorflow/core/kernels/cwise_op_sin.cc b/tensorflow/core/kernels/cwise_op_sin.cc
index 123a251c62..85ceab95d0 100644
--- a/tensorflow/core/kernels/cwise_op_sin.cc
+++ b/tensorflow/core/kernels/cwise_op_sin.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER4(UnaryOp, CPU, "Sin", functor::sin, float, Eigen::half, double,
- complex64);
+REGISTER5(UnaryOp, CPU, "Sin", functor::sin, float, Eigen::half, double,
+ complex64, complex128);
#if GOOGLE_CUDA
REGISTER3(UnaryOp, GPU, "Sin", functor::sin, float, Eigen::half, double);
#endif
diff --git a/tensorflow/core/kernels/cwise_op_sqrt.cc b/tensorflow/core/kernels/cwise_op_sqrt.cc
index daaa4c55c6..652e12851b 100644
--- a/tensorflow/core/kernels/cwise_op_sqrt.cc
+++ b/tensorflow/core/kernels/cwise_op_sqrt.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER4(UnaryOp, CPU, "Sqrt", functor::sqrt, float, Eigen::half, double,
- complex64);
+REGISTER5(UnaryOp, CPU, "Sqrt", functor::sqrt, float, Eigen::half, double,
+ complex64, complex128);
#if GOOGLE_CUDA
REGISTER3(UnaryOp, GPU, "Sqrt", functor::sqrt, float, Eigen::half, double);
#endif
diff --git a/tensorflow/core/kernels/cwise_op_square.cc b/tensorflow/core/kernels/cwise_op_square.cc
index 80c8423cf9..f03d2f11d4 100644
--- a/tensorflow/core/kernels/cwise_op_square.cc
+++ b/tensorflow/core/kernels/cwise_op_square.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER6(UnaryOp, CPU, "Square", functor::square, float, Eigen::half, double,
- int32, complex64, int64);
+REGISTER7(UnaryOp, CPU, "Square", functor::square, float, Eigen::half, double,
+ int32, complex64, complex128, int64);
#if GOOGLE_CUDA
REGISTER4(UnaryOp, GPU, "Square", functor::square, float, Eigen::half, double,
int64);
diff --git a/tensorflow/core/kernels/cwise_op_sub.cc b/tensorflow/core/kernels/cwise_op_sub.cc
index 245f13ad68..ac744aeb2b 100644
--- a/tensorflow/core/kernels/cwise_op_sub.cc
+++ b/tensorflow/core/kernels/cwise_op_sub.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER6(BinaryOp, CPU, "Sub", functor::sub, float, Eigen::half, double, int32,
- int64, complex64);
+REGISTER7(BinaryOp, CPU, "Sub", functor::sub, float, Eigen::half, double, int32,
+ int64, complex64, complex128);
#if GOOGLE_CUDA
REGISTER4(BinaryOp, GPU, "Sub", functor::sub, float, Eigen::half, double,
int64);
diff --git a/tensorflow/core/kernels/cwise_op_tanh.cc b/tensorflow/core/kernels/cwise_op_tanh.cc
index a42b6fe6e4..edb89f799c 100644
--- a/tensorflow/core/kernels/cwise_op_tanh.cc
+++ b/tensorflow/core/kernels/cwise_op_tanh.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER4(UnaryOp, CPU, "Tanh", functor::tanh, float, Eigen::half, double,
- complex64);
+REGISTER5(UnaryOp, CPU, "Tanh", functor::tanh, float, Eigen::half, double,
+ complex64, complex128);
#if GOOGLE_CUDA
REGISTER3(UnaryOp, GPU, "Tanh", functor::tanh, float, Eigen::half, double);
#endif
diff --git a/tensorflow/core/kernels/cwise_op_zeta.cc b/tensorflow/core/kernels/cwise_op_zeta.cc
new file mode 100644
index 0000000000..a960475396
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_zeta.cc
@@ -0,0 +1,21 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER2(BinaryOp, CPU, "Zeta", functor::zeta, float, double);
+REGISTER2(BinaryOp, CPU, "Polygamma", functor::polygamma, float, double);
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h
index 71d65f806b..7d45fb8511 100644
--- a/tensorflow/core/kernels/cwise_ops.h
+++ b/tensorflow/core/kernels/cwise_ops.h
@@ -480,6 +480,12 @@ template <typename T>
struct igammac : base<T, Eigen::internal::scalar_igammac_op<T>> {};
template <typename T>
+struct zeta : base<T, Eigen::internal::scalar_zeta_op<T>> {};
+
+template <typename T>
+struct polygamma : base<T, Eigen::internal::scalar_polygamma_op<T>> {};
+
+template <typename T>
struct squared_difference
: base<T, Eigen::internal::scalar_compose_op<
T, Eigen::internal::scalar_square_op<T>,
diff --git a/tensorflow/core/kernels/cwise_ops_common.h b/tensorflow/core/kernels/cwise_ops_common.h
index f86f206cb1..05fb4afb63 100644
--- a/tensorflow/core/kernels/cwise_ops_common.h
+++ b/tensorflow/core/kernels/cwise_ops_common.h
@@ -386,6 +386,9 @@ struct UnaryFunctor<CPUDevice, Functor> {
REGISTER(OP, D, N, F, T0)
#define REGISTER11(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) \
REGISTER(OP, D, N, F, T0)
+#define REGISTER12(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, \
+ T11) \
+ REGISTER(OP, D, N, F, T0)
#else // !defined(__ANDROID_TYPES_SLIM__)
#define REGISTER2(OP, D, N, F, T0, T1) \
REGISTER(OP, D, N, F, T0) \
@@ -417,6 +420,10 @@ struct UnaryFunctor<CPUDevice, Functor> {
#define REGISTER11(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) \
REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) \
REGISTER6(OP, D, N, F, T5, T6, T7, T8, T9, T10)
+#define REGISTER12(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, \
+ T11) \
+ REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) \
+ REGISTER6(OP, D, N, F, T6, T7, T8, T9, T10, T11)
#endif // defined(__ANDROID_TYPES_SLIM__)
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h b/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h
index d91d0faa86..6b23fb5785 100644
--- a/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h
+++ b/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h
@@ -34,6 +34,7 @@ namespace functor {
typedef Eigen::GpuDevice GPUDevice;
typedef std::complex<float> complex64;
+typedef std::complex<double> complex128;
// Partial specialization of UnaryFunctor<Device=GPUDevice, Functor>.
template <typename Functor>
@@ -149,6 +150,9 @@ struct BinaryFunctor<GPUDevice, Functor, NDIMS, has_errors> {
#define DEFINE_BINARY9(F, T0, T1, T2, T3, T4, T5, T6, T7, T8) \
DEFINE_BINARY4(F, T0, T1, T2, T3); \
DEFINE_BINARY5(F, T4, T5, T6, T7, T8)
+#define DEFINE_BINARY10(F, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9) \
+ DEFINE_BINARY5(F, T0, T1, T2, T3, T4); \
+ DEFINE_BINARY5(F, T5, T6, T7, T8, T9)
} // end namespace functor
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/diag_op.cc b/tensorflow/core/kernels/diag_op.cc
index df738f00cb..252fc338b9 100644
--- a/tensorflow/core/kernels/diag_op.cc
+++ b/tensorflow/core/kernels/diag_op.cc
@@ -122,6 +122,7 @@ REGISTER_DIAGOP(double);
REGISTER_DIAGOP(float);
REGISTER_DIAGOP(int32);
REGISTER_DIAGOP(int64);
+REGISTER_DIAGOP(complex64);
#undef REGISTER_DIAGOP
@@ -188,6 +189,7 @@ REGISTER_DIAGPARTOP(double);
REGISTER_DIAGPARTOP(float);
REGISTER_DIAGPARTOP(int32);
REGISTER_DIAGPARTOP(int64);
+REGISTER_DIAGPARTOP(complex64);
#undef REGISTER_DIAGPARTOP
diff --git a/tensorflow/core/kernels/edit_distance_op.cc b/tensorflow/core/kernels/edit_distance_op.cc
index b1953adeae..b4d14e8c62 100644
--- a/tensorflow/core/kernels/edit_distance_op.cc
+++ b/tensorflow/core/kernels/edit_distance_op.cc
@@ -179,7 +179,7 @@ class EditDistanceOp : public OpKernel {
if (g_truth == g_hypothesis) {
auto loc = std::inner_product(g_truth.begin(), g_truth.end(),
- output_strides.begin(), 0);
+ output_strides.begin(), int64{0});
output_t(loc) =
gtl::LevenshteinDistance<T>(truth_seq, hypothesis_seq, cmp);
if (normalize_) output_t(loc) /= truth_seq.size();
@@ -188,13 +188,13 @@ class EditDistanceOp : public OpKernel {
++truth_iter;
} else if (g_truth > g_hypothesis) { // missing truth @ this hypothesis
auto loc = std::inner_product(g_hypothesis.begin(), g_hypothesis.end(),
- output_strides.begin(), 0);
+ output_strides.begin(), int64{0});
output_t(loc) = hypothesis_seq.size();
if (normalize_) output_t(loc) /= 0.0;
++hypothesis_iter;
} else { // missing hypothesis @ this truth
auto loc = std::inner_product(g_truth.begin(), g_truth.end(),
- output_strides.begin(), 0);
+ output_strides.begin(), int64{0});
output_t(loc) = (normalize_) ? 1.0 : truth_seq.size();
++truth_iter;
}
@@ -204,7 +204,7 @@ class EditDistanceOp : public OpKernel {
std::vector<int64> g_hypothesis = hypothesis_j.group();
auto hypothesis_seq = hypothesis_j.values<T>();
auto loc = std::inner_product(g_hypothesis.begin(), g_hypothesis.end(),
- output_strides.begin(), 0);
+ output_strides.begin(), int64{0});
output_t(loc) = hypothesis_seq.size();
if (normalize_) output_t(loc) /= 0.0;
++hypothesis_iter;
@@ -214,7 +214,7 @@ class EditDistanceOp : public OpKernel {
std::vector<int64> g_truth = truth_i.group();
auto truth_seq = truth_i.values<T>();
auto loc = std::inner_product(g_truth.begin(), g_truth.end(),
- output_strides.begin(), 0);
+ output_strides.begin(), int64{0});
output_t(loc) = (normalize_) ? 1.0 : truth_seq.size();
++truth_iter;
}
diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc
index 08b699c8fa..f6420aaad8 100644
--- a/tensorflow/core/kernels/matmul_op.cc
+++ b/tensorflow/core/kernels/matmul_op.cc
@@ -205,8 +205,9 @@ struct MatMulFunctor<CPUDevice, T> {
REGISTER_CPU(float);
REGISTER_CPU(double);
REGISTER_CPU(int32);
-REGISTER_CPU(complex64);
REGISTER_CPU(Eigen::half);
+REGISTER_CPU(complex64);
+REGISTER_CPU(complex128);
#if GOOGLE_CUDA
REGISTER_GPU(float);
// REGISTER_GPU(double);
diff --git a/tensorflow/core/kernels/queue_base.cc b/tensorflow/core/kernels/queue_base.cc
index 6886165f98..1ba130692a 100644
--- a/tensorflow/core/kernels/queue_base.cc
+++ b/tensorflow/core/kernels/queue_base.cc
@@ -371,6 +371,7 @@ Status QueueBase::CopySliceToElement(const Tensor& parent, Tensor* element,
HANDLE_TYPE(DT_QINT32);
HANDLE_TYPE(DT_QINT16);
HANDLE_TYPE(DT_QUINT16);
+ HANDLE_TYPE(DT_COMPLEX128);
#undef HANDLE_TYPE
return errors::Unimplemented("CopySliceToElement Unhandled data type: ",
parent.dtype());
@@ -399,6 +400,7 @@ Status QueueBase::CopyElementToSlice(const Tensor& element, Tensor* parent,
HANDLE_TYPE(DT_QINT32);
HANDLE_TYPE(DT_QINT16);
HANDLE_TYPE(DT_QUINT16);
+ HANDLE_TYPE(DT_COMPLEX128);
#undef HANDLE_TYPE
return errors::Unimplemented("CopyElementToSlice Unhandled data type: ",
element.dtype());
diff --git a/tensorflow/core/kernels/reduction_ops_gpu.cu.cc b/tensorflow/core/kernels/reduction_ops_gpu.cu.cc
index 69694fbd86..8ff724f5eb 100644
--- a/tensorflow/core/kernels/reduction_ops_gpu.cu.cc
+++ b/tensorflow/core/kernels/reduction_ops_gpu.cu.cc
@@ -90,6 +90,7 @@ DEFINE_FOR_ALL_REDUCERS(double);
#undef DEFINE_FOR_ALL_REDUCERS
DEFINE_FOR_TYPE_AND_R(complex64, Eigen::internal::SumReducer<complex64>);
+DEFINE_FOR_TYPE_AND_R(complex128, Eigen::internal::SumReducer<complex128>);
DEFINE_FOR_TYPE_AND_R(bool, Eigen::internal::AndReducer);
DEFINE_FOR_TYPE_AND_R(bool, Eigen::internal::OrReducer);
#undef DEFINE_FOR_TYPE_AND_R
diff --git a/tensorflow/core/kernels/reduction_ops_sum.cc b/tensorflow/core/kernels/reduction_ops_sum.cc
index 9661f6a523..72dd718754 100644
--- a/tensorflow/core/kernels/reduction_ops_sum.cc
+++ b/tensorflow/core/kernels/reduction_ops_sum.cc
@@ -22,14 +22,12 @@ namespace tensorflow {
Name("Sum").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
ReductionOp<CPUDevice, type, Eigen::internal::SumReducer<type>>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
-#undef REGISTER_CPU_KERNELS
-
// NOTE: We should have mean(complex64,int32), too. But that needs to
// change Eigen::internal::MeanReducer to cast int to complex<float>.
// We don't see immediate need of mean(complex64,int32) anyway.
-REGISTER_KERNEL_BUILDER(
- Name("Sum").Device(DEVICE_CPU).TypeConstraint<complex64>("T"),
- ReductionOp<CPUDevice, complex64, Eigen::internal::SumReducer<complex64>>);
+REGISTER_CPU_KERNELS(complex64);
+REGISTER_CPU_KERNELS(complex128);
+#undef REGISTER_CPU_KERNELS
#if GOOGLE_CUDA
@@ -43,15 +41,10 @@ REGISTER_KERNEL_BUILDER(
REGISTER_GPU_KERNELS(Eigen::half);
REGISTER_GPU_KERNELS(float);
REGISTER_GPU_KERNELS(double);
+REGISTER_GPU_KERNELS(complex64);
+REGISTER_GPU_KERNELS(complex128);
#undef REGISTER_GPU_KERNELS
-REGISTER_KERNEL_BUILDER(
- Name("Sum")
- .Device(DEVICE_GPU)
- .TypeConstraint<complex64>("T")
- .HostMemory("reduction_indices"),
- ReductionOp<GPUDevice, complex64, Eigen::internal::SumReducer<complex64>>);
-
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
// registration requires all int32 inputs and outputs to be in host memory.
diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc
index 9fba08d6b1..e60f11b246 100644
--- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc
+++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc
@@ -171,6 +171,7 @@ REGISTER_CPU(float);
REGISTER_CPU(double);
REGISTER_CPU(int32);
REGISTER_CPU(complex64);
+REGISTER_CPU(complex128);
#if GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/transpose_functor_cpu.cc b/tensorflow/core/kernels/transpose_functor_cpu.cc
index ca4268c78d..0c20aa1f01 100644
--- a/tensorflow/core/kernels/transpose_functor_cpu.cc
+++ b/tensorflow/core/kernels/transpose_functor_cpu.cc
@@ -99,6 +99,10 @@ Status DoTranspose<Device>(const Device& d, const Tensor& in,
internal::Transpose<Device, uint64>(d, in, perm, out);
break;
+ case DT_COMPLEX128:
+ internal::Transpose<Device, complex128>(d, in, perm, out);
+ break;
+
case DT_STRING:
internal::Transpose<Device, string>(d, in, perm, out);
break;
diff --git a/tensorflow/core/kernels/xent_op.cc b/tensorflow/core/kernels/xent_op.cc
index 3d8746fd7c..703254ce90 100644
--- a/tensorflow/core/kernels/xent_op.cc
+++ b/tensorflow/core/kernels/xent_op.cc
@@ -108,6 +108,10 @@ REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits")
.Device(DEVICE_GPU)
.TypeConstraint<float>("T"),
SoftmaxXentWithLogitsOp<GPUDevice, float>);
+REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<double>("T"),
+ SoftmaxXentWithLogitsOp<GPUDevice, double>);
#endif // GOOGLE_CUDA
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/xent_op_gpu.cu.cc b/tensorflow/core/kernels/xent_op_gpu.cu.cc
index 18b36bfcf1..1b707be007 100644
--- a/tensorflow/core/kernels/xent_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/xent_op_gpu.cu.cc
@@ -42,9 +42,10 @@ struct XentFunctor<GPUDevice, T> {
};
} // end namespace functor
-// Instantiate the GPU implementation for half and float.
+// Instantiate the GPU implementation for half, float and double.
template struct functor::XentFunctor<GPUDevice, Eigen::half>;
template struct functor::XentFunctor<GPUDevice, float>;
+template struct functor::XentFunctor<GPUDevice, double>;
} // end namespace tensorflow
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index fc4d2e6f13..fe0fa34337 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -167,7 +167,7 @@ y: a tensor of the same shape and type as x but filled with zeros.
REGISTER_OP("Diag")
.Input("diagonal: T")
.Output("output: T")
- .Attr("T: {float, double, int32, int64}")
+ .Attr("T: {float, double, int32, int64, complex64}")
.Doc(R"doc(
Returns a diagonal tensor with a given diagonal values.
@@ -196,7 +196,7 @@ diagonal: Rank k tensor where k is at most 3.
REGISTER_OP("DiagPart")
.Input("input: T")
.Output("diagonal: T")
- .Attr("T: {float, double, int32, int64}")
+ .Attr("T: {float, double, int32, int64, complex64}")
.Doc(R"doc(
Returns the diagonal part of the tensor.
@@ -660,14 +660,14 @@ For example:
```prettyprint
# tensor 't' is [1, 2, 3, 4, 5, 6, 7, 8, 9]
# tensor 't' has shape [9]
-reshape(t, [3, 3]) ==> [[1, 2, 3]
- [4, 5, 6]
+reshape(t, [3, 3]) ==> [[1, 2, 3],
+ [4, 5, 6],
[7, 8, 9]]
-# tensor 't' is [[[1, 1], [2, 2]]
+# tensor 't' is [[[1, 1], [2, 2]],
# [[3, 3], [4, 4]]]
# tensor 't' has shape [2, 2, 2]
-reshape(t, [2, 4]) ==> [[1, 1, 2, 2]
+reshape(t, [2, 4]) ==> [[1, 1, 2, 2],
[3, 3, 4, 4]]
# tensor 't' is [[[1, 1, 1],
@@ -679,9 +679,22 @@ reshape(t, [2, 4]) ==> [[1, 1, 2, 2]
# tensor 't' has shape [3, 2, 3]
# pass '[-1]' to flatten 't'
reshape(t, [-1]) ==> [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6]
-# -1 can also be used with higher dimensional shapes
+
+# -1 can also be used to infer the shape
+
+# -1 is inferred to be 9:
reshape(t, [2, -1]) ==> [[1, 1, 1, 2, 2, 2, 3, 3, 3],
[4, 4, 4, 5, 5, 5, 6, 6, 6]]
+# -1 is inferred to be 2:
+reshape(t, [-1, 9]) ==> [[1, 1, 1, 2, 2, 2, 3, 3, 3],
+ [4, 4, 4, 5, 5, 5, 6, 6, 6]]
+# -1 is inferred to be 3:
+reshape(t, [ 2, -1, 3]) ==> [[[1, 1, 1],
+ [2, 2, 2],
+ [3, 3, 3]],
+ [[4, 4, 4],
+ [5, 5, 5],
+ [6, 6, 6]]]
# tensor 't' is [7]
# shape `[]` reshapes to a scalar
@@ -1394,10 +1407,10 @@ The output tensor has shape `[4, 1, 1, 3]` and value:
(3) For the following input of shape `[1, 4, 4, 1]` and block_size of 2:
```prettyprint
-x = [[[1], [2], [3], [4]],
- [[5], [6], [7], [8]],
- [[9], [10], [11], [12]],
- [[13], [14], [15], [16]]]
+x = [[[[1], [2], [3], [4]],
+ [[5], [6], [7], [8]],
+ [[9], [10], [11], [12]],
+ [[13], [14], [15], [16]]]]
```
The output tensor has shape `[4, 2, 2, 1]` and value:
@@ -1591,10 +1604,10 @@ This operation, for block_size of 2, will return the following tensor of shape
Similarly, for the following input of shape `[1 4 4 1]`, and a block size of 2:
```prettyprint
-x = [[ [1], [2], [5], [6]],
- [ [3], [4], [7], [8]],
- [ [9], [10], [13], [14]],
- [ [11], [12], [15], [16]]]
+x = [[[[1], [2], [5], [6]],
+ [[3], [4], [7], [8]],
+ [[9], [10], [13], [14]],
+ [[11], [12], [15], [16]]]]
```
the operator will return the following tensor of shape `[1 2 2 4]`:
diff --git a/tensorflow/core/ops/math_grad.cc b/tensorflow/core/ops/math_grad.cc
index d290580077..7b4b28159d 100644
--- a/tensorflow/core/ops/math_grad.cc
+++ b/tensorflow/core/ops/math_grad.cc
@@ -284,7 +284,7 @@ REGISTER_OP_GRADIENT("Sub", SubGrad);
Status MulGrad(const AttrSlice& attrs, FunctionDef* g) {
DataType T;
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T));
- if (T == DT_COMPLEX64) {
+ if (T == DT_COMPLEX64 || T == DT_COMPLEX128) {
return GradForBinaryCwise(
g, {
{{"cy"}, "Conj", {"y"}, {}, {"dz"}},
@@ -543,7 +543,7 @@ Status MatMulGradCommon(const string& opname, const string& attr_adj_x,
FunctionDef* g) {
DataType T;
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T));
- if (T == DT_COMPLEX64) {
+ if (T == DT_COMPLEX64 || T == DT_COMPLEX128) {
return errors::Unimplemented(
"MatMul gradient for complex is not supported yet.");
}
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 861ed74b1f..fdb490df9e 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -37,7 +37,7 @@ REGISTER_OP("BatchMatMul")
.Input("x: T")
.Input("y: T")
.Output("output: T")
- .Attr("T: {half, float, double, int32, complex64}")
+ .Attr("T: {half, float, double, int32, complex64, complex128}")
.Attr("adj_x: bool = false")
.Attr("adj_y: bool = false")
.Doc(R"doc(
@@ -111,15 +111,17 @@ an output element, this operation computes \\(y = |x|\\).
)doc");
REGISTER_OP("ComplexAbs")
- .Input("x: complex64")
- .Output("y: float")
+ .Input("x: T")
+ .Output("y: Tout")
+ .Attr("T: {complex64, complex128} = DT_COMPLEX64")
+ .Attr("Tout: {float, double} = DT_FLOAT")
.Doc(R"doc(
Computes the complex absolute value of a tensor.
Given a tensor `x` of complex numbers, this operation returns a tensor of type
-`float` that is the absolute value of each element in `x`. All elements in `x`
-must be complex numbers of the form \\(a + bj\\). The absolute value is
-computed as \\( \sqrt{a^2 + b^2}\\).
+`float` or `double` that is the absolute value of each element in `x`. All
+elements in `x` must be complex numbers of the form \\(a + bj\\). The absolute
+value is computed as \\( \sqrt{a^2 + b^2}\\).
For example:
@@ -132,7 +134,7 @@ tf.complex_abs(x) ==> [5.25594902, 6.60492229]
// Declares cwise unary operations signature: 't -> 't
#define UNARY() \
Input("x: T").Output("y: T").Attr( \
- "T: {half, float, double, int32, complex64, int64}")
+ "T: {half, float, double, int32, int64, complex64, complex128}")
REGISTER_OP("Neg")
.UNARY()
@@ -262,7 +264,7 @@ Returns which elements of x are finite.
REGISTER_OP("Sign")
.Input("x: T")
.Output("y: T")
- .Attr("T: {half, float, double, int32, int64, complex64}")
+ .Attr("T: {half, float, double, int32, int64, complex64, complex128}")
.Doc(R"doc(
Returns an element-wise indication of the sign of a number.
@@ -291,11 +293,11 @@ Returns element-wise smallest integer in not less than x.
#define BINARY_MORE() \
Input("x: T").Input("y: T").Output("z: T").Attr( \
- "T: {half, float, double, uint8, int8, int16, int32, int64, complex64}")
+ "T: {half, float, double, uint8, int8, int16, int32, int64, complex64, complex128}")
#define BINARY_FEWER() \
Input("x: T").Input("y: T").Output("z: T").Attr( \
- "T: {half, float, double, int32, complex64, int64}")
+ "T: {half, float, double, int32, int64, complex64, complex128}")
// TODO(mrry): Restore `SetIsCommutative()` for non-string types.
REGISTER_OP("Add")
@@ -304,7 +306,7 @@ REGISTER_OP("Add")
.Output("z: T")
.Attr(
"T: {half, float, double, uint8, int8, int16, int32, int64, complex64, "
- "string}")
+ "complex128, string}")
.Doc(R"doc(
Returns x + y element-wise.
@@ -373,7 +375,7 @@ REGISTER_OP("Pow")
.Input("x: T")
.Input("y: T")
.Output("z: T")
- .Attr("T: {half, float, double, int32, complex64, int64}")
+ .Attr("T: {half, float, double, int32, int64, complex64, complex128}")
.Doc(R"doc(
Computes the power of one value to another.
@@ -433,6 +435,37 @@ Note, above `Q(a, x)` (`Igammac`) is the upper regularized complete
Gamma function.
)doc");
+REGISTER_OP("Zeta")
+ .Input("x: T")
+ .Input("q: T")
+ .Output("z: T")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Compute the Hurwitz zeta function \\(\zeta(x, q)\\).
+
+The Hurwitz zeta function is defined as:
+
+```
+\zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x}
+```
+)doc");
+
+REGISTER_OP("Polygamma")
+ .Input("a: T")
+ .Input("x: T")
+ .Output("z: T")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Compute the polygamma function \\(\psi^{(n)}(x)\\).
+
+The polygamma function is defined as:
+
+```
+\psi^{(n)}(x) = \frac{d^n}{dx^n} \psi(x)
+```
+where \\(\psi(x)\\) is the digamma function.
+)doc");
+
// --------------------------------------------------------------------------
// Declares cwise binary comparison operations signature: 't, 't -> bool,
@@ -471,7 +504,7 @@ Returns the truth value of (x >= y) element-wise.
#define EQUALITY_COMPARISON() \
Input("x: T").Input("y: T").Output("z: bool").SetIsCommutative().Attr( \
"T: {half, float, double, uint8, int8, int16, int32, int64, complex64, " \
- "quint8, qint8, qint32, string, bool}")
+ "quint8, qint8, qint32, string, bool, complex128}")
REGISTER_OP("Equal")
.EQUALITY_COMPARISON()
@@ -577,7 +610,7 @@ REGISTER_OP("MatMul")
.Output("product: T")
.Attr("transpose_a: bool = false")
.Attr("transpose_b: bool = false")
- .Attr("T: {half, float, double, int32, complex64}")
+ .Attr("T: {half, float, double, int32, complex64, complex128}")
.Doc(R"doc(
Multiply the matrix "a" by the matrix "b".
@@ -1141,9 +1174,11 @@ output: 1-D. The generated values.
)doc");
REGISTER_OP("Complex")
- .Input("real: float")
- .Input("imag: float")
- .Output("output: complex64")
+ .Input("real: T")
+ .Input("imag: T")
+ .Output("out: Tout")
+ .Attr("T: {float, double} = DT_FLOAT")
+ .Attr("Tout: {complex64, complex128} = DT_COMPLEX64")
.Doc(R"doc(
Converts two real numbers to a complex number.
@@ -1163,7 +1198,12 @@ tf.complex(real, imag) ==> [[2.25 + 4.75j], [3.25 + 5.75j]]
```
)doc");
-REGISTER_OP("Real").Input("input: complex64").Output("output: float").Doc(R"doc(
+REGISTER_OP("Real")
+ .Input("input: T")
+ .Output("output: Tout")
+ .Attr("T: {complex64, complex128} = DT_COMPLEX64")
+ .Attr("Tout: {float, double} = DT_FLOAT")
+ .Doc(R"doc(
Returns the real part of a complex number.
Given a tensor `input` of complex numbers, this operation returns a tensor of
@@ -1179,7 +1219,12 @@ tf.real(input) ==> [-2.25, 3.25]
```
)doc");
-REGISTER_OP("Imag").Input("input: complex64").Output("output: float").Doc(R"doc(
+REGISTER_OP("Imag")
+ .Input("input: T")
+ .Output("output: Tout")
+ .Attr("T: {complex64, complex128} = DT_COMPLEX64")
+ .Attr("Tout: {float, double} = DT_FLOAT")
+ .Doc(R"doc(
Returns the imaginary part of a complex number.
Given a tensor `input` of complex numbers, this operation returns a tensor of
@@ -1196,8 +1241,9 @@ tf.imag(input) ==> [4.75, 5.75]
)doc");
REGISTER_OP("Conj")
- .Input("input: complex64")
- .Output("output: complex64")
+ .Input("input: T")
+ .Output("output: T")
+ .Attr("T: {complex64, complex128} = DT_COMPLEX64")
.Doc(R"doc(
Returns the complex conjugate of a complex number.
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index d22df658c8..34e38b97d7 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -3594,6 +3594,7 @@ op {
type: DT_DOUBLE
type: DT_INT32
type: DT_INT64
+ type: DT_COMPLEX64
}
}
}
@@ -3621,6 +3622,7 @@ op {
type: DT_DOUBLE
type: DT_INT32
type: DT_INT64
+ type: DT_COMPLEX64
}
}
}
diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto
index 8469c3823b..f357735fd6 100644
--- a/tensorflow/core/protobuf/worker.proto
+++ b/tensorflow/core/protobuf/worker.proto
@@ -154,7 +154,7 @@ message RunGraphRequest {
// A unique ID to distinguish different runs of the same graph.
//
- // The master generates a global unique `step_id` to dinstinguish
+ // The master generates a global unique `step_id` to distinguish
// different runs of the graph computation. Subgraphs communicate
// (e.g., send/recv ops) with each other using `step_id` to
// distinguish tensors generated by different runs.
diff --git a/tensorflow/examples/skflow/README.md b/tensorflow/examples/skflow/README.md
index 8bb2739e6c..756351c0b6 100644
--- a/tensorflow/examples/skflow/README.md
+++ b/tensorflow/examples/skflow/README.md
@@ -7,21 +7,28 @@ known Scikit Learn API.
To run these examples, you need to have `scikit learn` library installed (`sudo pip install sklearn`).
Some examples use the `pandas` library for data processing (`sudo pip install pandas`).
+## Basics
+
* [Deep Neural Network Regression with Boston Data](boston.py)
* [Convolutional Neural Networks with Digits Data](digits.py)
* [Deep Neural Network Classification with Iris Data](iris.py)
-* [Grid search and Deep Neural Network Classification](iris_gridsearch_cv.py)
-* [Deep Neural Network with Customized Decay Function](iris_custom_decay_dnn.py)
+* [Deep Neural Network Autoencoder with Iris Data](dnn_autoencoder_iris.py)
* [Building A Custom Model](iris_custom_model.py)
* [Accessing Weights and Biases in A Custom Model](mnist_weights.py)
-* [Building A Custom Model Using Multiple GPUs](multiple_gpu.py)
-* [Building A Model Using Different GPU Configurations](iris_config_addon.py)
-* [Using skflow with Pipeline](iris_with_pipeline.py)
+* [Building A Model Using Different GPU Configurations](iris_run_config.py)
* [Example of Saving and Restoring Models](iris_save_restore.py)
* [Multi-output Deep Neural Network regression](multioutput_regression.py)
+
+
+## Techniques
+
* [Improving Performance Using Early Stopping with Iris Data](iris_val_based_early_stopping.py)
+* [Using skflow with Pipeline](iris_with_pipeline.py)
+* [Building A Custom Model Using Multiple GPUs](multiple_gpu.py)
+* [Grid search and Deep Neural Network Classification](iris_gridsearch_cv.py)
+* [Deep Neural Network with Customized Decay Function](iris_custom_decay_dnn.py)
* [Out-of-core Data Classification Using Dask](out_of_core_data_classification.py)
-
+* [Handling Large HDF5 Dataset](hdf5_classification.py)
## Image classification
diff --git a/tensorflow/examples/skflow/boston.py b/tensorflow/examples/skflow/boston.py
index 2349baf72e..bf2066770c 100644
--- a/tensorflow/examples/skflow/boston.py
+++ b/tensorflow/examples/skflow/boston.py
@@ -40,6 +40,6 @@ regressor = skflow.TensorFlowDNNRegressor(hidden_units=[10, 10],
regressor.fit(X_train, y_train)
# Predict and score
-score = metrics.mean_squared_error(regressor.predict(scaler.fit_transform(X_test)), y_test)
+score = metrics.mean_squared_error(regressor.predict(scaler.transform(X_test)), y_test)
print('MSE: {0:f}'.format(score))
diff --git a/tensorflow/examples/skflow/digits.py b/tensorflow/examples/skflow/digits.py
index f61df6c572..b3c684b7df 100644
--- a/tensorflow/examples/skflow/digits.py
+++ b/tensorflow/examples/skflow/digits.py
@@ -18,8 +18,8 @@ from __future__ import print_function
from sklearn import datasets, cross_validation, metrics
import tensorflow as tf
-from tensorflow.contrib import skflow
-from tensorflow.contrib.skflow import monitors
+from tensorflow.contrib import learn
+from tensorflow.contrib.learn import monitors
# Load dataset
@@ -45,13 +45,13 @@ X_train, X_val, y_train, y_val = cross_validation.train_test_split(X_train,
def conv_model(X, y):
X = tf.expand_dims(X, 3)
- features = tf.reduce_max(skflow.ops.conv2d(X, 12, [3, 3]), [1, 2])
+ features = tf.reduce_max(learn.ops.conv2d(X, 12, [3, 3]), [1, 2])
features = tf.reshape(features, [-1, 12])
- return skflow.models.logistic_regression(features, y)
+ return learn.models.logistic_regression(features, y)
val_monitor = monitors.ValidationMonitor(X_val, y_val, every_n_steps=50)
# Create a classifier, train and predict.
-classifier = skflow.TensorFlowEstimator(model_fn=conv_model, n_classes=10,
+classifier = learn.TensorFlowEstimator(model_fn=conv_model, n_classes=10,
steps=1000, learning_rate=0.05,
batch_size=128)
classifier.fit(X_train, y_train, val_monitor)
diff --git a/tensorflow/examples/skflow/dnn_autoencoder_iris.py b/tensorflow/examples/skflow/dnn_autoencoder_iris.py
new file mode 100644
index 0000000000..c4383ae608
--- /dev/null
+++ b/tensorflow/examples/skflow/dnn_autoencoder_iris.py
@@ -0,0 +1,35 @@
+# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import random
+
+import tensorflow as tf
+from tensorflow.contrib import learn
+from tensorflow.contrib.learn import datasets
+
+# Load Iris Data
+iris = datasets.load_iris()
+
+# Initialize a deep neural network autoencoder
+# You can also add noise and add dropout if needed
+# Details see TensorFlowDNNAutoencoder documentation.
+autoencoder = learn.TensorFlowDNNAutoencoder(hidden_units=[10, 20])
+
+# Fit with Iris data
+transformed = autoencoder.fit_transform(iris.data)
+
+print(transformed)
diff --git a/tensorflow/examples/skflow/hdf5_classification.py b/tensorflow/examples/skflow/hdf5_classification.py
new file mode 100644
index 0000000000..0a4a7fd731
--- /dev/null
+++ b/tensorflow/examples/skflow/hdf5_classification.py
@@ -0,0 +1,49 @@
+# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from sklearn import metrics, cross_validation
+from tensorflow.contrib import learn
+import h5py
+
+# Load dataset.
+iris = learn.datasets.load_dataset('iris')
+X_train, X_test, y_train, y_test = cross_validation.train_test_split(iris.data, iris.target,
+ test_size=0.2, random_state=42)
+
+# Note that we are saving and load iris data as h5 format as a simple demonstration here.
+h5f = h5py.File('test_hdf5.h5', 'w')
+h5f.create_dataset('X_train', data=X_train)
+h5f.create_dataset('X_test', data=X_test)
+h5f.create_dataset('y_train', data=y_train)
+h5f.create_dataset('y_test', data=y_test)
+h5f.close()
+
+h5f = h5py.File('test_hdf5.h5', 'r')
+X_train = h5f['X_train']
+X_test = h5f['X_test']
+y_train = h5f['y_train']
+y_test = h5f['y_test']
+
+# Build 3 layer DNN with 10, 20, 10 units respectively.
+classifier = learn.TensorFlowDNNClassifier(hidden_units=[10, 20, 10],
+ n_classes=3, steps=200)
+
+# Fit and predict.
+classifier.fit(X_train, y_train)
+score = metrics.accuracy_score(y_test, classifier.predict(X_test))
+print('Accuracy: {0:f}'.format(score))
+
diff --git a/tensorflow/examples/skflow/iris.py b/tensorflow/examples/skflow/iris.py
index 3ea1ef1f4c..c6c566b10f 100644
--- a/tensorflow/examples/skflow/iris.py
+++ b/tensorflow/examples/skflow/iris.py
@@ -17,15 +17,15 @@ from __future__ import print_function
from sklearn import metrics, cross_validation
-from tensorflow.contrib import skflow
+from tensorflow.contrib import learn
# Load dataset.
-iris = skflow.datasets.load_dataset('iris')
+iris = learn.datasets.load_dataset('iris')
X_train, X_test, y_train, y_test = cross_validation.train_test_split(iris.data, iris.target,
test_size=0.2, random_state=42)
# Build 3 layer DNN with 10, 20, 10 units respectively.
-classifier = skflow.TensorFlowDNNClassifier(hidden_units=[10, 20, 10],
+classifier = learn.TensorFlowDNNClassifier(hidden_units=[10, 20, 10],
n_classes=3, steps=200)
# Fit and predict.
diff --git a/tensorflow/examples/skflow/iris_custom_model.py b/tensorflow/examples/skflow/iris_custom_model.py
index f142a7db8c..11c7bc88d6 100644
--- a/tensorflow/examples/skflow/iris_custom_model.py
+++ b/tensorflow/examples/skflow/iris_custom_model.py
@@ -16,7 +16,7 @@ from __future__ import division
from __future__ import print_function
from sklearn import datasets, metrics, cross_validation
-from tensorflow.contrib import skflow
+from tensorflow.contrib import learn
iris = datasets.load_iris()
X_train, X_test, y_train, y_test = cross_validation.train_test_split(iris.data, iris.target,
@@ -24,10 +24,10 @@ X_train, X_test, y_train, y_test = cross_validation.train_test_split(iris.data,
def my_model(X, y):
"""This is DNN with 10, 20, 10 hidden layers, and dropout of 0.1 probability."""
- layers = skflow.ops.dnn(X, [10, 20, 10], dropout=0.1)
- return skflow.models.logistic_regression(layers, y)
+ layers = learn.ops.dnn(X, [10, 20, 10], dropout=0.1)
+ return learn.models.logistic_regression(layers, y)
-classifier = skflow.TensorFlowEstimator(model_fn=my_model, n_classes=3,
+classifier = learn.TensorFlowEstimator(model_fn=my_model, n_classes=3,
steps=1000)
classifier.fit(X_train, y_train)
score = metrics.accuracy_score(y_test, classifier.predict(X_test))
diff --git a/tensorflow/examples/skflow/iris_run_config.py b/tensorflow/examples/skflow/iris_run_config.py
index c85fcec224..4f057f817d 100644
--- a/tensorflow/examples/skflow/iris_run_config.py
+++ b/tensorflow/examples/skflow/iris_run_config.py
@@ -17,7 +17,7 @@ from __future__ import print_function
from sklearn import datasets, metrics, cross_validation
-from tensorflow.contrib import skflow
+from tensorflow.contrib import learn
# Load dataset.
@@ -27,10 +27,10 @@ X_train, X_test, y_train, y_test = cross_validation.train_test_split(iris.data,
# You can define you configurations by providing a RunConfig object to
# estimator to control session configurations, e.g. num_cores and gpu_memory_fraction
-run_config = skflow.estimators.RunConfig(num_cores=3, gpu_memory_fraction=0.6)
+run_config = learn.estimators.RunConfig(num_cores=3, gpu_memory_fraction=0.6)
# Build 3 layer DNN with 10, 20, 10 units respecitvely.
-classifier = skflow.TensorFlowDNNClassifier(hidden_units=[10, 20, 10],
+classifier = learn.TensorFlowDNNClassifier(hidden_units=[10, 20, 10],
n_classes=3, steps=200, config=run_config)
# Fit and predict.
diff --git a/tensorflow/examples/skflow/iris_save_restore.py b/tensorflow/examples/skflow/iris_save_restore.py
index f9613c96fe..d29237a26f 100644
--- a/tensorflow/examples/skflow/iris_save_restore.py
+++ b/tensorflow/examples/skflow/iris_save_restore.py
@@ -18,13 +18,13 @@ from __future__ import print_function
import shutil
from sklearn import datasets, metrics, cross_validation
-from tensorflow.contrib import skflow
+from tensorflow.contrib import learn
iris = datasets.load_iris()
X_train, X_test, y_train, y_test = cross_validation.train_test_split(iris.data, iris.target,
test_size=0.2, random_state=42)
-classifier = skflow.TensorFlowLinearClassifier(n_classes=3)
+classifier = learn.TensorFlowLinearClassifier(n_classes=3)
classifier.fit(X_train, y_train)
score = metrics.accuracy_score(y_test, classifier.predict(X_test))
print('Accuracy: {0:f}'.format(score))
@@ -40,6 +40,6 @@ classifier.save('/tmp/skflow_examples/iris_custom_model')
classifier = None
## Restore everything
-new_classifier = skflow.TensorFlowEstimator.restore('/tmp/skflow_examples/iris_custom_model')
+new_classifier = learn.TensorFlowEstimator.restore('/tmp/skflow_examples/iris_custom_model')
score = metrics.accuracy_score(y_test, new_classifier.predict(X_test))
print('Accuracy: {0:f}'.format(score))
diff --git a/tensorflow/examples/skflow/iris_val_based_early_stopping.py b/tensorflow/examples/skflow/iris_val_based_early_stopping.py
index 74e02561f3..3ded960bac 100644
--- a/tensorflow/examples/skflow/iris_val_based_early_stopping.py
+++ b/tensorflow/examples/skflow/iris_val_based_early_stopping.py
@@ -18,7 +18,7 @@ from __future__ import print_function
from sklearn import datasets, metrics
from sklearn.cross_validation import train_test_split
-from tensorflow.contrib import skflow
+from tensorflow.contrib import learn
iris = datasets.load_iris()
@@ -29,17 +29,17 @@ X_train, X_test, y_train, y_test = train_test_split(iris.data,
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train,
test_size=0.2, random_state=42)
-val_monitor = skflow.monitors.ValidationMonitor(X_val, y_val,
- early_stopping_rounds=200)
+val_monitor = learn.monitors.ValidationMonitor(X_val, y_val,
+ early_stopping_rounds=200)
# classifier with early stopping on training data
-classifier1 = skflow.TensorFlowDNNClassifier(hidden_units=[10, 20, 10],
+classifier1 = learn.TensorFlowDNNClassifier(hidden_units=[10, 20, 10],
n_classes=3, steps=2000)
classifier1.fit(X_train, y_train, logdir='/tmp/iris_model/')
score1 = metrics.accuracy_score(y_test, classifier1.predict(X_test))
# classifier with early stopping on validation data
-classifier2 = skflow.TensorFlowDNNClassifier(hidden_units=[10, 20, 10],
+classifier2 = learn.TensorFlowDNNClassifier(hidden_units=[10, 20, 10],
n_classes=3, steps=2000)
classifier2.fit(X_train, y_train, val_monitor, logdir='/tmp/iris_model_val/')
score2 = metrics.accuracy_score(y_test, classifier2.predict(X_test))
diff --git a/tensorflow/examples/skflow/iris_with_pipeline.py b/tensorflow/examples/skflow/iris_with_pipeline.py
index f6408f84a8..3ba5739250 100644
--- a/tensorflow/examples/skflow/iris_with_pipeline.py
+++ b/tensorflow/examples/skflow/iris_with_pipeline.py
@@ -20,7 +20,7 @@ from sklearn.datasets import load_iris
from sklearn import cross_validation
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
-from tensorflow.contrib import skflow
+from tensorflow.contrib import learn
iris = load_iris()
X_train, X_test, y_train, y_test = cross_validation.train_test_split(iris.data, iris.target,
@@ -30,7 +30,7 @@ X_train, X_test, y_train, y_test = cross_validation.train_test_split(iris.data,
scaler = StandardScaler()
# DNN classifier
-DNNclassifier = skflow.TensorFlowDNNClassifier(hidden_units=[10, 20, 10], n_classes=3, steps=200)
+DNNclassifier = learn.TensorFlowDNNClassifier(hidden_units=[10, 20, 10], n_classes=3, steps=200)
pipeline = Pipeline([('scaler', scaler), ('DNNclassifier', DNNclassifier)])
diff --git a/tensorflow/examples/skflow/language_model.py b/tensorflow/examples/skflow/language_model.py
index 439d6f3198..ccc1b04c5b 100644
--- a/tensorflow/examples/skflow/language_model.py
+++ b/tensorflow/examples/skflow/language_model.py
@@ -22,7 +22,7 @@ import math
import numpy as np
import tensorflow as tf
-from tensorflow.contrib import skflow
+from tensorflow.contrib import learn
### Training data
@@ -53,7 +53,7 @@ def unpack_xy(iter_obj):
return (item[0] for item in X), (item[1] for item in y)
-byte_processor = skflow.preprocessing.ByteProcessor(
+byte_processor = learn.preprocessing.ByteProcessor(
max_document_length=MAX_DOC_LENGTH)
data = training_data(CORPUS_FILENAME)
@@ -68,31 +68,31 @@ HIDDEN_SIZE = 10
def seq_autoencoder(X, y):
"""Sequence auto-encoder with RNN."""
- inputs = skflow.ops.one_hot_matrix(X, 256)
- in_X, in_y, out_y = skflow.ops.seq2seq_inputs(inputs, y, MAX_DOC_LENGTH, MAX_DOC_LENGTH)
+ inputs = learn.ops.one_hot_matrix(X, 256)
+ in_X, in_y, out_y = learn.ops.seq2seq_inputs(inputs, y, MAX_DOC_LENGTH, MAX_DOC_LENGTH)
encoder_cell = tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE)
decoder_cell = tf.nn.rnn_cell.OutputProjectionWrapper(tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE), 256)
- decoding, _, sampling_decoding, _ = skflow.ops.rnn_seq2seq(in_X, in_y, encoder_cell, decoder_cell)
- return skflow.ops.sequence_classifier(decoding, out_y, sampling_decoding)
+ decoding, _, sampling_decoding, _ = learn.ops.rnn_seq2seq(in_X, in_y, encoder_cell, decoder_cell)
+ return learn.ops.sequence_classifier(decoding, out_y, sampling_decoding)
def get_language_model(hidden_size):
"""Returns a language model with given hidden size."""
def language_model(X, y):
- inputs = skflow.ops.one_hot_matrix(X, 256)
- inputs = skflow.ops.split_squeeze(1, MAX_DOC_LENGTH, inputs)
- target = skflow.ops.split_squeeze(1, MAX_DOC_LENGTH, y)
+ inputs = learn.ops.one_hot_matrix(X, 256)
+ inputs = learn.ops.split_squeeze(1, MAX_DOC_LENGTH, inputs)
+ target = learn.ops.split_squeeze(1, MAX_DOC_LENGTH, y)
encoder_cell = tf.nn.rnn_cell.OutputProjectionWrapper(tf.nn.rnn_cell.GRUCell(hidden_size),256)
output, _ = tf.nn.rnn(encoder_cell, inputs, dtype=tf.float32)
- return skflow.ops.sequence_classifier(output, target)
+ return learn.ops.sequence_classifier(output, target)
return language_model
### Training model.
-estimator = skflow.TensorFlowEstimator(model_fn=get_language_model(HIDDEN_SIZE),
+estimator = learn.TensorFlowEstimator(model_fn=get_language_model(HIDDEN_SIZE),
n_classes=256,
optimizer='Adam', learning_rate=0.01,
steps=1000, batch_size=64, continue_training=True)
diff --git a/tensorflow/examples/skflow/mnist.py b/tensorflow/examples/skflow/mnist.py
index 504a5ae9f8..082ecb2f83 100644
--- a/tensorflow/examples/skflow/mnist.py
+++ b/tensorflow/examples/skflow/mnist.py
@@ -24,15 +24,15 @@ from __future__ import print_function
from sklearn import metrics
import tensorflow as tf
-from tensorflow.contrib import skflow
+from tensorflow.contrib import learn
### Download and load MNIST data.
-mnist = skflow.datasets.load_dataset('mnist')
+mnist = learn.datasets.load_dataset('mnist')
### Linear classifier.
-classifier = skflow.TensorFlowLinearClassifier(
+classifier = learn.TensorFlowLinearClassifier(
n_classes=10, batch_size=100, steps=1000, learning_rate=0.01)
classifier.fit(mnist.train.images, mnist.train.labels)
score = metrics.accuracy_score(mnist.test.labels, classifier.predict(mnist.test.images))
@@ -50,22 +50,22 @@ def conv_model(X, y):
X = tf.reshape(X, [-1, 28, 28, 1])
# first conv layer will compute 32 features for each 5x5 patch
with tf.variable_scope('conv_layer1'):
- h_conv1 = skflow.ops.conv2d(X, n_filters=32, filter_shape=[5, 5],
+ h_conv1 = learn.ops.conv2d(X, n_filters=32, filter_shape=[5, 5],
bias=True, activation=tf.nn.relu)
h_pool1 = max_pool_2x2(h_conv1)
# second conv layer will compute 64 features for each 5x5 patch
with tf.variable_scope('conv_layer2'):
- h_conv2 = skflow.ops.conv2d(h_pool1, n_filters=64, filter_shape=[5, 5],
+ h_conv2 = learn.ops.conv2d(h_pool1, n_filters=64, filter_shape=[5, 5],
bias=True, activation=tf.nn.relu)
h_pool2 = max_pool_2x2(h_conv2)
# reshape tensor into a batch of vectors
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
# densely connected layer with 1024 neurons
- h_fc1 = skflow.ops.dnn(h_pool2_flat, [1024], activation=tf.nn.relu, dropout=0.5)
- return skflow.models.logistic_regression(h_fc1, y)
+ h_fc1 = learn.ops.dnn(h_pool2_flat, [1024], activation=tf.nn.relu, dropout=0.5)
+ return learn.models.logistic_regression(h_fc1, y)
# Training and predicting
-classifier = skflow.TensorFlowEstimator(
+classifier = learn.TensorFlowEstimator(
model_fn=conv_model, n_classes=10, batch_size=100, steps=20000,
learning_rate=0.001)
classifier.fit(mnist.train.images, mnist.train.labels)
diff --git a/tensorflow/examples/skflow/mnist_weights.py b/tensorflow/examples/skflow/mnist_weights.py
index 7d4acd24af..b0c2ea583e 100644
--- a/tensorflow/examples/skflow/mnist_weights.py
+++ b/tensorflow/examples/skflow/mnist_weights.py
@@ -24,15 +24,15 @@ from __future__ import print_function
from sklearn import metrics
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
-from tensorflow.contrib import skflow
+from tensorflow.contrib import learn
### Download and load MNIST data.
-mnist = skflow.datasets.load_dataset('mnist')
+mnist = learn.datasets.load_dataset('mnist')
### Linear classifier.
-classifier = skflow.TensorFlowLinearClassifier(
+classifier = learn.TensorFlowLinearClassifier(
n_classes=10, batch_size=100, steps=1000, learning_rate=0.01)
classifier.fit(mnist.train.images, mnist.train.labels)
score = metrics.accuracy_score(mnist.test.labels, classifier.predict(mnist.test.images))
@@ -50,22 +50,22 @@ def conv_model(X, y):
X = tf.reshape(X, [-1, 28, 28, 1])
# first conv layer will compute 32 features for each 5x5 patch
with tf.variable_scope('conv_layer1'):
- h_conv1 = skflow.ops.conv2d(X, n_filters=32, filter_shape=[5, 5],
+ h_conv1 = learn.ops.conv2d(X, n_filters=32, filter_shape=[5, 5],
bias=True, activation=tf.nn.relu)
h_pool1 = max_pool_2x2(h_conv1)
# second conv layer will compute 64 features for each 5x5 patch
with tf.variable_scope('conv_layer2'):
- h_conv2 = skflow.ops.conv2d(h_pool1, n_filters=64, filter_shape=[5, 5],
+ h_conv2 = learn.ops.conv2d(h_pool1, n_filters=64, filter_shape=[5, 5],
bias=True, activation=tf.nn.relu)
h_pool2 = max_pool_2x2(h_conv2)
# reshape tensor into a batch of vectors
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
# densely connected layer with 1024 neurons
- h_fc1 = skflow.ops.dnn(h_pool2_flat, [1024], activation=tf.nn.relu, dropout=0.5)
- return skflow.models.logistic_regression(h_fc1, y)
+ h_fc1 = learn.ops.dnn(h_pool2_flat, [1024], activation=tf.nn.relu, dropout=0.5)
+ return learn.models.logistic_regression(h_fc1, y)
# Training and predicting
-classifier = skflow.TensorFlowEstimator(
+classifier = learn.TensorFlowEstimator(
model_fn=conv_model, n_classes=10, batch_size=100, steps=20000,
learning_rate=0.001)
classifier.fit(mnist.train.images, mnist.train.labels)
diff --git a/tensorflow/examples/skflow/multioutput_regression.py b/tensorflow/examples/skflow/multioutput_regression.py
index a4300a5508..c0ddf1cf30 100644
--- a/tensorflow/examples/skflow/multioutput_regression.py
+++ b/tensorflow/examples/skflow/multioutput_regression.py
@@ -26,7 +26,7 @@ import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.metrics import mean_squared_error
-from tensorflow.contrib import skflow
+from tensorflow.contrib import learn
# Create random dataset.
rng = np.random.RandomState(1)
@@ -38,11 +38,11 @@ regressors = []
options = [[2], [10, 10], [20, 20]]
for hidden_units in options:
def tanh_dnn(X, y):
- features = skflow.ops.dnn(X, hidden_units=hidden_units,
- activation=skflow.tf.tanh)
- return skflow.models.linear_regression(features, y)
+ features = learn.ops.dnn(X, hidden_units=hidden_units,
+ activation=learn.tf.tanh)
+ return learn.models.linear_regression(features, y)
- regressor = skflow.TensorFlowEstimator(model_fn=tanh_dnn, n_classes=0,
+ regressor = learn.TensorFlowEstimator(model_fn=tanh_dnn, n_classes=0,
steps=500, learning_rate=0.1, batch_size=100)
regressor.fit(X, y)
score = mean_squared_error(regressor.predict(X), y)
diff --git a/tensorflow/examples/skflow/multiple_gpu.py b/tensorflow/examples/skflow/multiple_gpu.py
index 4afa667e6d..1168184a38 100644
--- a/tensorflow/examples/skflow/multiple_gpu.py
+++ b/tensorflow/examples/skflow/multiple_gpu.py
@@ -17,7 +17,7 @@ from __future__ import print_function
from sklearn import datasets, metrics, cross_validation
import tensorflow as tf
-from tensorflow.contrib import skflow
+from tensorflow.contrib import learn
iris = datasets.load_iris()
X_train, X_test, y_train, y_test = cross_validation.train_test_split(iris.data, iris.target,
@@ -31,11 +31,11 @@ def my_model(X, y):
CUDNN 6.5 V2 from NVIDIA need to be installed beforehand.
"""
with tf.device('/gpu:1'):
- layers = skflow.ops.dnn(X, [10, 20, 10], dropout=0.5)
+ layers = learn.ops.dnn(X, [10, 20, 10], dropout=0.5)
with tf.device('/gpu:2'):
- return skflow.models.logistic_regression(layers, y)
+ return learn.models.logistic_regression(layers, y)
-classifier = skflow.TensorFlowEstimator(model_fn=my_model, n_classes=3)
+classifier = learn.TensorFlowEstimator(model_fn=my_model, n_classes=3)
classifier.fit(X_train, y_train)
score = metrics.accuracy_score(y_test, classifier.predict(X_test))
print('Accuracy: {0:f}'.format(score))
diff --git a/tensorflow/examples/skflow/neural_translation.py b/tensorflow/examples/skflow/neural_translation.py
index d0c8c33b93..7832767145 100644
--- a/tensorflow/examples/skflow/neural_translation.py
+++ b/tensorflow/examples/skflow/neural_translation.py
@@ -22,7 +22,7 @@ import os
import numpy as np
import tensorflow as tf
-from tensorflow.contrib import skflow
+from tensorflow.contrib import learn
# Get training data
@@ -88,15 +88,15 @@ MAX_DOCUMENT_LENGTH = 30
HIDDEN_SIZE = 100
def translate_model(X, y):
- byte_list = skflow.ops.one_hot_matrix(X, 256)
- in_X, in_y, out_y = skflow.ops.seq2seq_inputs(
+ byte_list = learn.ops.one_hot_matrix(X, 256)
+ in_X, in_y, out_y = learn.ops.seq2seq_inputs(
byte_list, y, MAX_DOCUMENT_LENGTH, MAX_DOCUMENT_LENGTH)
cell = tf.nn.rnn_cell.OutputProjectionWrapper(tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE), 256)
- decoding, _, sampling_decoding, _ = skflow.ops.rnn_seq2seq(in_X, in_y, cell)
- return skflow.ops.sequence_classifier(decoding, out_y, sampling_decoding)
+ decoding, _, sampling_decoding, _ = learn.ops.rnn_seq2seq(in_X, in_y, cell)
+ return learn.ops.sequence_classifier(decoding, out_y, sampling_decoding)
-vocab_processor = skflow.preprocessing.ByteProcessor(
+vocab_processor = learn.preprocessing.ByteProcessor(
max_document_length=MAX_DOCUMENT_LENGTH)
x_iter = vocab_processor.transform(X_train)
@@ -107,9 +107,9 @@ ygold = list(y_test)[:20]
PATH = '/tmp/tf_examples/ntm/'
if os.path.exists(PATH):
- translator = skflow.TensorFlowEstimator.restore(PATH)
+ translator = learn.TensorFlowEstimator.restore(PATH)
else:
- translator = skflow.TensorFlowEstimator(model_fn=translate_model,
+ translator = learn.TensorFlowEstimator(model_fn=translate_model,
n_classes=256,
optimizer='Adam', learning_rate=0.01, batch_size=128,
continue_training=True)
diff --git a/tensorflow/examples/skflow/neural_translation_word.py b/tensorflow/examples/skflow/neural_translation_word.py
index e917c97a9b..90c73f0ba5 100644
--- a/tensorflow/examples/skflow/neural_translation_word.py
+++ b/tensorflow/examples/skflow/neural_translation_word.py
@@ -25,7 +25,7 @@ import random
import numpy as np
import tensorflow as tf
-from tensorflow.contrib import skflow
+from tensorflow.contrib import learn
# Get training data
@@ -93,9 +93,9 @@ X_test, y_test = Xy(read_iterator('test.data'))
MAX_DOCUMENT_LENGTH = 10
if not (os.path.exists('en.vocab') and os.path.exists('fr.vocab')):
- X_vocab_processor = skflow.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH,
+ X_vocab_processor = learn.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH,
min_frequency=5)
- y_vocab_processor = skflow.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH,
+ y_vocab_processor = learn.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH,
min_frequency=5)
Xtrainff, ytrainff = Xy(read_iterator('train.data'))
print('Fitting dictionary for English...')
@@ -126,24 +126,24 @@ HIDDEN_SIZE = 20
EMBEDDING_SIZE = 20
def translate_model(X, y):
- word_vectors = skflow.ops.categorical_variable(X, n_classes=n_en_words,
+ word_vectors = learn.ops.categorical_variable(X, n_classes=n_en_words,
embedding_size=EMBEDDING_SIZE, name='words')
- in_X, in_y, out_y = skflow.ops.seq2seq_inputs(
+ in_X, in_y, out_y = learn.ops.seq2seq_inputs(
word_vectors, y, MAX_DOCUMENT_LENGTH, MAX_DOCUMENT_LENGTH)
encoder_cell = tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE)
decoder_cell = tf.nn.rnn_cell.OutputProjectionWrapper(
tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE), n_fr_words)
- decoding, _, sampling_decoding, _ = skflow.ops.rnn_seq2seq(in_X, in_y,
+ decoding, _, sampling_decoding, _ = learn.ops.rnn_seq2seq(in_X, in_y,
encoder_cell, decoder_cell=decoder_cell)
- return skflow.ops.sequence_classifier(decoding, out_y, sampling_decoding)
+ return learn.ops.sequence_classifier(decoding, out_y, sampling_decoding)
PATH = '/tmp/tf_examples/ntm_words/'
if os.path.exists(os.path.join(PATH, 'graph.pbtxt')):
- translator = skflow.TensorFlowEstimator.restore(PATH)
+ translator = learn.TensorFlowEstimator.restore(PATH)
else:
- translator = skflow.TensorFlowEstimator(model_fn=translate_model,
+ translator = learn.TensorFlowEstimator(model_fn=translate_model,
n_classes=n_fr_words,
optimizer='Adam', learning_rate=0.01, batch_size=128,
continue_training=True, steps=100)
diff --git a/tensorflow/examples/skflow/out_of_core_data_classification.py b/tensorflow/examples/skflow/out_of_core_data_classification.py
index 9c09f436c6..6328941e6d 100644
--- a/tensorflow/examples/skflow/out_of_core_data_classification.py
+++ b/tensorflow/examples/skflow/out_of_core_data_classification.py
@@ -20,7 +20,7 @@ from sklearn import datasets, metrics, cross_validation
import pandas as pd
import dask.dataframe as dd
-from tensorflow.contrib import skflow
+from tensorflow.contrib import learn
# Sometimes when your dataset is too large to hold in the memory
# you may want to load it into a out-of-core dataframe as provided by dask library
@@ -41,7 +41,7 @@ X_train, y_train, X_test, y_test = [pd.DataFrame(data) for data in [X_train, y_t
X_train, y_train, X_test, y_test = [dd.from_pandas(data, npartitions=2) for data in [X_train, y_train, X_test, y_test]]
# Initialize a TensorFlow linear classifier
-classifier = skflow.TensorFlowLinearClassifier(n_classes=3)
+classifier = learn.TensorFlowLinearClassifier(n_classes=3)
# Fit the model using training set
classifier.fit(X_train, y_train)
diff --git a/tensorflow/examples/skflow/resnet.py b/tensorflow/examples/skflow/resnet.py
index e4dab3fd25..f1f39568d4 100644
--- a/tensorflow/examples/skflow/resnet.py
+++ b/tensorflow/examples/skflow/resnet.py
@@ -30,7 +30,7 @@ from math import sqrt
from sklearn import metrics
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
-from tensorflow.contrib import skflow
+from tensorflow.contrib import learn
def res_net(x, y, activation=tf.nn.relu):
@@ -62,7 +62,7 @@ def res_net(x, y, activation=tf.nn.relu):
# First convolution expands to 64 channels
with tf.variable_scope('conv_layer1'):
- net = skflow.ops.conv2d(x, 64, [7, 7], batch_norm=True,
+ net = learn.ops.conv2d(x, 64, [7, 7], batch_norm=True,
activation=activation, bias=False)
# Max pool
@@ -71,7 +71,7 @@ def res_net(x, y, activation=tf.nn.relu):
# First chain of resnets
with tf.variable_scope('conv_layer2'):
- net = skflow.ops.conv2d(net, blocks[0].num_filters,
+ net = learn.ops.conv2d(net, blocks[0].num_filters,
[1, 1], [1, 1, 1, 1],
padding='VALID', bias=True)
@@ -83,7 +83,7 @@ def res_net(x, y, activation=tf.nn.relu):
# 1x1 convolution responsible for reducing dimension
with tf.variable_scope(name + '/conv_in'):
- conv = skflow.ops.conv2d(net, block.bottleneck_size,
+ conv = learn.ops.conv2d(net, block.bottleneck_size,
[1, 1], [1, 1, 1, 1],
padding='VALID',
activation=activation,
@@ -91,7 +91,7 @@ def res_net(x, y, activation=tf.nn.relu):
bias=False)
with tf.variable_scope(name + '/conv_bottleneck'):
- conv = skflow.ops.conv2d(conv, block.bottleneck_size,
+ conv = learn.ops.conv2d(conv, block.bottleneck_size,
[3, 3], [1, 1, 1, 1],
padding='SAME',
activation=activation,
@@ -100,7 +100,7 @@ def res_net(x, y, activation=tf.nn.relu):
# 1x1 convolution responsible for restoring dimension
with tf.variable_scope(name + '/conv_out'):
- conv = skflow.ops.conv2d(conv, block.num_filters,
+ conv = learn.ops.conv2d(conv, block.num_filters,
[1, 1], [1, 1, 1, 1],
padding='VALID',
activation=activation,
@@ -115,7 +115,7 @@ def res_net(x, y, activation=tf.nn.relu):
# upscale to the next block size
next_block = blocks[block_i + 1]
with tf.variable_scope('block_%d/conv_upscale' % block_i):
- net = skflow.ops.conv2d(net, next_block.num_filters,
+ net = learn.ops.conv2d(net, next_block.num_filters,
[1, 1], [1, 1, 1, 1],
bias=False,
padding='SAME')
@@ -130,7 +130,7 @@ def res_net(x, y, activation=tf.nn.relu):
net_shape = net.get_shape().as_list()
net = tf.reshape(net, [-1, net_shape[1] * net_shape[2] * net_shape[3]])
- return skflow.models.logistic_regression(net, y)
+ return learn.models.logistic_regression(net, y)
# Download and load MNIST data.
@@ -138,10 +138,10 @@ mnist = input_data.read_data_sets('MNIST_data')
# Restore model if graph is saved into a folder.
if os.path.exists("models/resnet/graph.pbtxt"):
- classifier = skflow.TensorFlowEstimator.restore("models/resnet/")
+ classifier = learn.TensorFlowEstimator.restore("models/resnet/")
else:
# Create a new resnet classifier.
- classifier = skflow.TensorFlowEstimator(
+ classifier = learn.TensorFlowEstimator(
model_fn=res_net, n_classes=10, batch_size=100, steps=100,
learning_rate=0.001, continue_training=True)
diff --git a/tensorflow/examples/skflow/text_classification.py b/tensorflow/examples/skflow/text_classification.py
index 190d4a1464..08ef507f17 100644
--- a/tensorflow/examples/skflow/text_classification.py
+++ b/tensorflow/examples/skflow/text_classification.py
@@ -20,12 +20,12 @@ from sklearn import metrics
import pandas
import tensorflow as tf
-from tensorflow.contrib import skflow
+from tensorflow.contrib import learn
### Training data
# Downloads, unpacks and reads DBpedia dataset.
-dbpedia = skflow.datasets.load_dataset('dbpedia')
+dbpedia = learn.datasets.load_dataset('dbpedia')
X_train, y_train = pandas.DataFrame(dbpedia.train.data)[1], pandas.Series(dbpedia.train.target)
X_test, y_test = pandas.DataFrame(dbpedia.test.data)[1], pandas.Series(dbpedia.test.target)
@@ -33,7 +33,7 @@ X_test, y_test = pandas.DataFrame(dbpedia.test.data)[1], pandas.Series(dbpedia.t
MAX_DOCUMENT_LENGTH = 10
-vocab_processor = skflow.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH)
+vocab_processor = learn.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH)
X_train = np.array(list(vocab_processor.fit_transform(X_train)))
X_test = np.array(list(vocab_processor.transform(X_test)))
@@ -45,10 +45,10 @@ print('Total words: %d' % n_words)
EMBEDDING_SIZE = 50
def average_model(X, y):
- word_vectors = skflow.ops.categorical_variable(X, n_classes=n_words,
+ word_vectors = learn.ops.categorical_variable(X, n_classes=n_words,
embedding_size=EMBEDDING_SIZE, name='words')
features = tf.reduce_max(word_vectors, reduction_indices=1)
- return skflow.models.logistic_regression(features, y)
+ return learn.models.logistic_regression(features, y)
def rnn_model(X, y):
"""Recurrent neural network model to predict from sequence of words
@@ -57,11 +57,11 @@ def rnn_model(X, y):
# This creates embeddings matrix of [n_words, EMBEDDING_SIZE] and then
# maps word indexes of the sequence into [batch_size, sequence_length,
# EMBEDDING_SIZE].
- word_vectors = skflow.ops.categorical_variable(X, n_classes=n_words,
+ word_vectors = learn.ops.categorical_variable(X, n_classes=n_words,
embedding_size=EMBEDDING_SIZE, name='words')
# Split into list of embedding per word, while removing doc length dim.
# word_list results to be a list of tensors [batch_size, EMBEDDING_SIZE].
- word_list = skflow.ops.split_squeeze(1, MAX_DOCUMENT_LENGTH, word_vectors)
+ word_list = learn.ops.split_squeeze(1, MAX_DOCUMENT_LENGTH, word_vectors)
# Create a Gated Recurrent Unit cell with hidden size of EMBEDDING_SIZE.
cell = tf.nn.rnn_cell.GRUCell(EMBEDDING_SIZE)
# Create an unrolled Recurrent Neural Networks to length of
@@ -70,9 +70,9 @@ def rnn_model(X, y):
# Given encoding of RNN, take encoding of last step (e.g hidden size of the
# neural network of last step) and pass it as features for logistic
# regression over output classes.
- return skflow.models.logistic_regression(encoding, y)
+ return learn.models.logistic_regression(encoding, y)
-classifier = skflow.TensorFlowEstimator(model_fn=rnn_model, n_classes=15,
+classifier = learn.TensorFlowEstimator(model_fn=rnn_model, n_classes=15,
steps=1000, optimizer='Adam', learning_rate=0.01, continue_training=True)
# Continuously train for 1000 steps & predict on test set.
diff --git a/tensorflow/examples/skflow/text_classification_builtin_rnn_model.py b/tensorflow/examples/skflow/text_classification_builtin_rnn_model.py
index 9eb7dcfb97..d4824d52c6 100644
--- a/tensorflow/examples/skflow/text_classification_builtin_rnn_model.py
+++ b/tensorflow/examples/skflow/text_classification_builtin_rnn_model.py
@@ -20,12 +20,12 @@ from sklearn import metrics
import pandas
import tensorflow as tf
-from tensorflow.contrib import skflow
+from tensorflow.contrib import learn
### Training data
# Downloads, unpacks and reads DBpedia dataset.
-dbpedia = skflow.datasets.load_dataset('dbpedia')
+dbpedia = learn.datasets.load_dataset('dbpedia')
X_train, y_train = pandas.DataFrame(dbpedia.train.data)[1], pandas.Series(dbpedia.train.target)
X_test, y_test = pandas.DataFrame(dbpedia.test.data)[1], pandas.Series(dbpedia.test.target)
@@ -33,7 +33,7 @@ X_test, y_test = pandas.DataFrame(dbpedia.test.data)[1], pandas.Series(dbpedia.t
MAX_DOCUMENT_LENGTH = 10
-vocab_processor = skflow.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH)
+vocab_processor = learn.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH)
X_train = np.array(list(vocab_processor.fit_transform(X_train)))
X_test = np.array(list(vocab_processor.transform(X_test)))
@@ -50,15 +50,15 @@ def input_op_fn(X):
# This creates embeddings matrix of [n_words, EMBEDDING_SIZE] and then
# maps word indexes of the sequence into [batch_size, sequence_length,
# EMBEDDING_SIZE].
- word_vectors = skflow.ops.categorical_variable(X, n_classes=n_words,
+ word_vectors = learn.ops.categorical_variable(X, n_classes=n_words,
embedding_size=EMBEDDING_SIZE, name='words')
# Split into list of embedding per word, while removing doc length dim.
# word_list results to be a list of tensors [batch_size, EMBEDDING_SIZE].
- word_list = skflow.ops.split_squeeze(1, MAX_DOCUMENT_LENGTH, word_vectors)
+ word_list = learn.ops.split_squeeze(1, MAX_DOCUMENT_LENGTH, word_vectors)
return word_list
# Single direction GRU with a single layer
-classifier = skflow.TensorFlowRNNClassifier(rnn_size=EMBEDDING_SIZE,
+classifier = learn.TensorFlowRNNClassifier(rnn_size=EMBEDDING_SIZE,
n_classes=15, cell_type='gru', input_op_fn=input_op_fn,
num_layers=1, bidirectional=False, sequence_length=None,
steps=1000, optimizer='Adam', learning_rate=0.01, continue_training=True)
diff --git a/tensorflow/examples/skflow/text_classification_character_cnn.py b/tensorflow/examples/skflow/text_classification_character_cnn.py
index 71dec5b4ee..f447ddb3e7 100644
--- a/tensorflow/examples/skflow/text_classification_character_cnn.py
+++ b/tensorflow/examples/skflow/text_classification_character_cnn.py
@@ -32,12 +32,12 @@ from sklearn import metrics
import pandas
import tensorflow as tf
-from tensorflow.contrib import skflow
+from tensorflow.contrib import learn
### Training data
# Downloads, unpacks and reads DBpedia dataset.
-dbpedia = skflow.datasets.load_dataset('dbpedia')
+dbpedia = learn.datasets.load_dataset('dbpedia')
X_train, y_train = pandas.DataFrame(dbpedia.train.data)[1], pandas.Series(dbpedia.train.target)
X_test, y_test = pandas.DataFrame(dbpedia.test.data)[1], pandas.Series(dbpedia.test.target)
@@ -45,7 +45,7 @@ X_test, y_test = pandas.DataFrame(dbpedia.test.data)[1], pandas.Series(dbpedia.t
MAX_DOCUMENT_LENGTH = 100
-char_processor = skflow.preprocessing.ByteProcessor(MAX_DOCUMENT_LENGTH)
+char_processor = learn.preprocessing.ByteProcessor(MAX_DOCUMENT_LENGTH)
X_train = np.array(list(char_processor.fit_transform(X_train)))
X_test = np.array(list(char_processor.transform(X_test)))
@@ -59,11 +59,11 @@ POOLING_STRIDE = 2
def char_cnn_model(X, y):
"""Character level convolutional neural network model to predict classes."""
- byte_list = tf.reshape(skflow.ops.one_hot_matrix(X, 256),
+ byte_list = tf.reshape(learn.ops.one_hot_matrix(X, 256),
[-1, MAX_DOCUMENT_LENGTH, 256, 1])
with tf.variable_scope('CNN_Layer1'):
# Apply Convolution filtering on input sequence.
- conv1 = skflow.ops.conv2d(byte_list, N_FILTERS, FILTER_SHAPE1, padding='VALID')
+ conv1 = learn.ops.conv2d(byte_list, N_FILTERS, FILTER_SHAPE1, padding='VALID')
# Add a RELU for non linearity.
conv1 = tf.nn.relu(conv1)
# Max pooling across output of Convlution+Relu.
@@ -73,14 +73,14 @@ def char_cnn_model(X, y):
pool1 = tf.transpose(pool1, [0, 1, 3, 2])
with tf.variable_scope('CNN_Layer2'):
# Second level of convolution filtering.
- conv2 = skflow.ops.conv2d(pool1, N_FILTERS, FILTER_SHAPE2,
+ conv2 = learn.ops.conv2d(pool1, N_FILTERS, FILTER_SHAPE2,
padding='VALID')
# Max across each filter to get useful features for classification.
pool2 = tf.squeeze(tf.reduce_max(conv2, 1), squeeze_dims=[1])
# Apply regular WX + B and classification.
- return skflow.models.logistic_regression(pool2, y)
+ return learn.models.logistic_regression(pool2, y)
-classifier = skflow.TensorFlowEstimator(model_fn=char_cnn_model, n_classes=15,
+classifier = learn.TensorFlowEstimator(model_fn=char_cnn_model, n_classes=15,
steps=100, optimizer='Adam', learning_rate=0.01, continue_training=True)
# Continuously train for 1000 steps & predict on test set.
diff --git a/tensorflow/examples/skflow/text_classification_character_rnn.py b/tensorflow/examples/skflow/text_classification_character_rnn.py
index af1e37641b..6281850af4 100644
--- a/tensorflow/examples/skflow/text_classification_character_rnn.py
+++ b/tensorflow/examples/skflow/text_classification_character_rnn.py
@@ -32,12 +32,12 @@ from sklearn import metrics
import pandas
import tensorflow as tf
-from tensorflow.contrib import skflow
+from tensorflow.contrib import learn
### Training data
# Downloads, unpacks and reads DBpedia dataset.
-dbpedia = skflow.datasets.load_dataset('dbpedia')
+dbpedia = learn.datasets.load_dataset('dbpedia')
X_train, y_train = pandas.DataFrame(dbpedia.train.data)[1], pandas.Series(dbpedia.train.target)
X_test, y_test = pandas.DataFrame(dbpedia.test.data)[1], pandas.Series(dbpedia.test.target)
@@ -45,7 +45,7 @@ X_test, y_test = pandas.DataFrame(dbpedia.test.data)[1], pandas.Series(dbpedia.t
MAX_DOCUMENT_LENGTH = 100
-char_processor = skflow.preprocessing.ByteProcessor(MAX_DOCUMENT_LENGTH)
+char_processor = learn.preprocessing.ByteProcessor(MAX_DOCUMENT_LENGTH)
X_train = np.array(list(char_processor.fit_transform(X_train)))
X_test = np.array(list(char_processor.transform(X_test)))
@@ -54,13 +54,13 @@ X_test = np.array(list(char_processor.transform(X_test)))
HIDDEN_SIZE = 20
def char_rnn_model(X, y):
- byte_list = skflow.ops.one_hot_matrix(X, 256)
- byte_list = skflow.ops.split_squeeze(1, MAX_DOCUMENT_LENGTH, byte_list)
+ byte_list = learn.ops.one_hot_matrix(X, 256)
+ byte_list = learn.ops.split_squeeze(1, MAX_DOCUMENT_LENGTH, byte_list)
cell = tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE)
_, encoding = tf.nn.rnn(cell, byte_list, dtype=tf.float32)
- return skflow.models.logistic_regression(encoding, y)
+ return learn.models.logistic_regression(encoding, y)
-classifier = skflow.TensorFlowEstimator(model_fn=char_rnn_model, n_classes=15,
+classifier = learn.TensorFlowEstimator(model_fn=char_rnn_model, n_classes=15,
steps=100, optimizer='Adam', learning_rate=0.01, continue_training=True)
# Continuously train for 1000 steps & predict on test set.
diff --git a/tensorflow/examples/skflow/text_classification_cnn.py b/tensorflow/examples/skflow/text_classification_cnn.py
index c42a12819e..de14695e0d 100644
--- a/tensorflow/examples/skflow/text_classification_cnn.py
+++ b/tensorflow/examples/skflow/text_classification_cnn.py
@@ -20,12 +20,12 @@ from sklearn import metrics
import pandas
import tensorflow as tf
-from tensorflow.contrib import skflow
+from tensorflow.contrib import learn
### Training data
# Downloads, unpacks and reads DBpedia dataset.
-dbpedia = skflow.datasets.load_dataset('dbpedia')
+dbpedia = learn.datasets.load_dataset('dbpedia')
X_train, y_train = pandas.DataFrame(dbpedia.train.data)[1], pandas.Series(dbpedia.train.target)
X_test, y_test = pandas.DataFrame(dbpedia.test.data)[1], pandas.Series(dbpedia.test.target)
@@ -33,7 +33,7 @@ X_test, y_test = pandas.DataFrame(dbpedia.test.data)[1], pandas.Series(dbpedia.t
MAX_DOCUMENT_LENGTH = 100
-vocab_processor = skflow.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH)
+vocab_processor = learn.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH)
X_train = np.array(list(vocab_processor.fit_transform(X_train)))
X_test = np.array(list(vocab_processor.transform(X_test)))
@@ -57,12 +57,12 @@ def cnn_model(X, y):
# This creates embeddings matrix of [n_words, EMBEDDING_SIZE] and then
# maps word indexes of the sequence into [batch_size, sequence_length,
# EMBEDDING_SIZE].
- word_vectors = skflow.ops.categorical_variable(X, n_classes=n_words,
+ word_vectors = learn.ops.categorical_variable(X, n_classes=n_words,
embedding_size=EMBEDDING_SIZE, name='words')
word_vectors = tf.expand_dims(word_vectors, 3)
with tf.variable_scope('CNN_Layer1'):
# Apply Convolution filtering on input sequence.
- conv1 = skflow.ops.conv2d(word_vectors, N_FILTERS, FILTER_SHAPE1, padding='VALID')
+ conv1 = learn.ops.conv2d(word_vectors, N_FILTERS, FILTER_SHAPE1, padding='VALID')
# Add a RELU for non linearity.
conv1 = tf.nn.relu(conv1)
# Max pooling across output of Convlution+Relu.
@@ -72,15 +72,15 @@ def cnn_model(X, y):
pool1 = tf.transpose(pool1, [0, 1, 3, 2])
with tf.variable_scope('CNN_Layer2'):
# Second level of convolution filtering.
- conv2 = skflow.ops.conv2d(pool1, N_FILTERS, FILTER_SHAPE2,
+ conv2 = learn.ops.conv2d(pool1, N_FILTERS, FILTER_SHAPE2,
padding='VALID')
# Max across each filter to get useful features for classification.
pool2 = tf.squeeze(tf.reduce_max(conv2, 1), squeeze_dims=[1])
# Apply regular WX + B and classification.
- return skflow.models.logistic_regression(pool2, y)
+ return learn.models.logistic_regression(pool2, y)
-classifier = skflow.TensorFlowEstimator(model_fn=cnn_model, n_classes=15,
+classifier = learn.TensorFlowEstimator(model_fn=cnn_model, n_classes=15,
steps=100, optimizer='Adam', learning_rate=0.01, continue_training=True)
# Continuously train for 1000 steps & predict on test set.
diff --git a/tensorflow/examples/skflow/text_classification_save_restore.py b/tensorflow/examples/skflow/text_classification_save_restore.py
index 71551762b9..3fcde2b308 100644
--- a/tensorflow/examples/skflow/text_classification_save_restore.py
+++ b/tensorflow/examples/skflow/text_classification_save_restore.py
@@ -21,12 +21,12 @@ from sklearn import metrics
import pandas
import tensorflow as tf
-from tensorflow.contrib import skflow
+from tensorflow.contrib import learn
### Training data
# Downloads, unpacks and reads DBpedia dataset.
-dbpedia = skflow.datasets.load_dataset('dbpedia')
+dbpedia = learn.datasets.load_dataset('dbpedia')
X_train, y_train = pandas.DataFrame(dbpedia.train.data)[1], pandas.Series(dbpedia.train.target)
X_test, y_test = pandas.DataFrame(dbpedia.test.data)[1], pandas.Series(dbpedia.test.target)
@@ -34,7 +34,7 @@ X_test, y_test = pandas.DataFrame(dbpedia.test.data)[1], pandas.Series(dbpedia.t
MAX_DOCUMENT_LENGTH = 10
-vocab_processor = skflow.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH)
+vocab_processor = learn.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH)
X_train = np.array(list(vocab_processor.fit_transform(X_train)))
X_test = np.array(list(vocab_processor.transform(X_test)))
@@ -46,10 +46,10 @@ print('Total words: %d' % n_words)
EMBEDDING_SIZE = 50
def average_model(X, y):
- word_vectors = skflow.ops.categorical_variable(X, n_classes=n_words,
+ word_vectors = learn.ops.categorical_variable(X, n_classes=n_words,
embedding_size=EMBEDDING_SIZE, name='words')
features = tf.reduce_max(word_vectors, reduction_indices=1)
- return skflow.models.logistic_regression(features, y)
+ return learn.models.logistic_regression(features, y)
def rnn_model(X, y):
"""Recurrent neural network model to predict from sequence of words
@@ -58,11 +58,11 @@ def rnn_model(X, y):
# This creates embeddings matrix of [n_words, EMBEDDING_SIZE] and then
# maps word indexes of the sequence into [batch_size, sequence_length,
# EMBEDDING_SIZE].
- word_vectors = skflow.ops.categorical_variable(X, n_classes=n_words,
+ word_vectors = learn.ops.categorical_variable(X, n_classes=n_words,
embedding_size=EMBEDDING_SIZE, name='words')
# Split into list of embedding per word, while removing doc length dim.
# word_list results to be a list of tensors [batch_size, EMBEDDING_SIZE].
- word_list = skflow.ops.split_squeeze(1, MAX_DOCUMENT_LENGTH, word_vectors)
+ word_list = learn.ops.split_squeeze(1, MAX_DOCUMENT_LENGTH, word_vectors)
# Create a Gated Recurrent Unit cell with hidden size of EMBEDDING_SIZE.
cell = tf.nn.rnn_cell.GRUCell(EMBEDDING_SIZE)
# Create an unrolled Recurrent Neural Networks to length of
@@ -71,13 +71,13 @@ def rnn_model(X, y):
# Given encoding of RNN, take encoding of last step (e.g hidden size of the
# neural network of last step) and pass it as features for logistic
# regression over output classes.
- return skflow.models.logistic_regression(encoding, y)
+ return learn.models.logistic_regression(encoding, y)
model_path = '/tmp/skflow_examples/text_classification'
if os.path.exists(model_path):
- classifier = skflow.TensorFlowEstimator.restore(model_path)
+ classifier = learn.TensorFlowEstimator.restore(model_path)
else:
- classifier = skflow.TensorFlowEstimator(model_fn=rnn_model, n_classes=15,
+ classifier = learn.TensorFlowEstimator(model_fn=rnn_model, n_classes=15,
steps=100, optimizer='Adam', learning_rate=0.01, continue_training=True)
# Continuously train for 1000 steps
diff --git a/tensorflow/examples/udacity/1_notmnist.ipynb b/tensorflow/examples/udacity/1_notmnist.ipynb
index 2265445815..32ece419b0 100644
--- a/tensorflow/examples/udacity/1_notmnist.ipynb
+++ b/tensorflow/examples/udacity/1_notmnist.ipynb
@@ -110,11 +110,31 @@
},
"source": [
"url = 'http://commondatastorage.googleapis.com/books1000/'\n",
+ "last_percent_reported = None\n",
"\n",
+ "def download_progress_hook(count, blockSize, totalSize):\n",
+ " \"\"\"A hook to report the progress of a download. This is mostly intended for users with\n",
+ " slow internet connections. Reports every 1% change in download progress.\n",
+ " \"\"\"\n",
+ " global last_percent_reported\n",
+ " percent = int(count * blockSize * 100 / totalSize)\n",
+ "\n",
+ " if last_percent_reported != percent:\n",
+ " if percent % 5 == 0:\n",
+ " sys.stdout.write(\"%s%%\" % percent)\n",
+ " sys.stdout.flush()\n",
+ " else:\n",
+ " sys.stdout.write(\".\")\n",
+ " sys.stdout.flush()\n",
+ " \n",
+ " last_percent_reported = percent\n",
+ " \n",
"def maybe_download(filename, expected_bytes, force=False):\n",
" \"\"\"Download a file if not present, and make sure it's the right size.\"\"\"\n",
" if force or not os.path.exists(filename):\n",
- " filename, _ = urlretrieve(url + filename, filename)\n",
+ " print('Attempting to download:', filename) \n",
+ " filename, _ = urlretrieve(url + filename, filename, reporthook=download_progress_hook)\n",
+ " print('\\nDownload Complete!')\n",
" statinfo = os.stat(filename)\n",
" if statinfo.st_size == expected_bytes:\n",
" print('Found and verified', filename)\n",
diff --git a/tensorflow/examples/udacity/5_word2vec.ipynb b/tensorflow/examples/udacity/5_word2vec.ipynb
index 62dbec4e11..f932f62e28 100644
--- a/tensorflow/examples/udacity/5_word2vec.ipynb
+++ b/tensorflow/examples/udacity/5_word2vec.ipynb
@@ -446,6 +446,11 @@
" train_labels, num_sampled, vocabulary_size))\n",
"\n",
" # Optimizer.\n",
+ " # Note: The optimizer will optimize the softmax_weights AND the embeddings.\n",
+ " # This is because the embeddings are defined as a variable quantity and the\n",
+ " # optimizer's `minimize` method will by default modify all variable quantities \n",
+ " # that contribute to the tensor it is passed.\n",
+ " # See docs on `tf.train.Optimizer.minimize()` for more details.\n",
" optimizer = tf.train.AdagradOptimizer(1.0).minimize(loss)\n",
" \n",
" # Compute the similarity between minibatch examples and all embeddings.\n",
diff --git a/tensorflow/g3doc/api_docs/python/constant_op.md b/tensorflow/g3doc/api_docs/python/constant_op.md
index 008174f9d6..1aaf39bd50 100644
--- a/tensorflow/g3doc/api_docs/python/constant_op.md
+++ b/tensorflow/g3doc/api_docs/python/constant_op.md
@@ -60,7 +60,7 @@ tf.zeros_like(tensor) ==> [[0, 0, 0], [0, 0, 0]]
* <b>`tensor`</b>: A `Tensor`.
* <b>`dtype`</b>: A type for the returned `Tensor`. Must be `float32`, `float64`,
- `int8`, `int16`, `int32`, `int64`, `uint8`, or `complex64`.
+ `int8`, `int16`, `int32`, `int64`, `uint8`, `complex64`, or `complex128`.
* <b>`name`</b>: A name for the operation (optional).
@@ -119,7 +119,7 @@ tf.ones_like(tensor) ==> [[1, 1, 1], [1, 1, 1]]
* <b>`tensor`</b>: A `Tensor`.
* <b>`dtype`</b>: A type for the returned `Tensor`. Must be `float32`, `float64`,
- `int8`, `int16`, `int32`, `int64`, `uint8`, or `complex64`.
+ `int8`, `int16`, `int32`, `int64`, `uint8`, `complex64` or `complex128`.
* <b>`name`</b>: A name for the operation (optional).
diff --git a/tensorflow/g3doc/api_docs/python/control_flow_ops.md b/tensorflow/g3doc/api_docs/python/control_flow_ops.md
index 6a139fb6d3..da6262a6f7 100644
--- a/tensorflow/g3doc/api_docs/python/control_flow_ops.md
+++ b/tensorflow/g3doc/api_docs/python/control_flow_ops.md
@@ -417,7 +417,7 @@ Returns the truth value of (x == y) element-wise.
##### Args:
-* <b>`x`</b>: A `Tensor`. Must be one of the following types: `half`, `float32`, `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`, `complex64`, `quint8`, `qint8`, `qint32`, `string`, `bool`.
+* <b>`x`</b>: A `Tensor`. Must be one of the following types: `half`, `float32`, `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`, `complex64`, `complex128`, `quint8`, `qint8`, `qint32`, `string`.
* <b>`y`</b>: A `Tensor`. Must have the same type as `x`.
* <b>`name`</b>: A name for the operation (optional).
@@ -435,7 +435,7 @@ Returns the truth value of (x != y) element-wise.
##### Args:
-* <b>`x`</b>: A `Tensor`. Must be one of the following types: `half`, `float32`, `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`, `complex64`, `quint8`, `qint8`, `qint32`, `string`, `bool`.
+* <b>`x`</b>: A `Tensor`. Must be one of the following types: `half`, `float32`, `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`, `complex64`, `complex128`, `quint8`, `qint8`, `qint32`, `string`.
* <b>`y`</b>: A `Tensor`. Must have the same type as `x`.
* <b>`name`</b>: A name for the operation (optional).
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.diag.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.diag.md
index 394e351ad7..94eb6a6717 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.diag.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.diag.md
@@ -23,7 +23,7 @@ tf.diag(diagonal) ==> [[1, 0, 0, 0]
##### Args:
-* <b>`diagonal`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`.
+* <b>`diagonal`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`, `complex64`.
Rank k tensor where k is at most 3.
* <b>`name`</b>: A name for the operation (optional).
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.diag_part.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.diag_part.md
index 182aa5264b..249eb80e50 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.diag_part.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.diag_part.md
@@ -24,7 +24,7 @@ tf.diag_part(input) ==> [1, 2, 3, 4]
##### Args:
-* <b>`input`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`.
+* <b>`input`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`, `complex64`.
Rank k tensor where k is 2, 4, or 6.
* <b>`name`</b>: A name for the operation (optional).
diff --git a/tensorflow/g3doc/api_docs/python/math_ops.md b/tensorflow/g3doc/api_docs/python/math_ops.md
index 5ee3dcead4..7fac1aad70 100644
--- a/tensorflow/g3doc/api_docs/python/math_ops.md
+++ b/tensorflow/g3doc/api_docs/python/math_ops.md
@@ -931,7 +931,7 @@ tf.diag(diagonal) ==> [[1, 0, 0, 0]
##### Args:
-* <b>`diagonal`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`.
+* <b>`diagonal`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`, `complex64`.
Rank k tensor where k is at most 3.
* <b>`name`</b>: A name for the operation (optional).
@@ -968,7 +968,7 @@ tf.diag_part(input) ==> [1, 2, 3, 4]
##### Args:
-* <b>`input`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`.
+* <b>`input`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`, `complex64`.
Rank k tensor where k is 2, 4, or 6.
* <b>`name`</b>: A name for the operation (optional).
diff --git a/tensorflow/g3doc/api_docs/python/nn.md b/tensorflow/g3doc/api_docs/python/nn.md
index 4eabd326f0..8a40d364e1 100644
--- a/tensorflow/g3doc/api_docs/python/nn.md
+++ b/tensorflow/g3doc/api_docs/python/nn.md
@@ -163,7 +163,7 @@ case where both types are quantized.
* <b>`value`</b>: A `Tensor` with type `float`, `double`, `int64`, `int32`, `uint8`,
- `int16`, `int8`, or `complex64`.
+ `int16`, `int8`, `complex64` or `complex128`.
* <b>`bias`</b>: A 1-D `Tensor` with size matching the last dimension of `value`.
Must be the same type as `value` unless `value` is a quantized type,
in which case a different quantized type may be used.
@@ -186,7 +186,7 @@ Specifically, `y = 1 / (1 + exp(-x))`.
##### Args:
-* <b>`x`</b>: A Tensor with type `float`, `double`, `int32`, `complex64`, `int64`,
+* <b>`x`</b>: A Tensor with type `float`, `double`, `int32`, `complex64`, `complex128`, `int64`,
or `qint32`.
* <b>`name`</b>: A name for the operation (optional).
@@ -205,7 +205,7 @@ Computes hyperbolic tangent of `x` element-wise.
##### Args:
-* <b>`x`</b>: A Tensor with type `float`, `double`, `int32`, `complex64`, `int64`,
+* <b>`x`</b>: A Tensor with type `float`, `double`, `int32`, `complex64`, `complex128`, `int64`,
or `qint32`.
* <b>`name`</b>: A name for the operation (optional).
diff --git a/tensorflow/g3doc/api_docs/python/train.md b/tensorflow/g3doc/api_docs/python/train.md
index bf270ddfd4..c2aa1c01fa 100644
--- a/tensorflow/g3doc/api_docs/python/train.md
+++ b/tensorflow/g3doc/api_docs/python/train.md
@@ -948,7 +948,7 @@ saver.restore(...checkpoint filename...)
Creates a new ExponentialMovingAverage object.
-The `Apply()` method has to be called to create shadow variables and add
+The `apply()` method has to be called to create shadow variables and add
ops to maintain moving averages.
The optional `num_updates` parameter allows one to tweak the decay rate
@@ -965,7 +965,7 @@ move faster. If passed, the actual decay rate used is:
* <b>`decay`</b>: Float. The decay to use.
* <b>`num_updates`</b>: Optional count of number of updates applied to variables.
* <b>`name`</b>: String. Optional prefix name to use for the name of ops added in
- `Apply()`.
+ `apply()`.
- - -
diff --git a/tensorflow/g3doc/get_started/os_setup.md b/tensorflow/g3doc/get_started/os_setup.md
index 3a9b883bc4..50309a5df9 100644
--- a/tensorflow/g3doc/get_started/os_setup.md
+++ b/tensorflow/g3doc/get_started/os_setup.md
@@ -59,11 +59,11 @@ $ sudo easy_install pip
Install TensorFlow:
```bash
-# Ubuntu/Linux 64-bit, CPU only:
+# Ubuntu/Linux 64-bit, CPU only, Python 2.7:
$ sudo pip install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl
-# Ubuntu/Linux 64-bit, GPU enabled. Requires CUDA toolkit 7.5 and CuDNN v4. For
-# other versions, see "Install from sources" below.
+# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7. Requires CUDA toolkit 7.5 and CuDNN v4.
+# For other versions, see "Install from sources" below.
$ sudo pip install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl
# Mac OS X, CPU only:
@@ -74,11 +74,11 @@ $ sudo pip install --upgrade https://storage.googleapis.com/tensorflow/mac/tenso
For python3:
```bash
-# Ubuntu/Linux 64-bit, CPU only:
+# Ubuntu/Linux 64-bit, CPU only, Python 3.4:
$ sudo pip3 install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl
-# Ubuntu/Linux 64-bit, GPU enabled. Requires CUDA toolkit 7.5 and CuDNN v4. For
-# other versions, see "Install from sources" below.
+# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4. Requires CUDA toolkit 7.5 and CuDNN v4.
+# For other versions, see "Install from sources" below.
$ sudo pip3 install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl
# Mac OS X, CPU only:
@@ -134,11 +134,11 @@ $ source ~/tensorflow/bin/activate # If using bash
$ source ~/tensorflow/bin/activate.csh # If using csh
(tensorflow)$ # Your prompt should change
-# Ubuntu/Linux 64-bit, CPU only:
+# Ubuntu/Linux 64-bit, CPU only, Python 2.7:
(tensorflow)$ pip install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl
-# Ubuntu/Linux 64-bit, GPU enabled. Requires CUDA toolkit 7.5 and CuDNN v4. For
-# other versions, see "Install from sources" below.
+# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7. Requires CUDA toolkit 7.5 and CuDNN v4.
+# For other versions, see "Install from sources" below.
(tensorflow)$ pip install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl
# Mac OS X, CPU only:
@@ -152,11 +152,11 @@ $ source ~/tensorflow/bin/activate # If using bash
$ source ~/tensorflow/bin/activate.csh # If using csh
(tensorflow)$ # Your prompt should change
-# Ubuntu/Linux 64-bit, CPU only:
+# Ubuntu/Linux 64-bit, CPU only, Python 3.4:
(tensorflow)$ pip3 install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl
-# Ubuntu/Linux 64-bit, GPU enabled. Requires CUDA toolkit 7.5 and CuDNN v4. For
-# other versions, see "Install from sources" below.
+# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4. Requires CUDA toolkit 7.5 and CuDNN v4.
+# For other versions, see "Install from sources" below.
(tensorflow)$ pip3 install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl
# Mac OS X, CPU only:
@@ -225,15 +225,15 @@ Use the `--ignore-installed` flag to prevent errors about `easy_install`.
$ source activate tensorflow
(tensorflow)$ # Your prompt should change
-# Ubuntu/Linux 64-bit, CPU only:
-(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0rc0-cp27-none-linux_x86_64.whl
+# Ubuntu/Linux 64-bit, CPU only, Python 2.7:
+(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl
-# Ubuntu/Linux 64-bit, GPU enabled. Requires CUDA toolkit 7.5 and CuDNN v4. For
-# other versions, see "Install from sources" below.
-(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0rc0-cp27-none-linux_x86_64.whl
+# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7. Requires CUDA toolkit 7.5 and CuDNN v4.
+# For other versions, see "Install from sources" below.
+(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl
# Mac OS X, CPU only:
-(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/mac/tensorflow-0.8.0rc0-py2-none-any.whl
+(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/mac/tensorflow-0.8.0-py2-none-any.whl
```
and again for Python 3:
@@ -242,15 +242,15 @@ and again for Python 3:
$ source activate tensorflow
(tensorflow)$ # Your prompt should change
-# Ubuntu/Linux 64-bit, CPU only:
-(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0rc0-cp34-cp34m-linux_x86_64.whl
+# Ubuntu/Linux 64-bit, CPU only, Python 3.4:
+(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl
-# Ubuntu/Linux 64-bit, GPU enabled. Requires CUDA toolkit 7.5 and CuDNN v4. For
-# other versions, see "Install from sources" below.
-(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0rc0-cp34-cp34m-linux_x86_64.whl
+# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4. Requires CUDA toolkit 7.5 and CuDNN v4.
+# For other versions, see "Install from sources" below.
+(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl
# Mac OS X, CPU only:
-(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/mac/tensorflow-0.8.0rc0-py3-none-any.whl
+(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/mac/tensorflow-0.8.0-py3-none-any.whl
```
With the conda environment activated, you can now
@@ -309,9 +309,13 @@ After Docker is installed, launch a Docker container with the TensorFlow binary
image as follows.
```bash
-$ docker run -it gcr.io/tensorflow/tensorflow
+$ docker run -it -p 8888:8888 gcr.io/tensorflow/tensorflow
```
+The option `-p 8888:8888` is used to publish the Docker container᾿s internal port to the host machine, in this case to ensure Jupyter notebook connection.
+
+The format of the port mapping `hostPort:containerPort`. You can speficy any valid port number for the host port but has to be `8888` for the container port portion.
+
If you're using a container with GPU support, some additional flags must be
passed to expose the GPU device to the container. For the default config, we
include a
@@ -437,7 +441,10 @@ binary path.
#### Install other dependencies
```bash
-$ sudo apt-get install python-numpy swig python-dev
+# For Python 2.7:
+$ sudo apt-get install python-numpy swig python-dev python-wheel
+# For Python 3.x:
+$ sudo apt-get install python3-numpy swig python3-dev python3-wheel
```
#### Configure the installation
@@ -627,7 +634,7 @@ $ bazel build -c opt --config=cuda //tensorflow/tools/pip_package:build_pip_pack
$ bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
# The name of the .whl file will depend on your platform.
-$ pip install /tmp/tensorflow_pkg/tensorflow-0.8.0-py2-none-linux_x86_64.whl
+$ sudo pip install /tmp/tensorflow_pkg/tensorflow-0.8.0-py2-none-any.whl
```
## Setting up TensorFlow for Development
diff --git a/tensorflow/g3doc/how_tos/adding_an_op/index.md b/tensorflow/g3doc/how_tos/adding_an_op/index.md
index 7b067ebb12..c1516ea487 100644
--- a/tensorflow/g3doc/how_tos/adding_an_op/index.md
+++ b/tensorflow/g3doc/how_tos/adding_an_op/index.md
@@ -815,7 +815,7 @@ expressions:
```c++
REGISTER_OP("BuiltInTypesExample")
.Input("integers: int32")
- .Input("complex_numbers: scomplex64");
+ .Input("complex_numbers: complex64");
```
* `<attr-type>`, where `<attr-type>` is the name of an [Attr](#attrs) with type
diff --git a/tensorflow/g3doc/how_tos/distributed/index.md b/tensorflow/g3doc/how_tos/distributed/index.md
index 2adb7b3eb0..88f334cb53 100644
--- a/tensorflow/g3doc/how_tos/distributed/index.md
+++ b/tensorflow/g3doc/how_tos/distributed/index.md
@@ -21,7 +21,7 @@ $ python
```
The
-[`tf.train.Server.create_local_server()`](../../api_docs/train.md#Server.create_local_server)
+[`tf.train.Server.create_local_server()`](../../api_docs/python/train.md#Server.create_local_server)
method creates a single-process cluster, with an in-process server.
## Create a cluster
@@ -110,7 +110,7 @@ which you'd like to see support, please raise a
## Specifying distributed devices in your model
To place operations on a particular process, you can use the same
-[`tf.device()`](https://www.tensorflow.org/versions/master/api_docs/python/framework.html#device)
+[`tf.device()`](../../api_docs/python/framework.md#device)
function that is used to specify whether ops run on the CPU or GPU. For example:
```python
@@ -158,7 +158,7 @@ simplify the work of specifying a replicated model. Possible approaches include:
for each `/job:worker` task, typically in the same process as the worker
task. Each client builds a similar graph containing the parameters (pinned to
`/job:ps` as before using
- [`tf.train.replica_device_setter()`](../../api_docs/train.md#replica_device_setter)
+ [`tf.train.replica_device_setter()`](../../api_docs/python/train.md#replica_device_setter)
to map them deterministically to the same tasks); and a single copy of the
compute-intensive part of the model, pinned to the local task in
`/job:worker`.
@@ -199,7 +199,7 @@ FLAGS = tf.app.flags.FLAGS
def main(_):
ps_hosts = FLAGS.ps_hosts.split(",")
- worker_hosts = FLAGS.worker_hosts(",")
+ worker_hosts = FLAGS.worker_hosts.split(",")
# Create a cluster from the parameter server and worker hosts.
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
diff --git a/tensorflow/g3doc/how_tos/threading_and_queues/index.md b/tensorflow/g3doc/how_tos/threading_and_queues/index.md
index 9d1fbddb68..ec17518da0 100644
--- a/tensorflow/g3doc/how_tos/threading_and_queues/index.md
+++ b/tensorflow/g3doc/how_tos/threading_and_queues/index.md
@@ -167,10 +167,10 @@ try:
break
sess.run(train_op)
except Exception, e:
- # Report exceptions to the coordinator.
- coord.request_stop(e)
-
-# Terminate as usual. It is innocuous to request stop twice.
-coord.request_stop()
-coord.join(threads)
+ # Report exceptions to the coordinator.
+ coord.request_stop(e)
+finally:
+ # Terminate as usual. It is innocuous to request stop twice.
+ coord.request_stop()
+ coord.join(threads)
```
diff --git a/tensorflow/g3doc/resources/dims_types.md b/tensorflow/g3doc/resources/dims_types.md
index 8e55e609a0..76110190d9 100644
--- a/tensorflow/g3doc/resources/dims_types.md
+++ b/tensorflow/g3doc/resources/dims_types.md
@@ -62,6 +62,7 @@ Data type | Python type | Description
`DT_STRING` | `tf.string` | Variable length byte arrays. Each element of a Tensor is a byte array.
`DT_BOOL` | `tf.bool` | Boolean.
`DT_COMPLEX64` | `tf.complex64` | Complex number made of two 32 bits floating points: real and imaginary parts.
+`DT_COMPLEX128` | `tf.complex128` | Complex number made of two 64 bits floating points: real and imaginary parts.
`DT_QINT8` | `tf.qint8` | 8 bits signed integer used in quantized Ops.
`DT_QINT32` | `tf.qint32` | 32 bits signed integer used in quantized Ops.
`DT_QUINT8` | `tf.quint8` | 8 bits unsigned integer used in quantized Ops.
diff --git a/tensorflow/g3doc/tutorials/mandelbrot/index.md b/tensorflow/g3doc/tutorials/mandelbrot/index.md
index 6b1c070791..8009e32d84 100755
--- a/tensorflow/g3doc/tutorials/mandelbrot/index.md
+++ b/tensorflow/g3doc/tutorials/mandelbrot/index.md
@@ -20,9 +20,8 @@ import numpy as np
# Imports for visualization
import PIL.Image
-from cStringIO import StringIO
-from IPython.display import clear_output, Image, display
-import scipy.ndimage as nd
+from io import BytesIO
+from IPython.display import Image, display
```
Now we'll define a function to actually display the image once we have
@@ -39,7 +38,7 @@ def DisplayFractal(a, fmt='jpeg'):
img[a==a.max()] = 0
a = img
a = np.uint8(np.clip(a, 0, 255))
- f = StringIO()
+ f = BytesIO()
PIL.Image.fromarray(a).save(f, fmt)
display(Image(data=f.getvalue()))
```
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index 5bef94e789..ef025917b6 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -39,6 +39,10 @@ import traceback
# the mode to RTLD_GLOBAL to make the symbols visible, so libraries such
# as the ones implementing custom ops can have access to tensorflow
# framework's symbols.
+# one catch is that numpy *must* be imported before the call to
+# setdlopenflags(), or there is a risk that later c modules will segfault
+# when importing numpy (gh-2034).
+import numpy as np
_default_dlopen_flags = sys.getdlopenflags()
sys.setdlopenflags(_default_dlopen_flags | ctypes.RTLD_GLOBAL)
from tensorflow.python import pywrap_tensorflow
diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py
index d15d3562dd..74e91b826a 100644
--- a/tensorflow/python/kernel_tests/conv_ops_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_test.py
@@ -465,6 +465,21 @@ class Conv2DTest(tf.test.TestCase):
data_format=data_format,
use_gpu=use_gpu)
+ def testConv2DStrideTwoFilterOneSameBackpropInput(self):
+ expected_output = [1.0, 0.0, 2.0, 0.0,
+ 0.0, 0.0, 0.0, 0.0,
+ 3.0, 0.0, 4.0, 0.0,
+ 0.0, 0.0, 0.0, 0.0]
+ for (data_format, use_gpu) in GetTestConfigs():
+ self._RunAndVerifyBackpropInput(input_sizes=[1, 4, 4, 1],
+ filter_sizes=[1, 1, 1, 1],
+ output_sizes=[1, 2, 2, 1],
+ strides=[2, 2],
+ padding="SAME",
+ expected=expected_output,
+ data_format=data_format,
+ use_gpu=use_gpu)
+
# Testing for backprops
def _RunAndVerifyBackpropFilter(self, input_sizes, filter_sizes, output_sizes,
strides, padding, expected, data_format,
@@ -568,6 +583,18 @@ class Conv2DTest(tf.test.TestCase):
data_format=data_format,
use_gpu=use_gpu)
+ def testConv2DStrideTwoFilterOneSameBackpropFilter(self):
+ expected_output = [78.]
+ for (data_format, use_gpu) in GetTestConfigs():
+ self._RunAndVerifyBackpropFilter(input_sizes=[1, 4, 4, 1],
+ filter_sizes=[1, 1, 1, 1],
+ output_sizes=[1, 2, 2, 1],
+ strides=[2, 2],
+ padding="SAME",
+ expected=expected_output,
+ data_format=data_format,
+ use_gpu=use_gpu)
+
# Gradient checkers
def ConstructAndTestGradient(self, batch, input_rows, input_cols, filter_rows,
filter_cols, in_depth, out_depth, stride_rows,
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index 8ce81f1b14..596390bf42 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -64,12 +64,8 @@ class UnaryOpTest(tf.test.TestCase):
else:
self.assertAllClose(np_ans, tf_cpu)
- # TODO(ebrevdo): consider adding polygamma function
- if tf_func in (tf.digamma,):
- return # Return early
-
- if x.dtype == np.complex64 and tf_func in (
- tf.sign, tf.sqrt, tf.rsqrt, tf.log):
+ if (x.dtype in (np.complex64, np.complex128) and
+ tf_func in (tf.sign, tf.sqrt, tf.rsqrt, tf.log)):
return # Return early
if x.dtype == np.float16:
@@ -89,7 +85,7 @@ class UnaryOpTest(tf.test.TestCase):
x_init_value=xf)
jacob_n = jacob_n.astype(np.float16)
self.assertAllClose(jacob_t, jacob_n, rtol=5e-3, atol=5e-3)
- elif x.dtype == np.float32 or x.dtype == np.complex64:
+ elif x.dtype in (np.float32, np.complex64):
s = list(np.shape(x))
jacob_t, jacob_n = tf.test.compute_gradient(inx,
s,
@@ -97,7 +93,7 @@ class UnaryOpTest(tf.test.TestCase):
s,
x_init_value=x)
self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
- elif x.dtype == np.float64:
+ elif x.dtype in (np.float64, np.complex128):
s = list(np.shape(x))
jacob_t, jacob_n = tf.test.compute_gradient(inx,
s,
@@ -290,6 +286,30 @@ class UnaryOpTest(tf.test.TestCase):
return x / np.abs(x)
self._compareCpu(y, complex_sign, tf.sign)
+ def testComplex128Basic(self):
+ x = np.complex(1, 1) * np.arange(-3, 3).reshape(1, 3, 2).astype(
+ np.complex128)
+ y = x + 0.5 # no zeros
+ self._compareCpu(x, np.abs, tf.abs)
+ self._compareCpu(x, np.abs, _ABS)
+ self._compareCpu(x, np.negative, tf.neg)
+ self._compareCpu(x, np.negative, _NEG)
+ self._compareCpu(y, self._inv, tf.inv)
+ self._compareCpu(x, np.square, tf.square)
+ self._compareCpu(x, np.sqrt, tf.sqrt)
+ self._compareCpu(y, self._rsqrt, tf.rsqrt)
+ self._compareCpu(x, np.exp, tf.exp)
+ self._compareCpu(y, np.log, tf.log)
+ self._compareCpu(x, np.tanh, tf.tanh)
+ self._compareCpu(x, self._sigmoid, tf.sigmoid)
+ self._compareCpu(x, np.sin, tf.sin)
+ self._compareCpu(x, np.cos, tf.cos)
+
+ # Numpy uses an incorrect definition of sign; use the right one instead.
+ def complex_sign(x):
+ return x / np.abs(x)
+ self._compareCpu(y, complex_sign, tf.sign)
+
class BinaryOpTest(tf.test.TestCase):
@@ -397,10 +417,10 @@ class BinaryOpTest(tf.test.TestCase):
def _compareBoth(self, x, y, np_func, tf_func):
self._compareCpu(x, y, np_func, tf_func)
if x.dtype in (np.float16, np.float32, np.float64):
- if tf_func not in (_FLOORDIV, tf.floordiv, tf.igamma, tf.igammac):
+ if tf_func not in (_FLOORDIV, tf.floordiv, tf.igamma, tf.igammac, tf.zeta, tf.polygamma):
self._compareGradientX(x, y, np_func, tf_func)
self._compareGradientY(x, y, np_func, tf_func)
- if tf_func in (tf.igamma, tf.igammac):
+ if tf_func in (tf.igamma, tf.igammac, tf.zeta, tf.polygamma):
# These methods only support gradients in the second parameter
self._compareGradientY(x, y, np_func, tf_func)
self._compareGpu(x, y, np_func, tf_func)
@@ -424,6 +444,10 @@ class BinaryOpTest(tf.test.TestCase):
x_pos_small = np.linspace(0.1, 10, 15).reshape(1, 3, 5).astype(np.float32)
self._compareBoth(a_pos_small, x_pos_small, special.gammainc, tf.igamma)
self._compareBoth(a_pos_small, x_pos_small, special.gammaincc, tf.igammac)
+ # Need x > 1
+ self._compareBoth(x_pos_small + 1, a_pos_small, special.zeta, tf.zeta)
+ n_small = np.arange(0, 15).reshape(1, 3, 5).astype(np.float32)
+ self._compareBoth(n_small, x_pos_small, special.polygamma, tf.polygamma)
except ImportError as e:
tf.logging.warn("Cannot test special functions: %s" % str(e))
@@ -520,6 +544,20 @@ class BinaryOpTest(tf.test.TestCase):
self._compareCpu(x, y, np.multiply, _MUL)
self._compareCpu(x, y + 0.1, np.true_divide, _TRUEDIV)
+ def testComplex128Basic(self):
+ x = np.complex(1, 1) * np.linspace(-10, 10, 6).reshape(1, 3, 2).astype(
+ np.complex128)
+ y = np.complex(1, 1) * np.linspace(20, -20, 6).reshape(1, 3, 2).astype(
+ np.complex128)
+ self._compareCpu(x, y, np.add, tf.add)
+ self._compareCpu(x, y, np.subtract, tf.sub)
+ self._compareCpu(x, y, np.multiply, tf.mul)
+ self._compareCpu(x, y + 0.1, np.true_divide, tf.truediv)
+ self._compareCpu(x, y, np.add, _ADD)
+ self._compareCpu(x, y, np.subtract, _SUB)
+ self._compareCpu(x, y, np.multiply, _MUL)
+ self._compareCpu(x, y + 0.1, np.true_divide, _TRUEDIV)
+
def testStringComparison(self):
x = np.array([["abc", "bh"], ["c", ""]])
y = np.array([["abc", "bh"], ["def", "hi"]])
@@ -571,11 +609,13 @@ class BinaryOpTest(tf.test.TestCase):
np.float64,
np.int32,
np.int64,
- np.complex64
+ np.complex64,
+ np.complex128,
]
for dtype in dtypes:
for (np_func, tf_func) in funcs:
- if dtype == np.complex64 and tf_func in (_FLOORDIV, tf.floordiv):
+ if (dtype in (np.complex64, np.complex128) and
+ tf_func in (_FLOORDIV, tf.floordiv)):
continue # floordiv makes no sense for complex numbers
self._compareBCast(xs, ys, dtype, np_func, tf_func)
self._compareBCast(ys, xs, dtype, np_func, tf_func)
@@ -827,18 +867,18 @@ class ComparisonOpTest(tf.test.TestCase):
for t in dtypes:
for x in data:
for y in data:
- self.assertEqual(self._compare(tf.less, x, y, t),
- x < y)
- self.assertEqual(self._compare(tf.less_equal, x, y, t),
- x <= y)
- self.assertEqual(self._compare(tf.greater, x, y, t),
- x > y)
- self.assertEqual(self._compare(tf.greater_equal, x, y, t),
- x >= y)
- self.assertEqual(self._compare(tf.equal, x, y, t),
- x == y)
- self.assertEqual(self._compare(tf.not_equal, x, y, t),
- x != y)
+ self.assertEqual(self._compare(tf.less, x, y, t), x < y)
+ self.assertEqual(self._compare(tf.less_equal, x, y, t), x <= y)
+ self.assertEqual(self._compare(tf.greater, x, y, t), x > y)
+ self.assertEqual(self._compare(tf.greater_equal, x, y, t), x >= y)
+ self.assertEqual(self._compare(tf.equal, x, y, t), x == y)
+ self.assertEqual(self._compare(tf.not_equal, x, y, t), x != y)
+ data = [-1, 0, 1, -1j, 1j, 1 + 1j, 1 - 1j]
+ for t in [np.complex64, np.complex128]:
+ for x in data:
+ for y in data:
+ self.assertEqual(self._compare(tf.equal, x, y, t), x == y)
+ self.assertEqual(self._compare(tf.not_equal, x, y, t), x != y)
def _compareCpu(self, x, y, np_func, tf_func):
np_ans = np_func(x, y)
@@ -872,10 +912,9 @@ class ComparisonOpTest(tf.test.TestCase):
self._compareBoth(xt, yt, np.equal, tf.equal)
self._compareBoth(xt, yt, np.not_equal, tf.not_equal)
# TODO(zhifengc): complex64 doesn't work on GPU yet.
- self._compareCpu(x.astype(np.complex64), y.astype(np.complex64),
- np.equal, tf.equal)
- self._compareCpu(x.astype(np.complex64), y.astype(np.complex64),
- np.not_equal, tf.not_equal)
+ for t in [np.complex64, np.complex128]:
+ self._compareCpu(x.astype(t), y.astype(t), np.equal, tf.equal)
+ self._compareCpu(x.astype(t), y.astype(t), np.not_equal, tf.not_equal)
def _compareBCast(self, xs, ys, dtype, np_func, tf_func):
x = np.linspace(-15, 15, np.prod(xs)).astype(dtype).reshape(xs)
@@ -1115,7 +1154,7 @@ class SelectOpTest(tf.test.TestCase):
x = np.random.rand(1, 3, 2) * 100
y = np.random.rand(1, 3, 2) * 100
for t in [np.float16, np.float32, np.float64, np.int32, np.int64,
- np.complex64]:
+ np.complex64, np.complex128]:
xt = x.astype(t)
yt = y.astype(t)
self._compare(c, xt, yt, use_gpu=False)
@@ -1242,7 +1281,7 @@ class BatchSelectOpTest(tf.test.TestCase):
x = np.random.rand(16, 2, 8) * 100
y = np.random.rand(16, 2, 8) * 100
for t in [np.float16, np.float32, np.float64, np.int32, np.int64,
- np.complex64]:
+ np.complex64, np.complex128]:
xt = x.astype(t)
yt = y.astype(t)
self._compare(c, xt, yt, use_gpu=False)
@@ -1273,7 +1312,7 @@ class BatchSelectOpTest(tf.test.TestCase):
x = np.random.rand(16, 3, 2) * 100
y = np.random.rand(16, 3, 2) * 100
for t in [np.float16, np.float32, np.float64, np.int32, np.int64,
- np.complex64]:
+ np.complex64, np.complex128]:
xt = x.astype(t)
yt = y.astype(t)
with self.assertRaises(ValueError):
@@ -1394,6 +1433,7 @@ class MathOpsOverloadTest(tf.test.TestCase):
tf.int32,
tf.int64,
tf.complex64,
+ tf.complex128,
]
funcs = [
(np.add, _ADD),
@@ -1405,7 +1445,7 @@ class MathOpsOverloadTest(tf.test.TestCase):
]
for dtype in dtypes:
for np_func, tf_func in funcs:
- if dtype == tf.complex64 and tf_func == _FLOORDIV:
+ if dtype in (tf.complex64, tf.complex128) and tf_func == _FLOORDIV:
continue # floordiv makes no sense for complex
self._compareBinary(10, 5, dtype, np_func, tf_func)
# Mod only works for int32 and int64.
@@ -1537,10 +1577,17 @@ class ComplexMakeRealImagTest(tf.test.TestCase):
self.assertShapeEqual(np_real, tf_real)
self.assertShapeEqual(np_imag, tf_imag)
- def testRealImag(self):
+ def testRealImag64(self):
real = (np.arange(-3, 3) / 4.).reshape([1, 3, 2]).astype(np.float32)
imag = (np.arange(-3, 3) / 5.).reshape([1, 3, 2]).astype(np.float32)
- cplx = real + (1j) * imag
+ cplx = real + 1j * imag
+ self._compareRealImag(cplx, use_gpu=False)
+ self._compareRealImag(cplx, use_gpu=True)
+
+ def testRealImag128(self):
+ real = (np.arange(-3, 3) / 4.).reshape([1, 3, 2]).astype(np.float64)
+ imag = (np.arange(-3, 3) / 5.).reshape([1, 3, 2]).astype(np.float64)
+ cplx = real + 1j * imag
self._compareRealImag(cplx, use_gpu=False)
self._compareRealImag(cplx, use_gpu=True)
@@ -1553,10 +1600,17 @@ class ComplexMakeRealImagTest(tf.test.TestCase):
self.assertAllEqual(np_ans, tf_ans)
self.assertShapeEqual(np_ans, tf_conj)
- def testConj(self):
+ def testConj64(self):
real = (np.arange(-3, 3) / 4.).reshape([1, 3, 2]).astype(np.float32)
imag = (np.arange(-3, 3) / 5.).reshape([1, 3, 2]).astype(np.float32)
- cplx = real + (1j) * imag
+ cplx = real + 1j * imag
+ self._compareConj(cplx, use_gpu=False)
+ self._compareConj(cplx, use_gpu=True)
+
+ def testConj128(self):
+ real = (np.arange(-3, 3) / 4.).reshape([1, 3, 2]).astype(np.float64)
+ imag = (np.arange(-3, 3) / 5.).reshape([1, 3, 2]).astype(np.float64)
+ cplx = real + 1j * imag
self._compareConj(cplx, use_gpu=False)
self._compareConj(cplx, use_gpu=True)
@@ -1585,8 +1639,12 @@ class ComplexMakeRealImagTest(tf.test.TestCase):
self.assertAllClose(jacob_t, jacob_n, rtol=epsilon, atol=epsilon)
def testGradient(self):
+ # complex64
data = np.arange(1, 2, 0.10).reshape([5, 2]).astype(np.float32)
self._compareGradient(data)
+ # complex128
+ data = np.arange(1, 2, 0.10).reshape([5, 2]).astype(np.float64)
+ self._compareGradient(data)
def _compareMulGradient(self, data):
# data is a float matrix of shape [n, 4]. data[:, 0], data[:, 1],
diff --git a/tensorflow/python/kernel_tests/diag_op_test.py b/tensorflow/python/kernel_tests/diag_op_test.py
index 04be325dfd..20a7760fd2 100644
--- a/tensorflow/python/kernel_tests/diag_op_test.py
+++ b/tensorflow/python/kernel_tests/diag_op_test.py
@@ -165,6 +165,14 @@ class DiagTest(tf.test.TestCase):
self.diagOp(x, np.float32, expected_ans)
self.diagOp(x, np.float64, expected_ans)
+ def testRankOneComplexTensor(self):
+ x = np.array([1.1 + 1.1j, 2.2 + 2.2j, 3.3 + 3.3j], dtype = np.complex64)
+ expected_ans = np.array(
+ [[1.1 + 1.1j, 0 + 0j, 0 + 0j],
+ [0 + 0j, 2.2 + 2.2j, 0 + 0j],
+ [0 + 0j, 0 + 0j, 3.3 + 3.3j]], dtype = np.complex64)
+ self.diagOp(x, np.complex64, expected_ans)
+
def testRankTwoIntTensor(self):
x = np.array([[1, 2, 3], [4, 5, 6]])
expected_ans = np.array(
@@ -189,6 +197,19 @@ class DiagTest(tf.test.TestCase):
self.diagOp(x, np.float32, expected_ans)
self.diagOp(x, np.float64, expected_ans)
+ def testRankTwoComplexTensor(self):
+ x = np.array([[1.1 + 1.1j, 2.2 + 2.2j, 3.3 + 3.3j],
+ [4.4 + 4.4j, 5.5 + 5.5j, 6.6 + 6.6j]], dtype = np.complex64)
+ expected_ans = np.array(
+ [[[[1.1 + 1.1j, 0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j, 0 + 0j]],
+ [[0 + 0j, 2.2 + 2.2j, 0 + 0j], [0 + 0j, 0 + 0j, 0 + 0j]],
+ [[0 + 0j, 0 + 0j, 3.3 + 3.3j], [0 + 0j, 0 + 0j, 0 + 0j]]],
+ [[[0 + 0j, 0 + 0j, 0 + 0j], [4.4 + 4.4j, 0 + 0j, 0 + 0j]],
+ [[0 + 0j, 0 + 0j, 0 + 0j], [0 + 0j, 5.5 + 5.5j, 0 + 0j]],
+ [[0 + 0j, 0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j, 6.6 + 6.6j]]]],
+ dtype = np.complex64)
+ self.diagOp(x, np.complex64, expected_ans)
+
def testRankThreeFloatTensor(self):
x = np.array([[[1.1, 2.2], [3.3, 4.4]],
[[5.5, 6.6], [7.7, 8.8]]])
@@ -204,6 +225,30 @@ class DiagTest(tf.test.TestCase):
self.diagOp(x, np.float32, expected_ans)
self.diagOp(x, np.float64, expected_ans)
+ def testRankThreeComplexTensor(self):
+ x = np.array([[[1.1 + 1.1j, 2.2 + 2.2j], [3.3 + 3.3j, 4.4 + 4.4j]],
+ [[5.5 + 5.5j, 6.6 + 6.6j], [7.7 + 7.7j, 8.8 + 8.8j]]],
+ dtype = np.complex64)
+ expected_ans = np.array(
+ [[[[[[1.1 + 1.1j, 0 + 0j], [0 + 0j, 0 + 0j]],
+ [[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]]],
+ [[[0 + 0j, 2.2 + 2.2j], [0 + 0j, 0 + 0j]],
+ [[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]]]],
+ [[[[0 + 0j, 0 + 0j], [3.3 + 3.3j, 0 + 0j]],
+ [[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]]],
+ [[[0 + 0j, 0 + 0j], [0 + 0j, 4.4 + 4.4j]],
+ [[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]]]]],
+ [[[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]],
+ [[5.5 + 5.5j, 0 + 0j], [0 + 0j, 0 + 0j]]],
+ [[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]],
+ [[0 + 0j, 6.6 + 6.6j], [0 + 0j, 0 + 0j]]]],
+ [[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]],
+ [[0 + 0j, 0 + 0j], [7.7 + 7.7j, 0 + 0j]]],
+ [[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]],
+ [[0 + 0j, 0 + 0j], [0 + 0j, 8.8 + 8.8j]]]]]],
+ dtype = np.complex64)
+ self.diagOp(x, np.complex64, expected_ans)
+
class DiagPartOpTest(tf.test.TestCase):
diff --git a/tensorflow/python/kernel_tests/gradient_checker.py b/tensorflow/python/kernel_tests/gradient_checker.py
index 935446b1b5..5ca529cc73 100644
--- a/tensorflow/python/kernel_tests/gradient_checker.py
+++ b/tensorflow/python/kernel_tests/gradient_checker.py
@@ -188,7 +188,7 @@ def _compute_gradient(x,
"""Computes the theoretical and numerical jacobian."""
t = dtypes.as_dtype(x.dtype)
allowed_types = [dtypes.float16, dtypes.float32, dtypes.float64,
- dtypes.complex64]
+ dtypes.complex64, dtypes.complex128]
assert t.base_dtype in allowed_types, "Don't support type %s for x" % t.name
t2 = dtypes.as_dtype(y.dtype)
assert t2.base_dtype in allowed_types, "Don't support type %s for y" % t2.name
diff --git a/tensorflow/python/kernel_tests/matmul_op_test.py b/tensorflow/python/kernel_tests/matmul_op_test.py
index 3d39af8d85..0e7f2efe61 100644
--- a/tensorflow/python/kernel_tests/matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/matmul_op_test.py
@@ -68,10 +68,14 @@ class MatMulTest(tf.test.TestCase):
self.assertAllEqual(np_ans.shape, tf_ans.shape)
def _randMatrix(self, rows, cols, dtype):
- if dtype is np.complex64:
- real = self._randMatrix(rows, cols, np.float32)
- imag = self._randMatrix(rows, cols, np.float32)
- return real + np.complex(0, 1) * imag
+ if dtype in (np.complex64, np.complex128):
+ if dtype == np.complex64:
+ float_dtype = np.float32
+ else:
+ float_dtype = np.float64
+ real = self._randMatrix(rows, cols, float_dtype)
+ imag = self._randMatrix(rows, cols, float_dtype)
+ return real + 1j * imag
else:
return np.random.uniform(low=1.0, high=100.0, size=rows * cols).reshape(
[rows, cols]).astype(dtype)
@@ -106,11 +110,16 @@ class MatMulTest(tf.test.TestCase):
y = np.arange(1., 3.).reshape([1, 2]).astype(np.int32)
self._testCpuMatmul(x, y)
- def testSComplexBasic(self):
+ def testComplex64Basic(self):
x = np.arange(1., 5.).reshape([4, 1]).astype(np.complex64)
y = np.arange(1., 3.).reshape([1, 2]).astype(np.complex64)
self._testCpuMatmul(x, y)
+ def testComplex128Basic(self):
+ x = np.arange(1., 5.).reshape([4, 1]).astype(np.complex128)
+ y = np.arange(1., 3.).reshape([1, 2]).astype(np.complex128)
+ self._testCpuMatmul(x, y)
+
# Tests testing random sized matrices.
def testFloatRandom(self):
for _ in range(10):
@@ -145,13 +154,20 @@ class MatMulTest(tf.test.TestCase):
y = self._randMatrix(k, m, np.int32)
self._testCpuMatmul(x, y)
- def testSComplexRandom(self):
+ def testComplex64Random(self):
for _ in range(10):
n, k, m = np.random.randint(1, 100, size=3)
x = self._randMatrix(n, k, np.complex64)
y = self._randMatrix(k, m, np.complex64)
self._testCpuMatmul(x, y)
+ def testComplex128Random(self):
+ for _ in range(10):
+ n, k, m = np.random.randint(1, 100, size=3)
+ x = self._randMatrix(n, k, np.complex128)
+ y = self._randMatrix(k, m, np.complex128)
+ self._testCpuMatmul(x, y)
+
# Test the cases that transpose the matrices before multiplying.
# NOTE(keveman): The cases where only one of the inputs is
# transposed are covered by tf.matmul's gradient function.
diff --git a/tensorflow/python/kernel_tests/padding_fifo_queue_test.py b/tensorflow/python/kernel_tests/padding_fifo_queue_test.py
index d3fecf4f44..bdf0f5778f 100644
--- a/tensorflow/python/kernel_tests/padding_fifo_queue_test.py
+++ b/tensorflow/python/kernel_tests/padding_fifo_queue_test.py
@@ -1259,7 +1259,7 @@ class PaddingFIFOQueueTest(tf.test.TestCase):
def testDtypes(self):
with self.test_session() as sess:
dtypes = [tf.float32, tf.float64, tf.int32, tf.uint8, tf.int16, tf.int8,
- tf.int64, tf.bool, tf.complex64]
+ tf.int64, tf.bool, tf.complex64, tf.complex128]
shape = (32, 4, 128)
q = tf.PaddingFIFOQueue(32, dtypes, [shape[1:]] * len(dtypes))
diff --git a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py
index b3b25bf031..cf0bcbd50f 100644
--- a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py
@@ -79,6 +79,7 @@ class SparseTensorDenseMatMulTest(tf.test.TestCase):
self._testBasic(np.float32)
self._testBasic(np.float64)
self._testBasic(np.complex64)
+ self._testBasic(np.complex128)
# Tests setting one dimension to be a high value.
def _testLarge(self, np_dtype):
@@ -102,6 +103,7 @@ class SparseTensorDenseMatMulTest(tf.test.TestCase):
self._testLarge(np.float32)
self._testLarge(np.float64)
self._testLarge(np.complex64)
+ self._testLarge(np.complex128)
# Tests random sized matrices.
def testFloatRandom(self):
diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
index e03c54c953..0389f27622 100644
--- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py
+++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
@@ -77,6 +77,7 @@ class TensorArrayCPUTest(tf.test.TestCase):
self._testTensorArrayWritePack(tf.int32)
self._testTensorArrayWritePack(tf.int64)
self._testTensorArrayWritePack(tf.complex64)
+ self._testTensorArrayWritePack(tf.complex128)
self._testTensorArrayWritePack(tf.string)
def _testTensorArrayWriteConcat(self, tf_dtype):
@@ -111,6 +112,7 @@ class TensorArrayCPUTest(tf.test.TestCase):
self._testTensorArrayWriteConcat(tf.int32)
self._testTensorArrayWriteConcat(tf.int64)
self._testTensorArrayWriteConcat(tf.complex64)
+ self._testTensorArrayWriteConcat(tf.complex128)
self._testTensorArrayWriteConcat(tf.string)
def testTensorArrayUnpackWrongMajorSizeFails(self):
@@ -173,6 +175,7 @@ class TensorArrayCPUTest(tf.test.TestCase):
self._testTensorArrayUnpackRead(tf.int32)
self._testTensorArrayUnpackRead(tf.int64)
self._testTensorArrayUnpackRead(tf.complex64)
+ self._testTensorArrayUnpackRead(tf.complex128)
self._testTensorArrayUnpackRead(tf.string)
def _testTensorArraySplitRead(self, tf_dtype):
@@ -231,6 +234,7 @@ class TensorArrayCPUTest(tf.test.TestCase):
self._testTensorArraySplitRead(tf.int32)
self._testTensorArraySplitRead(tf.int64)
self._testTensorArraySplitRead(tf.complex64)
+ self._testTensorArraySplitRead(tf.complex128)
self._testTensorArraySplitRead(tf.string)
def testTensorGradArrayWriteRead(self):
@@ -468,8 +472,9 @@ class TensorArrayCPUTest(tf.test.TestCase):
wb1_grad.flow.eval()
def testTensorArrayWriteGradientAddMultipleAdds(self):
- for dtype in [tf.int32, tf.int64, tf.float32, tf.float64, tf.complex64]:
- self._testTensorArrayWriteGradientAddMultipleAdds(dtype)
+ for dtype in (tf.int32, tf.int64, tf.float32,
+ tf.float64, tf.complex64, tf.complex128):
+ self._testTensorArrayWriteGradientAddMultipleAdds(dtype)
def testMultiTensorArray(self):
with self.test_session(use_gpu=self._use_gpu):
@@ -540,7 +545,8 @@ class TensorArrayCPUTest(tf.test.TestCase):
self.assertAllEqual(c(-2.0), grad_vals[1])
def testTensorArrayGradientWriteRead(self):
- for dtype in (np.float32, np.float64, np.int32, np.int64, np.complex64):
+ for dtype in (np.float32, np.float64, np.int32,
+ np.int64, np.complex64, np.complex128):
self._testTensorArrayGradientWriteReadType(dtype)
def testTensorArrayGradientWritePackConcatAndRead(self):
diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py
index b2cad3ee05..6ea8ae3280 100644
--- a/tensorflow/python/kernel_tests/xent_op_test.py
+++ b/tensorflow/python/kernel_tests/xent_op_test.py
@@ -130,10 +130,9 @@ class XentTest(tf.test.TestCase):
np.array([[0., 0., 0., 1.], [0., .5, .5, 0.]]).astype(np.float32))
def testDouble(self):
- self._testXent(
+ self._testAll(
np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float64),
- np.array([[0., 0., 0., 1.], [0., .5, .5, 0.]]).astype(np.float64),
- use_gpu=False)
+ np.array([[0., 0., 0., 1.], [0., .5, .5, 0.]]).astype(np.float64))
def testGradient(self):
with self.test_session():
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index ff340d7f87..9ac80960b2 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -516,7 +516,7 @@ def boolean_mask(tensor, mask, name="boolean_mask"):
```python
# 2-D example
- a = [[1, 2], [3, 4], [5, 6]]
+ tensor = [[1, 2], [3, 4], [5, 6]]
mask = [True, False, True]
boolean_mask(tensor, mask) ==> [[1, 2], [5, 6]]
```
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index ab71e28cab..009ed8653a 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -287,8 +287,11 @@ def _LgammaGrad(op, grad):
@ops.RegisterGradient("Digamma")
-def _DigammaGrad(op, grad): # pylint: disable=unused-argument
- raise NotImplementedError("grad(Digamma) == Polygamma(1) is not implemented")
+def _DigammaGrad(op, grad):
+ """Compute gradient of the digamma function with respect to its argument."""
+ x = op.inputs[0]
+ with ops.control_dependencies([grad.op]):
+ return grad * math_ops.polygamma(array_ops.constant(1, dtype=x.dtype), x)
@ops.RegisterGradient("Igamma")
@@ -314,6 +317,40 @@ def _IgammacGrad(op, grad):
return [-1 * g if g is not None else None for g in _IgammaGrad(op, grad)]
+@ops.RegisterGradient("Zeta")
+def _ZetaGrad(op, grad):
+ """Returns gradient of zeta(x, q) with respect to x and q."""
+ # TODO(tillahoffmann): Add derivative with respect to x
+ x = op.inputs[0]
+ q = op.inputs[1]
+ # Broadcast gradients
+ sx = array_ops.shape(x)
+ sq = array_ops.shape(q)
+ unused_rx, rq = gen_array_ops._broadcast_gradient_args(sx, sq)
+ # Evaluate gradient
+ with ops.control_dependencies([grad.op]):
+ partial_q = -x * math_ops.zeta(x + 1, q)
+ return (None,
+ array_ops.reshape(math_ops.reduce_sum(partial_q * grad, rq), sq))
+
+
+@ops.RegisterGradient("Polygamma")
+def _PolygammaGrad(op, grad):
+ """Returns gradient of psi(n, x) with respect to n and x."""
+ # TODO(tillahoffmann): Add derivative with respect to n
+ n = op.inputs[0]
+ x = op.inputs[1]
+ # Broadcast gradients
+ sn = array_ops.shape(n)
+ sx = array_ops.shape(x)
+ unused_rn, rx = gen_array_ops._broadcast_gradient_args(sn, sx)
+ # Evaluate gradient
+ with ops.control_dependencies([grad.op]):
+ partial_x = math_ops.polygamma(n + 1, x)
+ return (None,
+ array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))
+
+
@ops.RegisterGradient("Sigmoid")
def _SigmoidGrad(op, grad):
"""Returns grad * sigmoid(x) * (1 - sigmoid(x))."""
diff --git a/tensorflow/python/ops/math_grad_test.py b/tensorflow/python/ops/math_grad_test.py
index 39b7d1193f..04ee61d640 100644
--- a/tensorflow/python/ops/math_grad_test.py
+++ b/tensorflow/python/ops/math_grad_test.py
@@ -65,7 +65,7 @@ class AbsOpTest(tf.test.TestCase):
def _testGrad(self, shape, dtype=None, max_error=None, bias=None, sigma=None):
np.random.seed(7)
- if dtype == tf.complex64:
+ if dtype in (tf.complex64, tf.complex128):
value = tf.complex(self._biasedRandN(shape, bias=bias, sigma=sigma),
self._biasedRandN(shape, bias=bias, sigma=sigma))
else:
@@ -74,7 +74,7 @@ class AbsOpTest(tf.test.TestCase):
for use_gpu in [True, False]:
with self.test_session(use_gpu=use_gpu):
- if dtype == tf.complex64:
+ if dtype in (tf.complex64, tf.complex128):
output = tf.complex_abs(value)
else:
output = tf.abs(value)
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index c622d83490..3075c8b9ef 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -58,6 +58,8 @@ mathematical functions to your graph.
@@squared_difference
@@igamma
@@igammac
+@@zeta
+@@polygamma
## Matrix Math Functions
@@ -240,11 +242,36 @@ def abs(x, name=None):
"""
with ops.op_scope([x], name, "Abs") as name:
x = ops.convert_to_tensor(x, name="x")
- if x.dtype == dtypes.complex64:
- return gen_math_ops.complex_abs(x, name=name)
+ if x.dtype in (dtypes.complex64, dtypes.complex128):
+ return gen_math_ops.complex_abs(x, Tout=x.dtype.real_dtype, name=name)
return gen_math_ops._abs(x, name=name)
+def complex_abs(x, name=None):
+ r"""Computes the complex absolute value of a tensor.
+
+ Given a tensor `x` of complex numbers, this operation returns a tensor of type
+ `float` or `double` that is the absolute value of each element in `x`. All
+ elements in `x` must be complex numbers of the form \\(a + bj\\). The
+ absolute value is computed as \\( \sqrt{a^2 + b^2}\\).
+
+ For example:
+
+ ```
+ # tensor 'x' is [[-2.25 + 4.75j], [-3.25 + 5.75j]]
+ tf.complex_abs(x) ==> [5.25594902, 6.60492229]
+ ```
+
+ Args:
+ x: A `Tensor` of type `complex64` or `complex128`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` of type `float32` or `float64`.
+ """
+ return gen_math_ops.complex_abs(x, Tout=x.dtype.real_dtype, name=name)
+
+
def scalar_mul(scalar, x):
"""Multiplies a scalar times a `Tensor` or `IndexedSlices` object.
@@ -302,29 +329,94 @@ def complex(real, imag, name=None):
Given a tensor `real` representing the real part of a complex number, and a
tensor `imag` representing the imaginary part of a complex number, this
- operation computes complex numbers elementwise of the form \\\\(a + bj\\\\),
- where *a* represents the `real` part and *b* represents the `imag` part.
+ operation returns complex numbers elementwise of the form \\(a + bj\\), where
+ *a* represents the `real` part and *b* represents the `imag` part.
- The input tensors `real` and `imag` must be the same shape.
+ The input tensors `real` and `imag` must have the same shape.
For example:
```
# tensor 'real' is [2.25, 3.25]
# tensor `imag` is [4.75, 5.75]
- tf.complex(real, imag) ==> [[2.25 + 4.74j], [3.25 + 5.75j]]
+ tf.complex(real, imag) ==> [[2.25 + 4.75j], [3.25 + 5.75j]]
```
Args:
- real: A `Tensor` of type `float`.
- imag: A `Tensor` of type `float`.
+ real: A `Tensor`. Must be one of the following types: `float32`, `float64`.
+ imag: A `Tensor`. Must have the same type as `real`.
name: A name for the operation (optional).
Returns:
- A `Tensor` of type `complex64`.
+ A `Tensor` of type `complex64` or `complex128`.
"""
+ real = ops.convert_to_tensor(real, name="real")
+ imag = ops.convert_to_tensor(imag, name="imag")
with ops.op_scope([real, imag], name, "Complex") as name:
- return gen_math_ops._complex(real, imag, name=name)
+ input_types = (real.dtype, imag.dtype)
+ if input_types == (dtypes.float64, dtypes.float64):
+ Tout = dtypes.complex128
+ elif input_types == (dtypes.float32, dtypes.float32):
+ Tout = dtypes.complex64
+ else:
+ raise TypeError("Types of real and imag don't match: "
+ "{} {}".format(real.dtype.name, imag.dtype.name))
+ return gen_math_ops._complex(real, imag, Tout=Tout, name=name)
+
+
+def real(input, name=None):
+ """Returns the real part of a complex number.
+
+ Given a tensor `input` of complex numbers, this operation returns a tensor of
+ type `float` or `double` that is the real part of each element in `input`.
+ All elements in `input` must be complex numbers of the form \\(a + bj\\),
+ where *a* is the real part returned by this operation and *b* is the
+ imaginary part.
+
+ For example:
+
+ ```
+ # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
+ tf.real(input) ==> [-2.25, 3.25]
+ ```
+
+ Args:
+ input: A `Tensor`. Must be one of the following types: `complex64`,
+ `complex128`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` of type `float` or `double`.
+ """
+ with ops.op_scope([input], name, "Real") as name:
+ return gen_math_ops.real(input, Tout=input.dtype.real_dtype, name=name)
+
+
+def imag(input, name=None):
+ """Returns the imaginary part of a complex number.
+
+ Given a tensor `input` of complex numbers, this operation returns a tensor of
+ type `float` or `double` that is the imaginary part of each element in
+ `input`. All elements in `input` must be complex numbers of the form \\(a +
+ bj\\), where *a* is the real part and *b* is the imaginary part returned by
+ this operation.
+
+ For example:
+
+ ```
+ # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
+ tf.imag(input) ==> [4.75, 5.75]
+ ```
+
+ Args:
+ input: A `Tensor`. Must be one of the following types: `complex64`, `complex128`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` of type `float` or `double`.
+ """
+ with ops.op_scope([input], name, "Imag") as name:
+ return gen_math_ops.imag(input, Tout=input.dtype.real_dtype, name=name)
def round(x, name=None):
@@ -559,6 +651,7 @@ _TRUEDIV_TABLE = {
dtypes.float32: None,
dtypes.float64: None,
dtypes.complex64: None,
+ dtypes.complex128: None,
}
@@ -1379,6 +1472,8 @@ ops.RegisterShape("BatchIFFT3D")(common_shapes.unchanged_shape)
@ops.RegisterShape("GreaterEqual")
@ops.RegisterShape("Igamma")
@ops.RegisterShape("Igammac")
+@ops.RegisterShape("Zeta")
+@ops.RegisterShape("Polygamma")
@ops.RegisterShape("Less")
@ops.RegisterShape("LessEqual")
@ops.RegisterShape("LogicalAnd")
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 6555096eae..662405d924 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -68,15 +68,15 @@ def atrous_conv2d(value, filters, rate, padding, name=None):
the amount of computation.
For a description of atrous convolution and how it can be used for dense
- feature extraction, please see: (Semantic Image Segmentation with Deep
- Convolutional Nets and Fully Connected CRFs)[http://arxiv.org/abs/1412.7062].
- The same operation is investigated further in (Multi-Scale Context Aggregation
- by Dilated Convolutions)[http://arxiv.org/abs/1511.07122]. Previous works
+ feature extraction, please see: [Semantic Image Segmentation with Deep
+ Convolutional Nets and Fully Connected CRFs](http://arxiv.org/abs/1412.7062).
+ The same operation is investigated further in [Multi-Scale Context Aggregation
+ by Dilated Convolutions](http://arxiv.org/abs/1511.07122). Previous works
that effectively use atrous convolution in different ways are, among others,
- (OverFeat: Integrated Recognition, Localization and Detection using
- Convolutional Networks) [http://arxiv.org/abs/1312.6229] and (Fast Image
- Scanning with Deep Max-Pooling Convolutional Neural Networks)
- [http://arxiv.org/abs/1302.1700]. Atrous convolution is also closely related
+ [OverFeat: Integrated Recognition, Localization and Detection using
+ Convolutional Networks](http://arxiv.org/abs/1312.6229) and [Fast Image
+ Scanning with Deep Max-Pooling Convolutional Neural Networks]
+ (http://arxiv.org/abs/1302.1700). Atrous convolution is also closely related
to the so-called noble identities in multi-rate signal processing.
There are many different ways to implement atrous convolution (see the refs
@@ -227,8 +227,8 @@ def conv2d_transpose(value,
name=None):
"""The transpose of `conv2d`.
- This operation is sometimes called "deconvolution" after (Deconvolutional
- Networks)[http://www.matthewzeiler.com/pubs/cvpr2010/cvpr2010.pdf], but is
+ This operation is sometimes called "deconvolution" after [Deconvolutional
+ Networks](http://www.matthewzeiler.com/pubs/cvpr2010/cvpr2010.pdf), but is
actually the transpose (gradient) of `conv2d` rather than an actual
deconvolution.
diff --git a/tensorflow/python/ops/rnn_cell.py b/tensorflow/python/ops/rnn_cell.py
index c6fbea1933..bfd0758883 100644
--- a/tensorflow/python/ops/rnn_cell.py
+++ b/tensorflow/python/ops/rnn_cell.py
@@ -120,10 +120,11 @@ class RNNCell(object):
class BasicRNNCell(RNNCell):
"""The most basic RNN cell."""
- def __init__(self, num_units, input_size=None):
+ def __init__(self, num_units, input_size=None, activation=tanh):
if input_size is not None:
logging.warn("%s: The input_size parameter is deprecated." % self)
self._num_units = num_units
+ self._activation = activation
@property
def state_size(self):
@@ -134,19 +135,20 @@ class BasicRNNCell(RNNCell):
return self._num_units
def __call__(self, inputs, state, scope=None):
- """Most basic RNN: output = new_state = tanh(W * input + U * state + B)."""
+ """Most basic RNN: output = new_state = activation(W * input + U * state + B)."""
with vs.variable_scope(scope or type(self).__name__): # "BasicRNNCell"
- output = tanh(_linear([inputs, state], self._num_units, True))
+ output = self._activation(_linear([inputs, state], self._num_units, True))
return output, output
class GRUCell(RNNCell):
"""Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078)."""
- def __init__(self, num_units, input_size=None):
+ def __init__(self, num_units, input_size=None, activation=tanh):
if input_size is not None:
logging.warn("%s: The input_size parameter is deprecated." % self)
self._num_units = num_units
+ self._activation = activation
@property
def state_size(self):
@@ -165,7 +167,8 @@ class GRUCell(RNNCell):
2 * self._num_units, True, 1.0))
r, u = sigmoid(r), sigmoid(u)
with vs.variable_scope("Candidate"):
- c = tanh(_linear([inputs, r * state], self._num_units, True))
+ c = self._activation(_linear([inputs, r * state],
+ self._num_units, True))
new_h = u * state + (1 - u) * c
return new_h, new_h
@@ -185,7 +188,7 @@ class BasicLSTMCell(RNNCell):
"""
def __init__(self, num_units, forget_bias=1.0, input_size=None,
- state_is_tuple=False):
+ state_is_tuple=False, activation=tanh):
"""Initialize the basic LSTM cell.
Args:
@@ -195,6 +198,7 @@ class BasicLSTMCell(RNNCell):
state_is_tuple: If True, accepted and returned states are 2-tuples of
the `c_state` and `m_state`. By default (False), they are concatenated
along the column axis. This default behavior will soon be deprecated.
+ activation: Activation function of the inner states.
"""
if not state_is_tuple:
logging.warn(
@@ -205,6 +209,7 @@ class BasicLSTMCell(RNNCell):
self._num_units = num_units
self._forget_bias = forget_bias
self._state_is_tuple = state_is_tuple
+ self._activation = activation
@property
def state_size(self):
@@ -228,8 +233,9 @@ class BasicLSTMCell(RNNCell):
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = array_ops.split(1, 4, concat)
- new_c = c * sigmoid(f + self._forget_bias) + sigmoid(i) * tanh(j)
- new_h = tanh(new_c) * sigmoid(o)
+ new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) *
+ self._activation(j))
+ new_h = self._activation(new_c) * sigmoid(o)
if self._state_is_tuple:
new_state = (new_c, new_h)
@@ -300,7 +306,8 @@ class LSTMCell(RNNCell):
use_peepholes=False, cell_clip=None,
initializer=None, num_proj=None,
num_unit_shards=1, num_proj_shards=1,
- forget_bias=1.0, state_is_tuple=False):
+ forget_bias=1.0, state_is_tuple=False,
+ activation=tanh):
"""Initialize the parameters for an LSTM cell.
Args:
@@ -323,6 +330,7 @@ class LSTMCell(RNNCell):
state_is_tuple: If True, accepted and returned states are 2-tuples of
the `c_state` and `m_state`. By default (False), they are concatenated
along the column axis. This default behavior will soon be deprecated.
+ activation: Activation function of the inner states.
"""
if not state_is_tuple:
logging.warn(
@@ -339,6 +347,7 @@ class LSTMCell(RNNCell):
self._num_proj_shards = num_proj_shards
self._forget_bias = forget_bias
self._state_is_tuple = state_is_tuple
+ self._activation = activation
if num_proj:
self._state_size = (
@@ -420,9 +429,10 @@ class LSTMCell(RNNCell):
if self._use_peepholes:
c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
- sigmoid(i + w_i_diag * c_prev) * tanh(j))
+ sigmoid(i + w_i_diag * c_prev) * self._activation(j))
else:
- c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * tanh(j))
+ c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
+ self._activation(j))
if self._cell_clip is not None:
# pylint: disable=invalid-unary-operand-type
@@ -430,9 +440,9 @@ class LSTMCell(RNNCell):
# pylint: enable=invalid-unary-operand-type
if self._use_peepholes:
- m = sigmoid(o + w_o_diag * c) * tanh(c)
+ m = sigmoid(o + w_o_diag * c) * self._activation(c)
else:
- m = sigmoid(o) * tanh(c)
+ m = sigmoid(o) * self._activation(c)
if self._num_proj is not None:
concat_w_proj = _get_concat_variable(
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py
index 1e75bac23c..a3f4d536d8 100644
--- a/tensorflow/python/training/moving_averages.py
+++ b/tensorflow/python/training/moving_averages.py
@@ -210,7 +210,7 @@ class ExponentialMovingAverage(object):
def __init__(self, decay, num_updates=None, name="ExponentialMovingAverage"):
"""Creates a new ExponentialMovingAverage object.
- The `Apply()` method has to be called to create shadow variables and add
+ The `apply()` method has to be called to create shadow variables and add
ops to maintain moving averages.
The optional `num_updates` parameter allows one to tweak the decay rate
@@ -225,7 +225,7 @@ class ExponentialMovingAverage(object):
decay: Float. The decay to use.
num_updates: Optional count of number of updates applied to variables.
name: String. Optional prefix name to use for the name of ops added in
- `Apply()`.
+ `apply()`.
"""
self._decay = decay
self._num_updates = num_updates
diff --git a/tensorflow/tools/docker/README.md b/tensorflow/tools/docker/README.md
index fba6c3144a..be54f97b9e 100644
--- a/tensorflow/tools/docker/README.md
+++ b/tensorflow/tools/docker/README.md
@@ -40,10 +40,14 @@ accomplished via
$ export CUDA_SO=$(\ls /usr/lib/x86_64-linux-gnu/libcuda.* | xargs -I{} echo '-v {}:{}')
$ export DEVICES=$(\ls /dev/nvidia* | xargs -I{} echo '--device {}:{}')
- $ docker run -it -p 8888:8888 $CUDA_SO $DEVICES gcr.io/tensorflow/tensorflow-devel-gpu
+ $ docker run -it -p 8888:8888 $CUDA_SO $DEVICES gcr.io/tensorflow/tensorflow-gpu
Alternately, you can use the `docker_run_gpu.sh` script in this directory.
+In order to set Jupyter Notebook to require a password, the `-e PASSWORD=pass` option must be provided
+
+ $ docker run -it -p 8888:8888 $CUDA_SO $DEVICES -e PASSWORD=pass gcr.io/tensorflow/tensorflow-gpu
+
## Rebuilding the containers
Just pick the dockerfile corresponding to the container you want to build, and run;
diff --git a/tensorflow/tools/docker/jupyter_notebook_config.py b/tensorflow/tools/docker/jupyter_notebook_config.py
index 9bf6f98ce3..bcc4234459 100644
--- a/tensorflow/tools/docker/jupyter_notebook_config.py
+++ b/tensorflow/tools/docker/jupyter_notebook_config.py
@@ -12,8 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
+import os
+from IPython.lib import passwd
c.NotebookApp.ip = '*'
c.NotebookApp.port = 8888
c.NotebookApp.open_browser = False
c.MultiKernelManager.default_kernel_name = 'python2'
+
+# sets a password if PASSWORD is set in the environment
+if 'PASSWORD' in os.environ:
+ c.NotebookApp.password = passwd(os.environ['PASSWORD'])
+ del os.environ['PASSWORD']
diff --git a/third_party/gpus/cuda/platform.bzl b/third_party/gpus/cuda/platform.bzl
index 20ab441bf4..06f3d0cff4 100644
--- a/third_party/gpus/cuda/platform.bzl
+++ b/third_party/gpus/cuda/platform.bzl
@@ -1,7 +1,5 @@
CUDA_VERSION = ""
-
CUDNN_VERSION = ""
-
PLATFORM = ""
def cuda_sdk_version():