diff options
author | 2016-05-23 11:39:39 -0800 | |
---|---|---|
committer | 2016-05-23 12:42:36 -0700 | |
commit | 892ca4ddc12852a7b4633fd08f163941356cb4e6 (patch) | |
tree | be913f46bb9323685c5a807a89fca6dc52a25504 | |
parent | 76d90938f95a14a5723c253ec8529e93939a25e2 (diff) |
Merge changes from github.
Change: 123026122
136 files changed, 1402 insertions, 509 deletions
diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc index bfad5515e6..629072ed7e 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc @@ -22,8 +22,6 @@ #include <sys/wait.h> #include <unistd.h> -#include <string> -#include <tuple> #include <vector> #include "tensorflow/core/lib/io/path.h" @@ -40,6 +38,7 @@ namespace { const char kFfmpegExecutable[] = "ffmpeg"; const int32 kDefaultProbeSize = 5000000; // 5MB + std::vector<string> FfmpegCommandLine(const string& input_filename, const string& output_filename, const string& input_format_id, diff --git a/tensorflow/contrib/learn/python/learn/README.md b/tensorflow/contrib/learn/python/learn/README.md index f557999828..2ab165f284 100644 --- a/tensorflow/contrib/learn/python/learn/README.md +++ b/tensorflow/contrib/learn/python/learn/README.md @@ -105,7 +105,7 @@ iris = datasets.load_iris() def my_model(X, y): """This is DNN with 10, 20, 10 hidden layers, and dropout of 0.5 probability.""" - layers = learn.ops.dnn(X, [10, 20, 10], keep_prob=0.5) + layers = learn.ops.dnn(X, [10, 20, 10], dropout=0.5) return learn.models.logistic_regression(layers, y) classifier = learn.TensorFlowEstimator(model_fn=my_model, n_classes=3) diff --git a/tensorflow/contrib/learn/python/learn/estimators/__init__.py b/tensorflow/contrib/learn/python/learn/estimators/__init__.py index 1ad6491a95..e714c15f2e 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/__init__.py +++ b/tensorflow/contrib/learn/python/learn/estimators/__init__.py @@ -16,7 +16,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowEstimator +from tensorflow.contrib.learn.python.learn.estimators.autoencoder import TensorFlowDNNAutoencoder +from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowEstimator, TensorFlowBaseTransformer from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNClassifier from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNRegressor from tensorflow.contrib.learn.python.learn.estimators.dnn import TensorFlowDNNClassifier diff --git a/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py b/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py index 491eba45a4..dcd1d81056 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py @@ -16,8 +16,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections +import os + import numpy as np +def _pprint(d): + return ', '.join(['%s=%s' % (key, str(value)) for key, value in d.items()]) + class _BaseEstimator(object): """This is a cross-import when sklearn is not available. @@ -43,6 +49,9 @@ class _BaseEstimator(object): for key in param_names: value = getattr(self, key, None) + if isinstance(value, collections.Callable): + continue + # XXX: should we rather test if instance of estimator? if deep and hasattr(value, 'get_params'): deep_items = value.get_params().items() @@ -90,9 +99,7 @@ class _BaseEstimator(object): def __repr__(self): class_name = self.__class__.__name__ return '%s(%s)' % (class_name, - _pprint( - self.get_params(deep=False), - offset=len(class_name),),) + _pprint(self.get_params(deep=False)),) class _ClassifierMixin(): @@ -104,6 +111,8 @@ class _RegressorMixin(): """Mixin class for all regression estimators.""" pass +class _TransformerMixin(): + """Mixin class for all transformer estimators.""" class _NotFittedError(ValueError, AttributeError): """Exception class to raise if estimator is used before fitting. @@ -152,6 +161,7 @@ def _train_test_split(*args, **options): train_size = 1 - test_size train_size = train_size * args[0].shape[0] + np.random.seed(random_state) indices = np.random.permutation(args[0].shape[0]) train_idx, test_idx = indices[:train_size], indices[:train_size] result = [] @@ -159,30 +169,29 @@ def _train_test_split(*args, **options): result += [x.take(train_idx, axis=0), x.take(test_idx, axis=0)] return tuple(result) -# Try to import sklearn, if fail - use _BaseEstimator. -try: - from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin -except ImportError: + +# If "TENSORFLOW_SKLEARN" flag is defined then try to import from sklearn. +TRY_IMPORT_SKLEARN = os.environ.get('TENSORFLOW_SKLEARN', False) +if TRY_IMPORT_SKLEARN: + from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin, TransformerMixin + from sklearn.metrics import accuracy_score, log_loss, mean_squared_error + from sklearn.cross_validation import train_test_split + try: + from sklearn.exceptions import NotFittedError + except ImportError: + try: + from sklearn.utils.validation import NotFittedError + except ImportError: + NotFittedError = _NotFittedError +else: + # Naive implementations of sklearn classes and functions. BaseEstimator = _BaseEstimator ClassifierMixin = _ClassifierMixin RegressorMixin = _RegressorMixin - -# Try to import exception for not fitted error. -try: - from sklearn.exceptions import NotFittedError -except ImportError: + TransformerMixin = _TransformerMixin NotFittedError = _NotFittedError - -# Try to import metrics -try: - from sklearn.metrics import accuracy_score, log_loss, mean_squared_error -except ImportError: accuracy_score = _accuracy_score log_loss = None mean_squared_error = _mean_squared_error - -# Try to import train_test_split -try: - from sklearn.cross_validation import train_test_split -except ImportError: train_test_split = _train_test_split + diff --git a/tensorflow/contrib/learn/python/learn/estimators/autoencoder.py b/tensorflow/contrib/learn/python/learn/estimators/autoencoder.py new file mode 100644 index 0000000000..690bac8f19 --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/estimators/autoencoder.py @@ -0,0 +1,116 @@ +"""Deep Autoencoder estimators.""" +# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.ops import nn +from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowBaseTransformer +from tensorflow.contrib.learn.python.learn import models + + +class TensorFlowDNNAutoencoder(TensorFlowBaseTransformer): + """TensorFlow Autoencoder Regressor model. + + Parameters: + hidden_units: List of hidden units per layer. + batch_size: Mini batch size. + activation: activation function used to map inner latent layer onto + reconstruction layer. + add_noise: a function that adds noise to tensor_in, + e.g. def add_noise(x): + return(x + np.random.normal(0, 0.1, (len(x), len(x[0])))) + steps: Number of steps to run over data. + optimizer: Optimizer name (or class), for example "SGD", "Adam", + "Adagrad". + learning_rate: If this is constant float value, no decay function is used. + Instead, a customized decay function can be passed that accepts + global_step as parameter and returns a Tensor. + e.g. exponential decay function: + def exp_decay(global_step): + return tf.train.exponential_decay( + learning_rate=0.1, global_step, + decay_steps=2, decay_rate=0.001) + continue_training: when continue_training is True, once initialized + model will be continuely trained on every call of fit. + config: RunConfig object that controls the configurations of the session, + e.g. num_cores, gpu_memory_fraction, etc. + verbose: Controls the verbosity, possible values: + 0: the algorithm and debug information is muted. + 1: trainer prints the progress. + 2: log device placement is printed. + dropout: When not None, the probability we will drop out a given + coordinate. + """ + def __init__(self, hidden_units, n_classes=0, batch_size=32, + steps=200, optimizer="Adagrad", learning_rate=0.1, + clip_gradients=5.0, activation=nn.relu, add_noise=None, + continue_training=False, config=None, + verbose=1, dropout=None): + self.hidden_units = hidden_units + self.dropout = dropout + self.activation = activation + self.add_noise = add_noise + super(TensorFlowDNNAutoencoder, self).__init__( + model_fn=self._model_fn, + n_classes=n_classes, + batch_size=batch_size, steps=steps, optimizer=optimizer, + learning_rate=learning_rate, clip_gradients=clip_gradients, + continue_training=continue_training, + config=config, verbose=verbose) + + def _model_fn(self, X, y): + encoder, decoder, autoencoder_estimator = models.get_autoencoder_model( + self.hidden_units, + models.linear_regression, + activation=self.activation, + add_noise=self.add_noise, + dropout=self.dropout)(X) + self.encoder = encoder + self.decoder = decoder + return autoencoder_estimator + + def generate(self, hidden=None): + """Generate new data using trained construction layer""" + if hidden is None: + last_layer = len(self.hidden_units) - 1 + bias = self.get_tensor_value('encoder/dnn/layer%d/Linear/Bias:0' % last_layer) + import numpy as np + hidden = np.random.normal(size=bias.shape) + hidden = np.reshape(hidden, (1, len(hidden))) + return self._session.run(self.decoder, feed_dict={self.encoder: hidden}) + + @property + def weights_(self): + """Returns weights of the autoencoder's weight layers.""" + weights = [] + for layer in range(len(self.hidden_units)): + weights.append(self.get_tensor_value('encoder/dnn/layer%d/Linear/Matrix:0' % layer)) + for layer in range(len(self.hidden_units)): + weights.append(self.get_tensor_value('decoder/dnn/layer%d/Linear/Matrix:0' % layer)) + weights.append(self.get_tensor_value('linear_regression/weights:0')) + return weights + + @property + def bias_(self): + """Returns bias of the autoencoder's bias layers.""" + biases = [] + for layer in range(len(self.hidden_units)): + biases.append(self.get_tensor_value('encoder/dnn/layer%d/Linear/Bias:0' % layer)) + for layer in range(len(self.hidden_units)): + biases.append(self.get_tensor_value('decoder/dnn/layer%d/Linear/Bias:0' % layer)) + biases.append(self.get_tensor_value('linear_regression/bias:0')) + return biases + diff --git a/tensorflow/contrib/learn/python/learn/estimators/base.py b/tensorflow/contrib/learn/python/learn/estimators/base.py index 9459bdbcc9..a94ee9e717 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/base.py +++ b/tensorflow/contrib/learn/python/learn/estimators/base.py @@ -360,7 +360,7 @@ class TensorFlowEstimator(estimator.Estimator): reconfigured. Returns: - Estiamator, object of the subclass of TensorFlowEstimator. + Estimator, object of the subclass of TensorFlowEstimator. """ model_def_filename = os.path.join(path, 'model.def') if not os.path.exists(model_def_filename): @@ -379,6 +379,7 @@ class TensorFlowEstimator(estimator.Estimator): new_value = locals()[key] if new_value is not None: model_def[key] = new_value + class_name = model_def.pop('class_name') if class_name == 'TensorFlowEstimator': custom_estimator = TensorFlowEstimator(model_fn=None, **model_def) @@ -392,3 +393,18 @@ class TensorFlowEstimator(estimator.Estimator): estimator = getattr(estimators, class_name)(**model_def) estimator._restore(path) return estimator + + +class TensorFlowBaseTransformer(TensorFlowEstimator, _sklearn.TransformerMixin): + """TensorFlow Base Transformer class.""" + def transform(self, X): + """Transform X using trained transformer.""" + return(super(TensorFlowBaseTransformer, self).predict(X, axis=1, batch_size=None)) + + def fit(self, X, y=None, monitor=None, logdir=None): + """Fit a transformer.""" + return(super(TensorFlowBaseTransformer, self).fit(X, y, monitors=None, logdir=None)) + + def fit_transform(self, X, y=None, monitor=None, logdir=None): + """Fit transformer and transform X using trained transformer.""" + return(self.fit(X, y, monitor=None, logdir=None).transform(X)) diff --git a/tensorflow/contrib/learn/python/learn/io/data_feeder.py b/tensorflow/contrib/learn/python/learn/io/data_feeder.py index 866a860e0e..babb216d9d 100644 --- a/tensorflow/contrib/learn/python/learn/io/data_feeder.py +++ b/tensorflow/contrib/learn/python/learn/io/data_feeder.py @@ -88,9 +88,8 @@ def setup_train_data_feeder( X, y = _data_type_filter(X, y) if HAS_DASK: import dask.dataframe as dd - allowed_classes = (dd.Series, dd.DataFrame) - if (isinstance(X, allowed_classes) and - (y is None or isinstance(y, allowed_classes))): + if (isinstance(X, (dd.Series, dd.DataFrame)) and + (y is None or isinstance(y, (dd.Series, dd.DataFrame)))): data_feeder_cls = DaskDataFeeder else: data_feeder_cls = DataFeeder @@ -156,7 +155,7 @@ def setup_processor_data_feeder(X): def check_array(array, dtype): - """Checks array on dtype and convers it if different. + """Checks array on dtype and converts it if different. Args: array: Input array. @@ -165,7 +164,10 @@ def check_array(array, dtype): Returns: Original array or converted. """ - array = np.array(array, dtype=dtype, order=None, copy=False) + # skip check if array is instance of other classes, e.g. h5py.Dataset + # to avoid copying array and loading whole data into memory + if isinstance(array, (np.ndarray, list)): + array = np.array(array, dtype=dtype, order=None, copy=False) return array @@ -459,10 +461,8 @@ class DaskDataFeeder(object): """Data feeder for TF trainer that reads data from dask.Series and dask.DataFrame. Numpy arrays can be serialized to disk and it's possible to do random seeks - into them. - DaskDataFeeder will remove requirement to have full dataset in the memory - and still do - random seeks for sampling of batches. + into them. DaskDataFeeder will remove requirement to have full dataset in the + memory and still do random seeks for sampling of batches. Parameters: X: iterator that returns for each element, returns features. @@ -483,10 +483,9 @@ class DaskDataFeeder(object): input_dtype: dtype of input. output_dtype: dtype of output. """ - def __init__(self, X, y, n_classes, batch_size, shuffle=True, random_state=None): import dask.dataframe as dd - # TODO: check X and y dtypes in dask_io like pandas + # TODO(terrytangyuan): check X and y dtypes in dask_io like pandas self.X = X self.y = y # save column names @@ -497,6 +496,8 @@ class DaskDataFeeder(object): # deal with cases where two DFs have overlapped default numeric colnames self.y_columns = len(self.X_columns) + 1 self.y = self.y.rename(columns={y.columns[0]: self.y_columns}) + + # TODO(terrytangyuan): deal with unsupervised cases # combine into a data frame self.df = dd.multi.concat([self.X, self.y], axis=1) self.n_classes = n_classes @@ -535,7 +536,6 @@ class DaskDataFeeder(object): A function that when called samples a random subset of batch size from X and y. """ - def _feed_dict_fn(): # TODO: option for with/without replacement (dev version of dask) sample = self.df.random_split( @@ -555,5 +555,4 @@ class DaskDataFeeder(object): encoded_out[np.arange(out.size), out] = 1 return {input_placeholder.name: inp, output_placeholder.name: encoded_out} - return _feed_dict_fn diff --git a/tensorflow/contrib/learn/python/learn/models.py b/tensorflow/contrib/learn/python/learn/models.py index 2c83a95ab8..8cabd390fc 100644 --- a/tensorflow/contrib/learn/python/learn/models.py +++ b/tensorflow/contrib/learn/python/learn/models.py @@ -18,6 +18,7 @@ from __future__ import print_function from tensorflow.contrib.learn.python.learn.ops import dnn_ops from tensorflow.contrib.learn.python.learn.ops import losses_ops +from tensorflow.contrib.learn.python.learn.ops import autoencoder_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops as array_ops_ @@ -187,6 +188,36 @@ def get_dnn_model(hidden_units, target_predictor_fn, dropout=None): return dnn_estimator +def get_autoencoder_model(hidden_units, target_predictor_fn, + activation, add_noise=None, dropout=None): + """Returns a function that creates a Autoencoder TensorFlow subgraph with given + params. + + Args: + hidden_units: List of values of hidden units for layers. + target_predictor_fn: Function that will predict target from input + features. This can be logistic regression, + linear regression or any other model, + that takes X, y and returns predictions and loss tensors. + activation: activation function used to map inner latent layer onto + reconstruction layer. + add_noise: a function that adds noise to tensor_in, + e.g. def add_noise(x): + return(x + np.random.normal(0, 0.1, (len(x), len(x[0])))) + dropout: When not none, causes dropout regularization to be used, + with the specified probability of removing a given coordinate. + + Returns: + A function that creates the subgraph. + """ + def dnn_autoencoder_estimator(X): + """Autoencoder estimator with target predictor function on top.""" + encoder, decoder = autoencoder_ops.dnn_autoencoder( + X, hidden_units, activation, + add_noise=add_noise, dropout=dropout) + return encoder, decoder, target_predictor_fn(X, decoder) + return dnn_autoencoder_estimator + ## This will be in Tensorflow 0.7. ## TODO(ilblackdragon): Clean this up when it's released diff --git a/tensorflow/contrib/learn/python/learn/ops/__init__.py b/tensorflow/contrib/learn/python/learn/ops/__init__.py index b8d00ac27c..0bda7110aa 100644 --- a/tensorflow/contrib/learn/python/learn/ops/__init__.py +++ b/tensorflow/contrib/learn/python/learn/ops/__init__.py @@ -19,6 +19,7 @@ from __future__ import print_function from tensorflow.contrib.learn.python.learn.ops.array_ops import * from tensorflow.contrib.learn.python.learn.ops.conv_ops import * from tensorflow.contrib.learn.python.learn.ops.dnn_ops import * +from tensorflow.contrib.learn.python.learn.ops.autoencoder_ops import * from tensorflow.contrib.learn.python.learn.ops.dropout_ops import * from tensorflow.contrib.learn.python.learn.ops.embeddings_ops import * from tensorflow.contrib.learn.python.learn.ops.losses_ops import * diff --git a/tensorflow/contrib/learn/python/learn/ops/autoencoder_ops.py b/tensorflow/contrib/learn/python/learn/ops/autoencoder_ops.py new file mode 100644 index 0000000000..e337b3b124 --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/ops/autoencoder_ops.py @@ -0,0 +1,56 @@ +"""TensorFlow ops for autoencoder.""" +# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.ops import nn +from tensorflow.python.ops import variable_scope as vs +from tensorflow.contrib.learn.python.learn.ops import dnn_ops + + +def dnn_autoencoder(tensor_in, hidden_units, + activation=nn.relu, add_noise=None, + dropout=None, scope=None): + """Creates fully connected autoencoder subgraph. + + Args: + tensor_in: tensor or placeholder for input features. + hidden_units: list of counts of hidden units in each layer. + activation: activation function used to map inner latent layer onto + reconstruction layer. + add_noise: a function that adds noise to tensor_in, + e.g. def add_noise(x): + return(x + np.random.normal(0, 0.1, (len(x), len(x[0])))) + dropout: if not None, will add a dropout layer with given + probability. + scope: the variable scope for this op. + + Returns: + Tensors for encoder and decoder. + """ + with vs.variable_op_scope([tensor_in], scope, "autoencoder"): + if add_noise is not None: + tensor_in = add_noise(tensor_in) + with vs.variable_scope('encoder'): + # build DNN encoder + encoder = dnn_ops.dnn(tensor_in, hidden_units, + activation=activation, dropout=dropout) + with vs.variable_scope('decoder'): + # reverse hidden_units and built DNN decoder + decoder = dnn_ops.dnn(encoder, hidden_units[::-1], + activation=activation, dropout=dropout) + return encoder, decoder + diff --git a/tensorflow/contrib/learn/python/learn/tests/test_custom_decay.py b/tensorflow/contrib/learn/python/learn/tests/test_custom_decay.py index 3963f5eb4e..cb459e1bb4 100644 --- a/tensorflow/contrib/learn/python/learn/tests/test_custom_decay.py +++ b/tensorflow/contrib/learn/python/learn/tests/test_custom_decay.py @@ -45,12 +45,12 @@ class CustomDecayTest(tf.test.TestCase): classifier = learn.TensorFlowDNNClassifier(hidden_units=[10, 20, 10], n_classes=3, - steps=800, + steps=500, learning_rate=exp_decay) classifier.fit(X_train, y_train) score = accuracy_score(y_test, classifier.predict(X_test)) - self.assertGreater(score, 0.7, "Failed with score = {0}".format(score)) + self.assertGreater(score, 0.65, "Failed with score = {0}".format(score)) if __name__ == "__main__": diff --git a/tensorflow/contrib/learn/python/learn/tests/test_data_feeder.py b/tensorflow/contrib/learn/python/learn/tests/test_data_feeder.py index ea361a0d68..77f9f2d6f2 100644 --- a/tensorflow/contrib/learn/python/learn/tests/test_data_feeder.py +++ b/tensorflow/contrib/learn/python/learn/tests/test_data_feeder.py @@ -126,6 +126,26 @@ class DataFeederTest(tf.test.TestCase): [0.60000002, 0.2]]) self.assertAllClose(feed_dict[out.name], [[0., 0., 1.], [0., 1., 0.]]) + def test_hdf5_data_feeder(self): + try: + import h5py + X = np.matrix([[1, 2], [3, 4]]) + y = np.array([1, 2]) + h5f = h5py.File('test_hdf5.h5', 'w') + h5f.create_dataset('X', data=X) + h5f.create_dataset('y', data=y) + h5f.close() + h5f = h5py.File('test_hdf5.h5', 'r') + X = h5f['X'] + y = h5f['y'] + df = data_feeder.DataFeeder(X, y, n_classes=0, batch_size=3) + inp, out = df.input_builder() + feed_dict_fn = df.get_feed_dict_fn() + feed_dict = feed_dict_fn() + self.assertAllClose(feed_dict[inp.name], [[3, 4], [1, 2]]) + self.assertAllClose(feed_dict[out.name], [2, 1]) + except ImportError: + print("Skipped test for hdf5 since it's not installed.") class SetupPredictDataFeederTest(tf.test.TestCase): diff --git a/tensorflow/contrib/learn/python/learn/tests/test_estimators.py b/tensorflow/contrib/learn/python/learn/tests/test_estimators.py index 0c93c5aa50..069a146593 100644 --- a/tensorflow/contrib/learn/python/learn/tests/test_estimators.py +++ b/tensorflow/contrib/learn/python/learn/tests/test_estimators.py @@ -35,7 +35,6 @@ class CustomOptimizer(tf.test.TestCase): iris.target, test_size=0.2, random_state=42) - # setup exponential decay function def exp_decay(global_step): return tf.train.exponential_decay(learning_rate=0.1, @@ -48,13 +47,13 @@ class CustomOptimizer(tf.test.TestCase): classifier = learn.TensorFlowDNNClassifier(hidden_units=[10, 20, 10], n_classes=3, - steps=800, + steps=400, learning_rate=exp_decay, optimizer=custom_optimizer) classifier.fit(X_train, y_train) score = accuracy_score(y_test, classifier.predict(X_test)) - self.assertGreater(score, 0.7, "Failed with score = {0}".format(score)) + self.assertGreater(score, 0.65, "Failed with score = {0}".format(score)) if __name__ == "__main__": diff --git a/tensorflow/contrib/learn/python/learn/tests/test_grid_search.py b/tensorflow/contrib/learn/python/learn/tests/test_grid_search.py index ad744eb2be..1008c53c99 100644 --- a/tensorflow/contrib/learn/python/learn/tests/test_grid_search.py +++ b/tensorflow/contrib/learn/python/learn/tests/test_grid_search.py @@ -15,15 +15,17 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import random -try: - from sklearn import datasets - from sklearn.grid_search import GridSearchCV - from sklearn.metrics import accuracy_score, mean_squared_error - HAS_SKLEARN = True -except ImportError: - HAS_SKLEARN = False +HAS_SKLEARN = os.environ.get('TENSORFLOW_SKLEARN', False) +if HAS_SKLEARN: + try: + from sklearn import datasets + from sklearn.grid_search import GridSearchCV + from sklearn.metrics import accuracy_score, mean_squared_error + except ImportError: + HAS_SKLEARN = False import tensorflow as tf diff --git a/tensorflow/contrib/learn/python/learn/tests/test_nonlinear.py b/tensorflow/contrib/learn/python/learn/tests/test_nonlinear.py index 5d50d2bbbb..7b8a27eebc 100644 --- a/tensorflow/contrib/learn/python/learn/tests/test_nonlinear.py +++ b/tensorflow/contrib/learn/python/learn/tests/test_nonlinear.py @@ -167,6 +167,16 @@ class NonLinearTest(tf.test.TestCase): predictions = classifier.predict(test_data) self.assertAllClose(predictions, np.array([1, 0])) + # def testDNNAutoencoder(self): + # import numpy as np + # iris = datasets.load_iris() + # autoencoder = learn.TensorFlowDNNAutoencoder(hidden_units=[10, 20]) + # transformed = autoencoder.fit_transform(iris.data[1:2]) + # expected = np.array([[ -3.57627869e-07, 1.17000043e+00, 1.01902664e+00, 1.19209290e-07, + # 0.00000000e+00, 1.19209290e-07, -5.96046448e-08, -2.38418579e-07, + # 9.74681854e-01, 1.19209290e-07]]) + # self.assertAllClose(transformed, expected) + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/contrib/learn/python/learn/tests/test_saver.py b/tensorflow/contrib/learn/python/learn/tests/test_saver.py index 28e3197b4c..8592e9c0bc 100644 --- a/tensorflow/contrib/learn/python/learn/tests/test_saver.py +++ b/tensorflow/contrib/learn/python/learn/tests/test_saver.py @@ -83,6 +83,5 @@ class SaverTest(tf.test.TestCase): with self.assertRaises(NotImplementedError): learn.TensorFlowEstimator.restore(path) - if __name__ == '__main__': tf.test.main() diff --git a/tensorflow/contrib/metrics/python/kernel_tests/confusion_matrix_ops_test.py b/tensorflow/contrib/metrics/python/kernel_tests/confusion_matrix_ops_test.py index fc71d5c690..1ab27aaa90 100644 --- a/tensorflow/contrib/metrics/python/kernel_tests/confusion_matrix_ops_test.py +++ b/tensorflow/contrib/metrics/python/kernel_tests/confusion_matrix_ops_test.py @@ -23,15 +23,15 @@ import tensorflow as tf class ConfusionMatrixTest(tf.test.TestCase): - def _testConfMatrix(self, predictions, targets, truth): + def _testConfMatrix(self, predictions, labels, truth): with self.test_session(): - ans = tf.contrib.metrics.confusion_matrix(predictions, targets) + ans = tf.contrib.metrics.confusion_matrix(predictions, labels) tf_ans = ans.eval() self.assertAllClose(tf_ans, truth, atol=1e-10) def _testBasic(self, dtype): predictions = np.arange(5, dtype=dtype) - targets = np.arange(5, dtype=dtype) + labels = np.arange(5, dtype=dtype) truth = np.asarray( [[1, 0, 0, 0, 0], @@ -43,7 +43,7 @@ class ConfusionMatrixTest(tf.test.TestCase): self._testConfMatrix( predictions=predictions, - targets=targets, + labels=labels, truth=truth) def testInt32Basic(self, dtype=np.int32): @@ -54,7 +54,7 @@ class ConfusionMatrixTest(tf.test.TestCase): def _testDiffentLabelsInPredictionAndTarget(self, dtype): predictions = np.asarray([1, 2, 3], dtype=dtype) - targets = np.asarray([4, 5, 6], dtype=dtype) + labels = np.asarray([4, 5, 6], dtype=dtype) truth = np.asarray( [[0, 0, 0, 0, 0, 0, 0], @@ -68,7 +68,7 @@ class ConfusionMatrixTest(tf.test.TestCase): self._testConfMatrix( predictions=predictions, - targets=targets, + labels=labels, truth=truth) def testInt32DifferentLabels(self, dtype=np.int32): @@ -79,7 +79,7 @@ class ConfusionMatrixTest(tf.test.TestCase): def _testMultipleLabels(self, dtype): predictions = np.asarray([1, 1, 2, 3, 5, 6, 1, 2, 3, 4], dtype=dtype) - targets = np.asarray([1, 1, 2, 3, 5, 1, 3, 6, 3, 1], dtype=dtype) + labels = np.asarray([1, 1, 2, 3, 5, 1, 3, 6, 3, 1], dtype=dtype) truth = np.asarray( [[0, 0, 0, 0, 0, 0, 0], @@ -93,7 +93,7 @@ class ConfusionMatrixTest(tf.test.TestCase): self._testConfMatrix( predictions=predictions, - targets=targets, + labels=labels, truth=truth) def testInt32MultipleLabels(self, dtype=np.int32): @@ -104,24 +104,24 @@ class ConfusionMatrixTest(tf.test.TestCase): def testInvalidRank(self): predictions = np.asarray([[1, 2, 3]]) - targets = np.asarray([1, 2, 3]) + labels = np.asarray([1, 2, 3]) self.assertRaisesRegexp( ValueError, "are not compatible", - tf.contrib.metrics.confusion_matrix, predictions, targets) + tf.contrib.metrics.confusion_matrix, predictions, labels) predictions = np.asarray([1, 2, 3]) - targets = np.asarray([[1, 2, 3]]) + labels = np.asarray([[1, 2, 3]]) self.assertRaisesRegexp( ValueError, "are not compatible", - tf.contrib.metrics.confusion_matrix, predictions, targets) + tf.contrib.metrics.confusion_matrix, predictions, labels) def testInputDifferentSize(self): predictions = np.asarray([1, 2, 3]) - targets = np.asarray([1, 2]) + labels = np.asarray([1, 2]) self.assertRaisesRegexp( ValueError, "are not compatible", - tf.contrib.metrics.confusion_matrix, predictions, targets) + tf.contrib.metrics.confusion_matrix, predictions, labels) if __name__ == '__main__': tf.test.main() diff --git a/tensorflow/contrib/metrics/python/ops/confusion_matrix_ops.py b/tensorflow/contrib/metrics/python/ops/confusion_matrix_ops.py index f8d28788ca..0f9baf2ee7 100644 --- a/tensorflow/contrib/metrics/python/ops/confusion_matrix_ops.py +++ b/tensorflow/contrib/metrics/python/ops/confusion_matrix_ops.py @@ -25,14 +25,17 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops -def confusion_matrix(predictions, targets, num_classes=None, name=None): - """Computes the confusion matrix from predictions and targets +"""Confusion matrix related metrics.""" + + +def confusion_matrix(predictions, labels, num_classes=None, name=None): + """Computes the confusion matrix from predictions and labels Calculate the Confusion Matrix for a pair of prediction and - target 1-D int arrays. + label 1-D int arrays. Considering a prediction array such as: `[1, 2, 3]` - And a target array such as: `[2, 2, 3]` + And a label array such as: `[2, 2, 3]` The confusion matrix returned would be the following one: [[0, 0, 0] @@ -41,18 +44,18 @@ def confusion_matrix(predictions, targets, num_classes=None, name=None): [0, 0, 1]] Where the matrix rows represent the prediction labels and the columns - represents the target labels. The confusion matrix is always a 2-D array + represents the real labels. The confusion matrix is always a 2-D array of shape [n, n], where n is the number of valid labels for a given - classification task. Both prediction and target must be 1-D arrays of + classification task. Both prediction and labels must be 1-D arrays of the same shape in order for this function to work. Args: predictions: A 1-D array represeting the predictions for a given classification. - targets: A 1-D represeting the real labels for the classification task. + labels: A 1-D represeting the real labels for the classification task. num_classes: The possible number of labels the classification task can have. If this value is not provided, it will be calculated - using both predictions and targets array. + using both predictions and labels array. name: Scope name. Returns: @@ -60,22 +63,21 @@ def confusion_matrix(predictions, targets, num_classes=None, name=None): possible labels in the classification task. Raises: - ValueError: If both predictions and targets are not 1-D vectors and do not + ValueError: If both predictions and labels are not 1-D vectors and do not have the same size. """ - with ops.op_scope([predictions, targets, num_classes], name, + with ops.op_scope([predictions, labels, num_classes], name, 'confusion_matrix') as name: predictions = ops.convert_to_tensor( predictions, name='predictions', dtype=dtypes.int64) - targets = ops.convert_to_tensor(targets, name='targets', dtype=dtypes.int64) + labels = ops.convert_to_tensor(labels, name='labels', dtype=dtypes.int64) if num_classes is None: num_classes = math_ops.maximum(math_ops.reduce_max(predictions), - math_ops.reduce_max(targets)) + 1 + math_ops.reduce_max(labels)) + 1 shape = array_ops.pack([num_classes, num_classes]) - indices = array_ops.transpose( - array_ops.pack([predictions, targets])) + indices = array_ops.transpose(array_ops.pack([predictions, labels])) values = array_ops.ones_like(predictions, dtype=dtypes.int32) cm_sparse = ops.SparseTensor( indices=indices, values=values, shape=shape) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index c9754f518c..1be21fa954 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1291,6 +1291,7 @@ tf_cc_tests_gpu( exclude = [ # Run by tests below "common_runtime/gpu/gpu_allocator_retry_test.cc", + "common_runtime/gpu/gpu_debug_allocator_test.cc", "common_runtime/gpu/gpu_stream_util_test.cc", "common_runtime/gpu/gpu_tracer_test.cc", ], @@ -1482,6 +1483,30 @@ tf_cc_test_gpu( ) tf_cc_test_gpu( + name = "common_runtime/gpu/gpu_debug_allocator_test.cc", + size = "medium", + args = ["\"--gtest_death_test_style=threadsafe\""], + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags(), + deps = [ + ":core_cpu", + ":core_cpu_internal", + ":direct_session", + ":framework", + ":framework_internal", + ":gpu_runtime", + ":lib", + ":lib_internal", + ":protos_all_cc", + ":test", + ":test_main", + ":testlib", + "//tensorflow/cc:cc_ops", + "//tensorflow/core/kernels:ops_util", + ], +) + +tf_cc_test_gpu( name = "common_runtime/gpu/gpu_stream_util_test.cc", size = "small", linkstatic = tf_kernel_tests_linkstatic(), diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index 35a16e8bfe..93ed0d8c32 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -621,16 +621,17 @@ LocalDevice* BaseGPUDeviceFactory::CreateGPUDevice( int64 total_memory, available_memory; CHECK(se->DeviceMemoryUsage(&available_memory, &total_memory)); - int64 allocated_memory = available_memory; + int64 allocated_memory; double config_memory_fraction = options.config.gpu_options().per_process_gpu_memory_fraction(); if (config_memory_fraction == 0) { + allocated_memory = available_memory; const int64 min_system_memory = MinSystemMemory(available_memory); if (min_system_memory < allocated_memory) { allocated_memory -= min_system_memory; } } else { - allocated_memory *= config_memory_fraction; + allocated_memory = total_memory * config_memory_fraction; } Bytes allocated_bytes = static_cast<Bytes>(allocated_memory); diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index 4e4b6d3703..3267a200f0 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -264,7 +264,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { // partitions_ is immutable after RegisterPartitions() call // finishes. RunPartitions() can access partitions_ safely without - // acquring locks. + // acquiring locks. std::vector<Part> partitions_; mutable mutex mu_; diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc index 0057db6967..e4e5610448 100644 --- a/tensorflow/core/kernels/conv_grad_ops.cc +++ b/tensorflow/core/kernels/conv_grad_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #define USE_EIGEN_TENSOR #define EIGEN_USE_THREADS +#include <algorithm> #include <vector> #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -839,13 +840,13 @@ class Conv2DSlowBackpropInputOp : public OpKernel { context->allocate_output(0, input_shape, &in_backprop)); const int padding_rows = - (padding_ == VALID) - ? 0 - : (output_rows - 1) * stride_rows + filter_rows - input_rows; + (padding_ == VALID) ? 0 + : std::max<int>(0, (output_rows - 1) * stride_rows + + filter_rows - input_rows); const int padding_cols = - (padding_ == VALID) - ? 0 - : (output_cols - 1) * stride_cols + filter_cols - input_cols; + (padding_ == VALID) ? 0 + : std::max<int>(0, (output_cols - 1) * stride_cols + + filter_cols - input_cols); // TODO(keveman): cuDNN only supports equal padding on both sides, so only // calling it when that is true. Remove this check when (if?) cuDNN starts @@ -889,6 +890,7 @@ class Conv2DSlowBackpropInputOp : public OpKernel { context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, ", n=", n, ", k=", k)); } + return; } @@ -1058,6 +1060,7 @@ class Conv2DSlowBackpropInputOp : public OpKernel { "cuDNN Backward Data function launch failure : input shape(", input_shape.DebugString(), ") filter shape(", filter_shape.DebugString(), ")")); + return; } if (rows_odd || cols_odd) { @@ -1148,13 +1151,13 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { context->allocate_output(0, filter_shape, &filter_backprop)); const int padding_rows = - (padding_ == VALID) - ? 0 - : (output_rows - 1) * stride_rows + filter_rows - input_rows; + (padding_ == VALID) ? 0 + : std::max<int>(0, (output_rows - 1) * stride_rows + + filter_rows - input_rows); const int padding_cols = - (padding_ == VALID) - ? 0 - : (output_cols - 1) * stride_cols + filter_cols - input_cols; + (padding_ == VALID) ? 0 + : std::max<int>(0, (output_cols - 1) * stride_cols + + filter_cols - input_cols); // TODO(zhengxq): cuDNN only supports equal padding on both sides, so only // calling it when that is true. Remove this check when (if?) cuDNN starts @@ -1387,6 +1390,7 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { "cuDNN Backward Filter function launch failure : input shape(", input_shape.DebugString(), ") filter shape(", filter_shape.DebugString(), ")")); + return; } auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index de8d8e784d..4eb67268c2 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -312,6 +312,7 @@ struct LaunchConvOp<GPUDevice, T> { ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, ", n=", n, ", k=", k)); } + return; } int padding_rows = 0; diff --git a/tensorflow/core/kernels/cudnn_pooling_gpu.h b/tensorflow/core/kernels/cudnn_pooling_gpu.h index d1982299e2..59ae794737 100644 --- a/tensorflow/core/kernels/cudnn_pooling_gpu.h +++ b/tensorflow/core/kernels/cudnn_pooling_gpu.h @@ -18,6 +18,8 @@ limitations under the License. #ifndef TENSORFLOW_KERNELS_CUDNN_POOLING_GPU_H_ #define TENSORFLOW_KERNELS_CUDNN_POOLING_GPU_H_ +#include <array> + #include "tensorflow/core/framework/op_kernel.h" #if GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_abs.cc b/tensorflow/core/kernels/cwise_op_abs.cc index ff9db97fe9..e745143339 100644 --- a/tensorflow/core/kernels/cwise_op_abs.cc +++ b/tensorflow/core/kernels/cwise_op_abs.cc @@ -19,8 +19,7 @@ namespace tensorflow { REGISTER5(UnaryOp, CPU, "Abs", functor::abs, float, Eigen::half, double, int32, int64); #if !defined(__ANDROID__) -REGISTER_KERNEL_BUILDER(Name("ComplexAbs").Device(DEVICE_CPU), - UnaryOp<CPUDevice, functor::abs<complex64>>); +REGISTER2(UnaryOp, CPU, "ComplexAbs", functor::abs, complex64, complex128); #endif #if GOOGLE_CUDA REGISTER4(UnaryOp, GPU, "Abs", functor::abs, float, Eigen::half, double, int64); diff --git a/tensorflow/core/kernels/cwise_op_add.cc b/tensorflow/core/kernels/cwise_op_add.cc index 1457b74e6f..4aa3761ffe 100644 --- a/tensorflow/core/kernels/cwise_op_add.cc +++ b/tensorflow/core/kernels/cwise_op_add.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER9(BinaryOp, CPU, "Add", functor::add, float, Eigen::half, double, int32, - int64, int8, int16, complex64, string); +REGISTER10(BinaryOp, CPU, "Add", functor::add, float, Eigen::half, double, int32, + int64, int8, int16, complex64, complex128, string); #if GOOGLE_CUDA REGISTER4(BinaryOp, GPU, "Add", functor::add, float, Eigen::half, double, int64); diff --git a/tensorflow/core/kernels/cwise_op_complex.cc b/tensorflow/core/kernels/cwise_op_complex.cc index d7e4638ff1..f8796f72d8 100644 --- a/tensorflow/core/kernels/cwise_op_complex.cc +++ b/tensorflow/core/kernels/cwise_op_complex.cc @@ -16,10 +16,20 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER_KERNEL_BUILDER(Name("Complex").Device(DEVICE_CPU), - BinaryOp<CPUDevice, functor::make_complex<float>>); +#define REGISTER_COMPLEX(D, R, C) \ + REGISTER_KERNEL_BUILDER(Name("Complex") \ + .Device(DEVICE_##D) \ + .TypeConstraint<R>("T") \ + .TypeConstraint<C>("Tout"), \ + BinaryOp<D##Device, functor::make_complex<R>>); + +REGISTER_COMPLEX(CPU, float, complex64); +REGISTER_COMPLEX(CPU, double, complex128); + #if GOOGLE_CUDA -REGISTER_KERNEL_BUILDER(Name("Complex").Device(DEVICE_GPU), - BinaryOp<GPUDevice, functor::make_complex<float>>); +REGISTER_COMPLEX(GPU, float, complex64); +REGISTER_COMPLEX(GPU, double, complex128); #endif + +#undef REGISTER_COMPLEX } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_conj.cc b/tensorflow/core/kernels/cwise_op_conj.cc index 37ab71dd65..e4cb6aabe3 100644 --- a/tensorflow/core/kernels/cwise_op_conj.cc +++ b/tensorflow/core/kernels/cwise_op_conj.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER_KERNEL_BUILDER(Name("Conj").Device(DEVICE_CPU), - UnaryOp<CPUDevice, functor::conj<complex64>>); + +REGISTER2(UnaryOp, CPU, "Conj", functor::conj, complex64, complex128); #if GOOGLE_CUDA // REGISTER_KERNEL_BUILDER(Name("Conj").Device(DEVICE_GPU), // UnaryOp<GPUDevice, functor::conj<complex64>>); diff --git a/tensorflow/core/kernels/cwise_op_cos.cc b/tensorflow/core/kernels/cwise_op_cos.cc index 6958fa22b8..cd7dd976db 100644 --- a/tensorflow/core/kernels/cwise_op_cos.cc +++ b/tensorflow/core/kernels/cwise_op_cos.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER4(UnaryOp, CPU, "Cos", functor::cos, float, Eigen::half, double, - complex64); +REGISTER5(UnaryOp, CPU, "Cos", functor::cos, float, Eigen::half, double, + complex64, complex128); #if GOOGLE_CUDA REGISTER3(UnaryOp, GPU, "Cos", functor::cos, float, Eigen::half, double); #endif diff --git a/tensorflow/core/kernels/cwise_op_div.cc b/tensorflow/core/kernels/cwise_op_div.cc index c4899473e6..617d663e8d 100644 --- a/tensorflow/core/kernels/cwise_op_div.cc +++ b/tensorflow/core/kernels/cwise_op_div.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER4(BinaryOp, CPU, "Div", functor::div, float, Eigen::half, double, - complex64); +REGISTER5(BinaryOp, CPU, "Div", functor::div, float, Eigen::half, double, + complex64, complex128); REGISTER4(BinaryOp, CPU, "Div", functor::safe_div, uint8, int16, int32, int64); #if GOOGLE_CUDA REGISTER6(BinaryOp, GPU, "Div", functor::div, float, Eigen::half, double, uint8, diff --git a/tensorflow/core/kernels/cwise_op_equal_to.cc b/tensorflow/core/kernels/cwise_op_equal_to.cc index c1d5b2f4ed..d9cd413534 100644 --- a/tensorflow/core/kernels/cwise_op_equal_to.cc +++ b/tensorflow/core/kernels/cwise_op_equal_to.cc @@ -16,8 +16,9 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER11(BinaryOp, CPU, "Equal", functor::equal_to, float, Eigen::half, - double, uint8, int8, int16, int32, int64, complex64, string, bool); +REGISTER12(BinaryOp, CPU, "Equal", functor::equal_to, float, Eigen::half, + double, uint8, int8, int16, int32, int64, complex64, complex128, + string, bool); #if GOOGLE_CUDA REGISTER8(BinaryOp, GPU, "Equal", functor::equal_to, float, Eigen::half, double, uint8, int8, int16, int64, bool); diff --git a/tensorflow/core/kernels/cwise_op_exp.cc b/tensorflow/core/kernels/cwise_op_exp.cc index 2d7df89149..4191018174 100644 --- a/tensorflow/core/kernels/cwise_op_exp.cc +++ b/tensorflow/core/kernels/cwise_op_exp.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER4(UnaryOp, CPU, "Exp", functor::exp, float, Eigen::half, double, - complex64); +REGISTER5(UnaryOp, CPU, "Exp", functor::exp, float, Eigen::half, double, + complex64, complex128); #if GOOGLE_CUDA REGISTER3(UnaryOp, GPU, "Exp", functor::exp, float, Eigen::half, double); #endif diff --git a/tensorflow/core/kernels/cwise_op_gpu_complex.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_complex.cu.cc index cb6fafaad3..c7433c2de9 100644 --- a/tensorflow/core/kernels/cwise_op_gpu_complex.cu.cc +++ b/tensorflow/core/kernels/cwise_op_gpu_complex.cu.cc @@ -19,7 +19,7 @@ limitations under the License. namespace tensorflow { namespace functor { -DEFINE_BINARY1(make_complex, float); +DEFINE_BINARY2(make_complex, float, double); } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_gpu_equal_to.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_equal_to.cu.cc index c183425254..0d2c299e75 100644 --- a/tensorflow/core/kernels/cwise_op_gpu_equal_to.cu.cc +++ b/tensorflow/core/kernels/cwise_op_gpu_equal_to.cu.cc @@ -19,8 +19,8 @@ limitations under the License. namespace tensorflow { namespace functor { -DEFINE_BINARY9(equal_to, float, Eigen::half, double, uint8, int8, int16, int64, - complex64, bool); +DEFINE_BINARY10(equal_to, float, Eigen::half, double, uint8, int8, int16, int64, + complex64, complex128, bool); } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_gpu_imag.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_imag.cu.cc index 58cdccfb30..f43aee786c 100644 --- a/tensorflow/core/kernels/cwise_op_gpu_imag.cu.cc +++ b/tensorflow/core/kernels/cwise_op_gpu_imag.cu.cc @@ -19,7 +19,7 @@ limitations under the License. namespace tensorflow { namespace functor { -DEFINE_UNARY1(get_imag, complex64); +DEFINE_UNARY2(get_imag, complex64, complex128); } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_gpu_not_equal_to.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_not_equal_to.cu.cc index 171c1f9e40..ecd91cfd13 100644 --- a/tensorflow/core/kernels/cwise_op_gpu_not_equal_to.cu.cc +++ b/tensorflow/core/kernels/cwise_op_gpu_not_equal_to.cu.cc @@ -19,8 +19,8 @@ limitations under the License. namespace tensorflow { namespace functor { -DEFINE_BINARY9(not_equal_to, float, Eigen::half, double, uint8, int8, int16, - int64, complex64, bool); +DEFINE_BINARY10(not_equal_to, float, Eigen::half, double, uint8, int8, int16, + int64, complex64, complex128, bool); } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_gpu_real.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_real.cu.cc index 6ad3505422..f8cbb33703 100644 --- a/tensorflow/core/kernels/cwise_op_gpu_real.cu.cc +++ b/tensorflow/core/kernels/cwise_op_gpu_real.cu.cc @@ -19,7 +19,7 @@ limitations under the License. namespace tensorflow { namespace functor { -DEFINE_UNARY1(get_real, complex64); +DEFINE_UNARY2(get_real, complex64, complex128); } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc index f29d07e2db..6867a76c98 100644 --- a/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc +++ b/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc @@ -77,6 +77,7 @@ SELECT_FUNCTOR(double); SELECT_FUNCTOR(int32); SELECT_FUNCTOR(int64); SELECT_FUNCTOR(complex64); +SELECT_FUNCTOR(complex128); #undef SELECT_FUNCTOR diff --git a/tensorflow/core/kernels/cwise_op_gpu_zeta.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_zeta.cu.cc new file mode 100644 index 0000000000..3bbc4ec39a --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_zeta.cu.cc @@ -0,0 +1,27 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_BINARY2(zeta, float, double); +DEFINE_BINARY2(polygamma, float, double); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_imag.cc b/tensorflow/core/kernels/cwise_op_imag.cc index 7366f7342f..18bf8d83fe 100644 --- a/tensorflow/core/kernels/cwise_op_imag.cc +++ b/tensorflow/core/kernels/cwise_op_imag.cc @@ -16,10 +16,20 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER_KERNEL_BUILDER(Name("Imag").Device(DEVICE_CPU), - UnaryOp<CPUDevice, functor::get_imag<complex64>>); +#define REGISTER_COMPLEX(D, R, C) \ + REGISTER_KERNEL_BUILDER(Name("Imag") \ + .Device(DEVICE_##D) \ + .TypeConstraint<C>("T") \ + .TypeConstraint<R>("Tout"), \ + UnaryOp<D##Device, functor::get_imag<C>>); + +REGISTER_COMPLEX(CPU, float, complex64); +REGISTER_COMPLEX(CPU, double, complex128); + #if GOOGLE_CUDA -REGISTER_KERNEL_BUILDER(Name("Imag").Device(DEVICE_GPU), - UnaryOp<GPUDevice, functor::get_imag<complex64>>); +REGISTER_COMPLEX(GPU, float, complex64); +REGISTER_COMPLEX(GPU, double, complex128); #endif + +#undef REGISTER_COMPLEX } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_inverse.cc b/tensorflow/core/kernels/cwise_op_inverse.cc index 05834996c0..168ab268b4 100644 --- a/tensorflow/core/kernels/cwise_op_inverse.cc +++ b/tensorflow/core/kernels/cwise_op_inverse.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER4(UnaryOp, CPU, "Inv", functor::inverse, float, Eigen::half, double, - complex64); +REGISTER5(UnaryOp, CPU, "Inv", functor::inverse, float, Eigen::half, double, + complex64, complex128); #if GOOGLE_CUDA REGISTER4(UnaryOp, GPU, "Inv", functor::inverse, float, Eigen::half, double, int64); diff --git a/tensorflow/core/kernels/cwise_op_log.cc b/tensorflow/core/kernels/cwise_op_log.cc index ab6a1f9778..587b4f9eca 100644 --- a/tensorflow/core/kernels/cwise_op_log.cc +++ b/tensorflow/core/kernels/cwise_op_log.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER4(UnaryOp, CPU, "Log", functor::log, float, Eigen::half, double, - complex64); +REGISTER5(UnaryOp, CPU, "Log", functor::log, float, Eigen::half, double, + complex64, complex128); #if GOOGLE_CUDA REGISTER3(UnaryOp, GPU, "Log", functor::log, float, Eigen::half, double); #endif diff --git a/tensorflow/core/kernels/cwise_op_mul.cc b/tensorflow/core/kernels/cwise_op_mul.cc index 395cea5d7f..e1783950a7 100644 --- a/tensorflow/core/kernels/cwise_op_mul.cc +++ b/tensorflow/core/kernels/cwise_op_mul.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER9(BinaryOp, CPU, "Mul", functor::mul, float, Eigen::half, double, uint8, - int8, int16, int32, int64, complex64); +REGISTER10(BinaryOp, CPU, "Mul", functor::mul, float, Eigen::half, double, uint8, + int8, int16, int32, int64, complex64, complex128); #if GOOGLE_CUDA REGISTER7(BinaryOp, GPU, "Mul", functor::mul, float, Eigen::half, double, uint8, int8, int16, int64); diff --git a/tensorflow/core/kernels/cwise_op_neg.cc b/tensorflow/core/kernels/cwise_op_neg.cc index 2b672285b3..7c580c57e1 100644 --- a/tensorflow/core/kernels/cwise_op_neg.cc +++ b/tensorflow/core/kernels/cwise_op_neg.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER6(UnaryOp, CPU, "Neg", functor::neg, float, Eigen::half, double, int32, - complex64, int64); +REGISTER7(UnaryOp, CPU, "Neg", functor::neg, float, Eigen::half, double, int32, + complex64, int64, complex128); #if GOOGLE_CUDA REGISTER4(UnaryOp, GPU, "Neg", functor::neg, float, Eigen::half, double, int64); diff --git a/tensorflow/core/kernels/cwise_op_not_equal_to.cc b/tensorflow/core/kernels/cwise_op_not_equal_to.cc index da4a40ad9d..a4fbd8a280 100644 --- a/tensorflow/core/kernels/cwise_op_not_equal_to.cc +++ b/tensorflow/core/kernels/cwise_op_not_equal_to.cc @@ -16,8 +16,9 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER11(BinaryOp, CPU, "NotEqual", functor::not_equal_to, float, Eigen::half, - double, uint8, int8, int16, int32, int64, complex64, string, bool); +REGISTER12(BinaryOp, CPU, "NotEqual", functor::not_equal_to, float, Eigen::half, + double, uint8, int8, int16, int32, int64, complex64, complex128, + string, bool); #if GOOGLE_CUDA REGISTER8(BinaryOp, GPU, "NotEqual", functor::not_equal_to, float, Eigen::half, double, uint8, int8, int16, int64, bool); diff --git a/tensorflow/core/kernels/cwise_op_pow.cc b/tensorflow/core/kernels/cwise_op_pow.cc index 8bb71c03d8..3c136d7e2c 100644 --- a/tensorflow/core/kernels/cwise_op_pow.cc +++ b/tensorflow/core/kernels/cwise_op_pow.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER6(BinaryOp, CPU, "Pow", functor::pow, float, Eigen::half, double, int32, - int64, complex64); +REGISTER7(BinaryOp, CPU, "Pow", functor::pow, float, Eigen::half, double, int32, + int64, complex64, complex128); #if GOOGLE_CUDA REGISTER4(BinaryOp, GPU, "Pow", functor::pow, float, Eigen::half, double, int64); diff --git a/tensorflow/core/kernels/cwise_op_real.cc b/tensorflow/core/kernels/cwise_op_real.cc index f8619b6c69..214fb96291 100644 --- a/tensorflow/core/kernels/cwise_op_real.cc +++ b/tensorflow/core/kernels/cwise_op_real.cc @@ -16,10 +16,21 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER_KERNEL_BUILDER(Name("Real").Device(DEVICE_CPU), - UnaryOp<CPUDevice, functor::get_real<complex64>>); + +#define REGISTER_COMPLEX(D, R, C) \ + REGISTER_KERNEL_BUILDER(Name("Real") \ + .Device(DEVICE_##D) \ + .TypeConstraint<C>("T") \ + .TypeConstraint<R>("Tout"), \ + UnaryOp<D##Device, functor::get_real<C>>); + +REGISTER_COMPLEX(CPU, float, complex64); +REGISTER_COMPLEX(CPU, double, complex128); + #if GOOGLE_CUDA -REGISTER_KERNEL_BUILDER(Name("Real").Device(DEVICE_GPU), - UnaryOp<GPUDevice, functor::get_real<complex64>>); +REGISTER_COMPLEX(GPU, float, complex64); +REGISTER_COMPLEX(GPU, double, complex128); #endif + +#undef REGISTER_COMPLEX } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_rsqrt.cc b/tensorflow/core/kernels/cwise_op_rsqrt.cc index ff3e8d778f..ff6ae3bd2f 100644 --- a/tensorflow/core/kernels/cwise_op_rsqrt.cc +++ b/tensorflow/core/kernels/cwise_op_rsqrt.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER4(UnaryOp, CPU, "Rsqrt", functor::rsqrt, float, Eigen::half, double, - complex64); +REGISTER5(UnaryOp, CPU, "Rsqrt", functor::rsqrt, float, Eigen::half, double, + complex64, complex128); #if GOOGLE_CUDA REGISTER3(UnaryOp, GPU, "Rsqrt", functor::rsqrt, float, Eigen::half, double); #endif diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc index f9df4ad94e..4308bfd253 100644 --- a/tensorflow/core/kernels/cwise_op_select.cc +++ b/tensorflow/core/kernels/cwise_op_select.cc @@ -133,6 +133,7 @@ REGISTER_SELECT_GPU(double); REGISTER_SELECT_GPU(int32); REGISTER_SELECT_GPU(int64); REGISTER_SELECT_GPU(complex64); +REGISTER_SELECT_GPU(complex128); #undef REGISTER_SELECT_GPU diff --git a/tensorflow/core/kernels/cwise_op_sigmoid.cc b/tensorflow/core/kernels/cwise_op_sigmoid.cc index 574866b3f0..3b51a293c7 100644 --- a/tensorflow/core/kernels/cwise_op_sigmoid.cc +++ b/tensorflow/core/kernels/cwise_op_sigmoid.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER4(UnaryOp, CPU, "Sigmoid", functor::sigmoid, float, Eigen::half, double, - complex64); +REGISTER5(UnaryOp, CPU, "Sigmoid", functor::sigmoid, float, Eigen::half, double, + complex64, complex128); #if GOOGLE_CUDA REGISTER3(UnaryOp, GPU, "Sigmoid", functor::sigmoid, float, Eigen::half, double); diff --git a/tensorflow/core/kernels/cwise_op_sign.cc b/tensorflow/core/kernels/cwise_op_sign.cc index 7107970332..1d17a4a066 100644 --- a/tensorflow/core/kernels/cwise_op_sign.cc +++ b/tensorflow/core/kernels/cwise_op_sign.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER6(UnaryOp, CPU, "Sign", functor::sign, float, double, int32, int64, - complex64, Eigen::half); +REGISTER7(UnaryOp, CPU, "Sign", functor::sign, float, double, int32, int64, + complex64, Eigen::half, complex128); #if GOOGLE_CUDA REGISTER4(UnaryOp, GPU, "Sign", functor::sign, float, Eigen::half, double, int64); diff --git a/tensorflow/core/kernels/cwise_op_sin.cc b/tensorflow/core/kernels/cwise_op_sin.cc index 123a251c62..85ceab95d0 100644 --- a/tensorflow/core/kernels/cwise_op_sin.cc +++ b/tensorflow/core/kernels/cwise_op_sin.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER4(UnaryOp, CPU, "Sin", functor::sin, float, Eigen::half, double, - complex64); +REGISTER5(UnaryOp, CPU, "Sin", functor::sin, float, Eigen::half, double, + complex64, complex128); #if GOOGLE_CUDA REGISTER3(UnaryOp, GPU, "Sin", functor::sin, float, Eigen::half, double); #endif diff --git a/tensorflow/core/kernels/cwise_op_sqrt.cc b/tensorflow/core/kernels/cwise_op_sqrt.cc index daaa4c55c6..652e12851b 100644 --- a/tensorflow/core/kernels/cwise_op_sqrt.cc +++ b/tensorflow/core/kernels/cwise_op_sqrt.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER4(UnaryOp, CPU, "Sqrt", functor::sqrt, float, Eigen::half, double, - complex64); +REGISTER5(UnaryOp, CPU, "Sqrt", functor::sqrt, float, Eigen::half, double, + complex64, complex128); #if GOOGLE_CUDA REGISTER3(UnaryOp, GPU, "Sqrt", functor::sqrt, float, Eigen::half, double); #endif diff --git a/tensorflow/core/kernels/cwise_op_square.cc b/tensorflow/core/kernels/cwise_op_square.cc index 80c8423cf9..f03d2f11d4 100644 --- a/tensorflow/core/kernels/cwise_op_square.cc +++ b/tensorflow/core/kernels/cwise_op_square.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER6(UnaryOp, CPU, "Square", functor::square, float, Eigen::half, double, - int32, complex64, int64); +REGISTER7(UnaryOp, CPU, "Square", functor::square, float, Eigen::half, double, + int32, complex64, complex128, int64); #if GOOGLE_CUDA REGISTER4(UnaryOp, GPU, "Square", functor::square, float, Eigen::half, double, int64); diff --git a/tensorflow/core/kernels/cwise_op_sub.cc b/tensorflow/core/kernels/cwise_op_sub.cc index 245f13ad68..ac744aeb2b 100644 --- a/tensorflow/core/kernels/cwise_op_sub.cc +++ b/tensorflow/core/kernels/cwise_op_sub.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER6(BinaryOp, CPU, "Sub", functor::sub, float, Eigen::half, double, int32, - int64, complex64); +REGISTER7(BinaryOp, CPU, "Sub", functor::sub, float, Eigen::half, double, int32, + int64, complex64, complex128); #if GOOGLE_CUDA REGISTER4(BinaryOp, GPU, "Sub", functor::sub, float, Eigen::half, double, int64); diff --git a/tensorflow/core/kernels/cwise_op_tanh.cc b/tensorflow/core/kernels/cwise_op_tanh.cc index a42b6fe6e4..edb89f799c 100644 --- a/tensorflow/core/kernels/cwise_op_tanh.cc +++ b/tensorflow/core/kernels/cwise_op_tanh.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER4(UnaryOp, CPU, "Tanh", functor::tanh, float, Eigen::half, double, - complex64); +REGISTER5(UnaryOp, CPU, "Tanh", functor::tanh, float, Eigen::half, double, + complex64, complex128); #if GOOGLE_CUDA REGISTER3(UnaryOp, GPU, "Tanh", functor::tanh, float, Eigen::half, double); #endif diff --git a/tensorflow/core/kernels/cwise_op_zeta.cc b/tensorflow/core/kernels/cwise_op_zeta.cc new file mode 100644 index 0000000000..a960475396 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_zeta.cc @@ -0,0 +1,21 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER2(BinaryOp, CPU, "Zeta", functor::zeta, float, double); +REGISTER2(BinaryOp, CPU, "Polygamma", functor::polygamma, float, double); +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h index 71d65f806b..7d45fb8511 100644 --- a/tensorflow/core/kernels/cwise_ops.h +++ b/tensorflow/core/kernels/cwise_ops.h @@ -480,6 +480,12 @@ template <typename T> struct igammac : base<T, Eigen::internal::scalar_igammac_op<T>> {}; template <typename T> +struct zeta : base<T, Eigen::internal::scalar_zeta_op<T>> {}; + +template <typename T> +struct polygamma : base<T, Eigen::internal::scalar_polygamma_op<T>> {}; + +template <typename T> struct squared_difference : base<T, Eigen::internal::scalar_compose_op< T, Eigen::internal::scalar_square_op<T>, diff --git a/tensorflow/core/kernels/cwise_ops_common.h b/tensorflow/core/kernels/cwise_ops_common.h index f86f206cb1..05fb4afb63 100644 --- a/tensorflow/core/kernels/cwise_ops_common.h +++ b/tensorflow/core/kernels/cwise_ops_common.h @@ -386,6 +386,9 @@ struct UnaryFunctor<CPUDevice, Functor> { REGISTER(OP, D, N, F, T0) #define REGISTER11(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) \ REGISTER(OP, D, N, F, T0) +#define REGISTER12(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, \ + T11) \ + REGISTER(OP, D, N, F, T0) #else // !defined(__ANDROID_TYPES_SLIM__) #define REGISTER2(OP, D, N, F, T0, T1) \ REGISTER(OP, D, N, F, T0) \ @@ -417,6 +420,10 @@ struct UnaryFunctor<CPUDevice, Functor> { #define REGISTER11(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) \ REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) \ REGISTER6(OP, D, N, F, T5, T6, T7, T8, T9, T10) +#define REGISTER12(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, \ + T11) \ + REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) \ + REGISTER6(OP, D, N, F, T6, T7, T8, T9, T10, T11) #endif // defined(__ANDROID_TYPES_SLIM__) } // end namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h b/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h index d91d0faa86..6b23fb5785 100644 --- a/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h +++ b/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h @@ -34,6 +34,7 @@ namespace functor { typedef Eigen::GpuDevice GPUDevice; typedef std::complex<float> complex64; +typedef std::complex<double> complex128; // Partial specialization of UnaryFunctor<Device=GPUDevice, Functor>. template <typename Functor> @@ -149,6 +150,9 @@ struct BinaryFunctor<GPUDevice, Functor, NDIMS, has_errors> { #define DEFINE_BINARY9(F, T0, T1, T2, T3, T4, T5, T6, T7, T8) \ DEFINE_BINARY4(F, T0, T1, T2, T3); \ DEFINE_BINARY5(F, T4, T5, T6, T7, T8) +#define DEFINE_BINARY10(F, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9) \ + DEFINE_BINARY5(F, T0, T1, T2, T3, T4); \ + DEFINE_BINARY5(F, T5, T6, T7, T8, T9) } // end namespace functor } // end namespace tensorflow diff --git a/tensorflow/core/kernels/diag_op.cc b/tensorflow/core/kernels/diag_op.cc index df738f00cb..252fc338b9 100644 --- a/tensorflow/core/kernels/diag_op.cc +++ b/tensorflow/core/kernels/diag_op.cc @@ -122,6 +122,7 @@ REGISTER_DIAGOP(double); REGISTER_DIAGOP(float); REGISTER_DIAGOP(int32); REGISTER_DIAGOP(int64); +REGISTER_DIAGOP(complex64); #undef REGISTER_DIAGOP @@ -188,6 +189,7 @@ REGISTER_DIAGPARTOP(double); REGISTER_DIAGPARTOP(float); REGISTER_DIAGPARTOP(int32); REGISTER_DIAGPARTOP(int64); +REGISTER_DIAGPARTOP(complex64); #undef REGISTER_DIAGPARTOP diff --git a/tensorflow/core/kernels/edit_distance_op.cc b/tensorflow/core/kernels/edit_distance_op.cc index b1953adeae..b4d14e8c62 100644 --- a/tensorflow/core/kernels/edit_distance_op.cc +++ b/tensorflow/core/kernels/edit_distance_op.cc @@ -179,7 +179,7 @@ class EditDistanceOp : public OpKernel { if (g_truth == g_hypothesis) { auto loc = std::inner_product(g_truth.begin(), g_truth.end(), - output_strides.begin(), 0); + output_strides.begin(), int64{0}); output_t(loc) = gtl::LevenshteinDistance<T>(truth_seq, hypothesis_seq, cmp); if (normalize_) output_t(loc) /= truth_seq.size(); @@ -188,13 +188,13 @@ class EditDistanceOp : public OpKernel { ++truth_iter; } else if (g_truth > g_hypothesis) { // missing truth @ this hypothesis auto loc = std::inner_product(g_hypothesis.begin(), g_hypothesis.end(), - output_strides.begin(), 0); + output_strides.begin(), int64{0}); output_t(loc) = hypothesis_seq.size(); if (normalize_) output_t(loc) /= 0.0; ++hypothesis_iter; } else { // missing hypothesis @ this truth auto loc = std::inner_product(g_truth.begin(), g_truth.end(), - output_strides.begin(), 0); + output_strides.begin(), int64{0}); output_t(loc) = (normalize_) ? 1.0 : truth_seq.size(); ++truth_iter; } @@ -204,7 +204,7 @@ class EditDistanceOp : public OpKernel { std::vector<int64> g_hypothesis = hypothesis_j.group(); auto hypothesis_seq = hypothesis_j.values<T>(); auto loc = std::inner_product(g_hypothesis.begin(), g_hypothesis.end(), - output_strides.begin(), 0); + output_strides.begin(), int64{0}); output_t(loc) = hypothesis_seq.size(); if (normalize_) output_t(loc) /= 0.0; ++hypothesis_iter; @@ -214,7 +214,7 @@ class EditDistanceOp : public OpKernel { std::vector<int64> g_truth = truth_i.group(); auto truth_seq = truth_i.values<T>(); auto loc = std::inner_product(g_truth.begin(), g_truth.end(), - output_strides.begin(), 0); + output_strides.begin(), int64{0}); output_t(loc) = (normalize_) ? 1.0 : truth_seq.size(); ++truth_iter; } diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc index 08b699c8fa..f6420aaad8 100644 --- a/tensorflow/core/kernels/matmul_op.cc +++ b/tensorflow/core/kernels/matmul_op.cc @@ -205,8 +205,9 @@ struct MatMulFunctor<CPUDevice, T> { REGISTER_CPU(float); REGISTER_CPU(double); REGISTER_CPU(int32); -REGISTER_CPU(complex64); REGISTER_CPU(Eigen::half); +REGISTER_CPU(complex64); +REGISTER_CPU(complex128); #if GOOGLE_CUDA REGISTER_GPU(float); // REGISTER_GPU(double); diff --git a/tensorflow/core/kernels/queue_base.cc b/tensorflow/core/kernels/queue_base.cc index 6886165f98..1ba130692a 100644 --- a/tensorflow/core/kernels/queue_base.cc +++ b/tensorflow/core/kernels/queue_base.cc @@ -371,6 +371,7 @@ Status QueueBase::CopySliceToElement(const Tensor& parent, Tensor* element, HANDLE_TYPE(DT_QINT32); HANDLE_TYPE(DT_QINT16); HANDLE_TYPE(DT_QUINT16); + HANDLE_TYPE(DT_COMPLEX128); #undef HANDLE_TYPE return errors::Unimplemented("CopySliceToElement Unhandled data type: ", parent.dtype()); @@ -399,6 +400,7 @@ Status QueueBase::CopyElementToSlice(const Tensor& element, Tensor* parent, HANDLE_TYPE(DT_QINT32); HANDLE_TYPE(DT_QINT16); HANDLE_TYPE(DT_QUINT16); + HANDLE_TYPE(DT_COMPLEX128); #undef HANDLE_TYPE return errors::Unimplemented("CopyElementToSlice Unhandled data type: ", element.dtype()); diff --git a/tensorflow/core/kernels/reduction_ops_gpu.cu.cc b/tensorflow/core/kernels/reduction_ops_gpu.cu.cc index 69694fbd86..8ff724f5eb 100644 --- a/tensorflow/core/kernels/reduction_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/reduction_ops_gpu.cu.cc @@ -90,6 +90,7 @@ DEFINE_FOR_ALL_REDUCERS(double); #undef DEFINE_FOR_ALL_REDUCERS DEFINE_FOR_TYPE_AND_R(complex64, Eigen::internal::SumReducer<complex64>); +DEFINE_FOR_TYPE_AND_R(complex128, Eigen::internal::SumReducer<complex128>); DEFINE_FOR_TYPE_AND_R(bool, Eigen::internal::AndReducer); DEFINE_FOR_TYPE_AND_R(bool, Eigen::internal::OrReducer); #undef DEFINE_FOR_TYPE_AND_R diff --git a/tensorflow/core/kernels/reduction_ops_sum.cc b/tensorflow/core/kernels/reduction_ops_sum.cc index 9661f6a523..72dd718754 100644 --- a/tensorflow/core/kernels/reduction_ops_sum.cc +++ b/tensorflow/core/kernels/reduction_ops_sum.cc @@ -22,14 +22,12 @@ namespace tensorflow { Name("Sum").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ ReductionOp<CPUDevice, type, Eigen::internal::SumReducer<type>>); TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS); -#undef REGISTER_CPU_KERNELS - // NOTE: We should have mean(complex64,int32), too. But that needs to // change Eigen::internal::MeanReducer to cast int to complex<float>. // We don't see immediate need of mean(complex64,int32) anyway. -REGISTER_KERNEL_BUILDER( - Name("Sum").Device(DEVICE_CPU).TypeConstraint<complex64>("T"), - ReductionOp<CPUDevice, complex64, Eigen::internal::SumReducer<complex64>>); +REGISTER_CPU_KERNELS(complex64); +REGISTER_CPU_KERNELS(complex128); +#undef REGISTER_CPU_KERNELS #if GOOGLE_CUDA @@ -43,15 +41,10 @@ REGISTER_KERNEL_BUILDER( REGISTER_GPU_KERNELS(Eigen::half); REGISTER_GPU_KERNELS(float); REGISTER_GPU_KERNELS(double); +REGISTER_GPU_KERNELS(complex64); +REGISTER_GPU_KERNELS(complex128); #undef REGISTER_GPU_KERNELS -REGISTER_KERNEL_BUILDER( - Name("Sum") - .Device(DEVICE_GPU) - .TypeConstraint<complex64>("T") - .HostMemory("reduction_indices"), - ReductionOp<GPUDevice, complex64, Eigen::internal::SumReducer<complex64>>); - // A special GPU kernel for int32. // TODO(b/25387198): Also enable int32 in device memory. This kernel // registration requires all int32 inputs and outputs to be in host memory. diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc index 9fba08d6b1..e60f11b246 100644 --- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc +++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc @@ -171,6 +171,7 @@ REGISTER_CPU(float); REGISTER_CPU(double); REGISTER_CPU(int32); REGISTER_CPU(complex64); +REGISTER_CPU(complex128); #if GOOGLE_CUDA diff --git a/tensorflow/core/kernels/transpose_functor_cpu.cc b/tensorflow/core/kernels/transpose_functor_cpu.cc index ca4268c78d..0c20aa1f01 100644 --- a/tensorflow/core/kernels/transpose_functor_cpu.cc +++ b/tensorflow/core/kernels/transpose_functor_cpu.cc @@ -99,6 +99,10 @@ Status DoTranspose<Device>(const Device& d, const Tensor& in, internal::Transpose<Device, uint64>(d, in, perm, out); break; + case DT_COMPLEX128: + internal::Transpose<Device, complex128>(d, in, perm, out); + break; + case DT_STRING: internal::Transpose<Device, string>(d, in, perm, out); break; diff --git a/tensorflow/core/kernels/xent_op.cc b/tensorflow/core/kernels/xent_op.cc index 3d8746fd7c..703254ce90 100644 --- a/tensorflow/core/kernels/xent_op.cc +++ b/tensorflow/core/kernels/xent_op.cc @@ -108,6 +108,10 @@ REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") .Device(DEVICE_GPU) .TypeConstraint<float>("T"), SoftmaxXentWithLogitsOp<GPUDevice, float>); +REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") + .Device(DEVICE_GPU) + .TypeConstraint<double>("T"), + SoftmaxXentWithLogitsOp<GPUDevice, double>); #endif // GOOGLE_CUDA } // namespace tensorflow diff --git a/tensorflow/core/kernels/xent_op_gpu.cu.cc b/tensorflow/core/kernels/xent_op_gpu.cu.cc index 18b36bfcf1..1b707be007 100644 --- a/tensorflow/core/kernels/xent_op_gpu.cu.cc +++ b/tensorflow/core/kernels/xent_op_gpu.cu.cc @@ -42,9 +42,10 @@ struct XentFunctor<GPUDevice, T> { }; } // end namespace functor -// Instantiate the GPU implementation for half and float. +// Instantiate the GPU implementation for half, float and double. template struct functor::XentFunctor<GPUDevice, Eigen::half>; template struct functor::XentFunctor<GPUDevice, float>; +template struct functor::XentFunctor<GPUDevice, double>; } // end namespace tensorflow diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index fc4d2e6f13..fe0fa34337 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -167,7 +167,7 @@ y: a tensor of the same shape and type as x but filled with zeros. REGISTER_OP("Diag") .Input("diagonal: T") .Output("output: T") - .Attr("T: {float, double, int32, int64}") + .Attr("T: {float, double, int32, int64, complex64}") .Doc(R"doc( Returns a diagonal tensor with a given diagonal values. @@ -196,7 +196,7 @@ diagonal: Rank k tensor where k is at most 3. REGISTER_OP("DiagPart") .Input("input: T") .Output("diagonal: T") - .Attr("T: {float, double, int32, int64}") + .Attr("T: {float, double, int32, int64, complex64}") .Doc(R"doc( Returns the diagonal part of the tensor. @@ -660,14 +660,14 @@ For example: ```prettyprint # tensor 't' is [1, 2, 3, 4, 5, 6, 7, 8, 9] # tensor 't' has shape [9] -reshape(t, [3, 3]) ==> [[1, 2, 3] - [4, 5, 6] +reshape(t, [3, 3]) ==> [[1, 2, 3], + [4, 5, 6], [7, 8, 9]] -# tensor 't' is [[[1, 1], [2, 2]] +# tensor 't' is [[[1, 1], [2, 2]], # [[3, 3], [4, 4]]] # tensor 't' has shape [2, 2, 2] -reshape(t, [2, 4]) ==> [[1, 1, 2, 2] +reshape(t, [2, 4]) ==> [[1, 1, 2, 2], [3, 3, 4, 4]] # tensor 't' is [[[1, 1, 1], @@ -679,9 +679,22 @@ reshape(t, [2, 4]) ==> [[1, 1, 2, 2] # tensor 't' has shape [3, 2, 3] # pass '[-1]' to flatten 't' reshape(t, [-1]) ==> [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6] -# -1 can also be used with higher dimensional shapes + +# -1 can also be used to infer the shape + +# -1 is inferred to be 9: reshape(t, [2, -1]) ==> [[1, 1, 1, 2, 2, 2, 3, 3, 3], [4, 4, 4, 5, 5, 5, 6, 6, 6]] +# -1 is inferred to be 2: +reshape(t, [-1, 9]) ==> [[1, 1, 1, 2, 2, 2, 3, 3, 3], + [4, 4, 4, 5, 5, 5, 6, 6, 6]] +# -1 is inferred to be 3: +reshape(t, [ 2, -1, 3]) ==> [[[1, 1, 1], + [2, 2, 2], + [3, 3, 3]], + [[4, 4, 4], + [5, 5, 5], + [6, 6, 6]]] # tensor 't' is [7] # shape `[]` reshapes to a scalar @@ -1394,10 +1407,10 @@ The output tensor has shape `[4, 1, 1, 3]` and value: (3) For the following input of shape `[1, 4, 4, 1]` and block_size of 2: ```prettyprint -x = [[[1], [2], [3], [4]], - [[5], [6], [7], [8]], - [[9], [10], [11], [12]], - [[13], [14], [15], [16]]] +x = [[[[1], [2], [3], [4]], + [[5], [6], [7], [8]], + [[9], [10], [11], [12]], + [[13], [14], [15], [16]]]] ``` The output tensor has shape `[4, 2, 2, 1]` and value: @@ -1591,10 +1604,10 @@ This operation, for block_size of 2, will return the following tensor of shape Similarly, for the following input of shape `[1 4 4 1]`, and a block size of 2: ```prettyprint -x = [[ [1], [2], [5], [6]], - [ [3], [4], [7], [8]], - [ [9], [10], [13], [14]], - [ [11], [12], [15], [16]]] +x = [[[[1], [2], [5], [6]], + [[3], [4], [7], [8]], + [[9], [10], [13], [14]], + [[11], [12], [15], [16]]]] ``` the operator will return the following tensor of shape `[1 2 2 4]`: diff --git a/tensorflow/core/ops/math_grad.cc b/tensorflow/core/ops/math_grad.cc index d290580077..7b4b28159d 100644 --- a/tensorflow/core/ops/math_grad.cc +++ b/tensorflow/core/ops/math_grad.cc @@ -284,7 +284,7 @@ REGISTER_OP_GRADIENT("Sub", SubGrad); Status MulGrad(const AttrSlice& attrs, FunctionDef* g) { DataType T; TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T)); - if (T == DT_COMPLEX64) { + if (T == DT_COMPLEX64 || T == DT_COMPLEX128) { return GradForBinaryCwise( g, { {{"cy"}, "Conj", {"y"}, {}, {"dz"}}, @@ -543,7 +543,7 @@ Status MatMulGradCommon(const string& opname, const string& attr_adj_x, FunctionDef* g) { DataType T; TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T)); - if (T == DT_COMPLEX64) { + if (T == DT_COMPLEX64 || T == DT_COMPLEX128) { return errors::Unimplemented( "MatMul gradient for complex is not supported yet."); } diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 861ed74b1f..fdb490df9e 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -37,7 +37,7 @@ REGISTER_OP("BatchMatMul") .Input("x: T") .Input("y: T") .Output("output: T") - .Attr("T: {half, float, double, int32, complex64}") + .Attr("T: {half, float, double, int32, complex64, complex128}") .Attr("adj_x: bool = false") .Attr("adj_y: bool = false") .Doc(R"doc( @@ -111,15 +111,17 @@ an output element, this operation computes \\(y = |x|\\). )doc"); REGISTER_OP("ComplexAbs") - .Input("x: complex64") - .Output("y: float") + .Input("x: T") + .Output("y: Tout") + .Attr("T: {complex64, complex128} = DT_COMPLEX64") + .Attr("Tout: {float, double} = DT_FLOAT") .Doc(R"doc( Computes the complex absolute value of a tensor. Given a tensor `x` of complex numbers, this operation returns a tensor of type -`float` that is the absolute value of each element in `x`. All elements in `x` -must be complex numbers of the form \\(a + bj\\). The absolute value is -computed as \\( \sqrt{a^2 + b^2}\\). +`float` or `double` that is the absolute value of each element in `x`. All +elements in `x` must be complex numbers of the form \\(a + bj\\). The absolute +value is computed as \\( \sqrt{a^2 + b^2}\\). For example: @@ -132,7 +134,7 @@ tf.complex_abs(x) ==> [5.25594902, 6.60492229] // Declares cwise unary operations signature: 't -> 't #define UNARY() \ Input("x: T").Output("y: T").Attr( \ - "T: {half, float, double, int32, complex64, int64}") + "T: {half, float, double, int32, int64, complex64, complex128}") REGISTER_OP("Neg") .UNARY() @@ -262,7 +264,7 @@ Returns which elements of x are finite. REGISTER_OP("Sign") .Input("x: T") .Output("y: T") - .Attr("T: {half, float, double, int32, int64, complex64}") + .Attr("T: {half, float, double, int32, int64, complex64, complex128}") .Doc(R"doc( Returns an element-wise indication of the sign of a number. @@ -291,11 +293,11 @@ Returns element-wise smallest integer in not less than x. #define BINARY_MORE() \ Input("x: T").Input("y: T").Output("z: T").Attr( \ - "T: {half, float, double, uint8, int8, int16, int32, int64, complex64}") + "T: {half, float, double, uint8, int8, int16, int32, int64, complex64, complex128}") #define BINARY_FEWER() \ Input("x: T").Input("y: T").Output("z: T").Attr( \ - "T: {half, float, double, int32, complex64, int64}") + "T: {half, float, double, int32, int64, complex64, complex128}") // TODO(mrry): Restore `SetIsCommutative()` for non-string types. REGISTER_OP("Add") @@ -304,7 +306,7 @@ REGISTER_OP("Add") .Output("z: T") .Attr( "T: {half, float, double, uint8, int8, int16, int32, int64, complex64, " - "string}") + "complex128, string}") .Doc(R"doc( Returns x + y element-wise. @@ -373,7 +375,7 @@ REGISTER_OP("Pow") .Input("x: T") .Input("y: T") .Output("z: T") - .Attr("T: {half, float, double, int32, complex64, int64}") + .Attr("T: {half, float, double, int32, int64, complex64, complex128}") .Doc(R"doc( Computes the power of one value to another. @@ -433,6 +435,37 @@ Note, above `Q(a, x)` (`Igammac`) is the upper regularized complete Gamma function. )doc"); +REGISTER_OP("Zeta") + .Input("x: T") + .Input("q: T") + .Output("z: T") + .Attr("T: {float, double}") + .Doc(R"doc( +Compute the Hurwitz zeta function \\(\zeta(x, q)\\). + +The Hurwitz zeta function is defined as: + +``` +\zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x} +``` +)doc"); + +REGISTER_OP("Polygamma") + .Input("a: T") + .Input("x: T") + .Output("z: T") + .Attr("T: {float, double}") + .Doc(R"doc( +Compute the polygamma function \\(\psi^{(n)}(x)\\). + +The polygamma function is defined as: + +``` +\psi^{(n)}(x) = \frac{d^n}{dx^n} \psi(x) +``` +where \\(\psi(x)\\) is the digamma function. +)doc"); + // -------------------------------------------------------------------------- // Declares cwise binary comparison operations signature: 't, 't -> bool, @@ -471,7 +504,7 @@ Returns the truth value of (x >= y) element-wise. #define EQUALITY_COMPARISON() \ Input("x: T").Input("y: T").Output("z: bool").SetIsCommutative().Attr( \ "T: {half, float, double, uint8, int8, int16, int32, int64, complex64, " \ - "quint8, qint8, qint32, string, bool}") + "quint8, qint8, qint32, string, bool, complex128}") REGISTER_OP("Equal") .EQUALITY_COMPARISON() @@ -577,7 +610,7 @@ REGISTER_OP("MatMul") .Output("product: T") .Attr("transpose_a: bool = false") .Attr("transpose_b: bool = false") - .Attr("T: {half, float, double, int32, complex64}") + .Attr("T: {half, float, double, int32, complex64, complex128}") .Doc(R"doc( Multiply the matrix "a" by the matrix "b". @@ -1141,9 +1174,11 @@ output: 1-D. The generated values. )doc"); REGISTER_OP("Complex") - .Input("real: float") - .Input("imag: float") - .Output("output: complex64") + .Input("real: T") + .Input("imag: T") + .Output("out: Tout") + .Attr("T: {float, double} = DT_FLOAT") + .Attr("Tout: {complex64, complex128} = DT_COMPLEX64") .Doc(R"doc( Converts two real numbers to a complex number. @@ -1163,7 +1198,12 @@ tf.complex(real, imag) ==> [[2.25 + 4.75j], [3.25 + 5.75j]] ``` )doc"); -REGISTER_OP("Real").Input("input: complex64").Output("output: float").Doc(R"doc( +REGISTER_OP("Real") + .Input("input: T") + .Output("output: Tout") + .Attr("T: {complex64, complex128} = DT_COMPLEX64") + .Attr("Tout: {float, double} = DT_FLOAT") + .Doc(R"doc( Returns the real part of a complex number. Given a tensor `input` of complex numbers, this operation returns a tensor of @@ -1179,7 +1219,12 @@ tf.real(input) ==> [-2.25, 3.25] ``` )doc"); -REGISTER_OP("Imag").Input("input: complex64").Output("output: float").Doc(R"doc( +REGISTER_OP("Imag") + .Input("input: T") + .Output("output: Tout") + .Attr("T: {complex64, complex128} = DT_COMPLEX64") + .Attr("Tout: {float, double} = DT_FLOAT") + .Doc(R"doc( Returns the imaginary part of a complex number. Given a tensor `input` of complex numbers, this operation returns a tensor of @@ -1196,8 +1241,9 @@ tf.imag(input) ==> [4.75, 5.75] )doc"); REGISTER_OP("Conj") - .Input("input: complex64") - .Output("output: complex64") + .Input("input: T") + .Output("output: T") + .Attr("T: {complex64, complex128} = DT_COMPLEX64") .Doc(R"doc( Returns the complex conjugate of a complex number. diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index d22df658c8..34e38b97d7 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -3594,6 +3594,7 @@ op { type: DT_DOUBLE type: DT_INT32 type: DT_INT64 + type: DT_COMPLEX64 } } } @@ -3621,6 +3622,7 @@ op { type: DT_DOUBLE type: DT_INT32 type: DT_INT64 + type: DT_COMPLEX64 } } } diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto index 8469c3823b..f357735fd6 100644 --- a/tensorflow/core/protobuf/worker.proto +++ b/tensorflow/core/protobuf/worker.proto @@ -154,7 +154,7 @@ message RunGraphRequest { // A unique ID to distinguish different runs of the same graph. // - // The master generates a global unique `step_id` to dinstinguish + // The master generates a global unique `step_id` to distinguish // different runs of the graph computation. Subgraphs communicate // (e.g., send/recv ops) with each other using `step_id` to // distinguish tensors generated by different runs. diff --git a/tensorflow/examples/skflow/README.md b/tensorflow/examples/skflow/README.md index 8bb2739e6c..756351c0b6 100644 --- a/tensorflow/examples/skflow/README.md +++ b/tensorflow/examples/skflow/README.md @@ -7,21 +7,28 @@ known Scikit Learn API. To run these examples, you need to have `scikit learn` library installed (`sudo pip install sklearn`). Some examples use the `pandas` library for data processing (`sudo pip install pandas`). +## Basics + * [Deep Neural Network Regression with Boston Data](boston.py) * [Convolutional Neural Networks with Digits Data](digits.py) * [Deep Neural Network Classification with Iris Data](iris.py) -* [Grid search and Deep Neural Network Classification](iris_gridsearch_cv.py) -* [Deep Neural Network with Customized Decay Function](iris_custom_decay_dnn.py) +* [Deep Neural Network Autoencoder with Iris Data](dnn_autoencoder_iris.py) * [Building A Custom Model](iris_custom_model.py) * [Accessing Weights and Biases in A Custom Model](mnist_weights.py) -* [Building A Custom Model Using Multiple GPUs](multiple_gpu.py) -* [Building A Model Using Different GPU Configurations](iris_config_addon.py) -* [Using skflow with Pipeline](iris_with_pipeline.py) +* [Building A Model Using Different GPU Configurations](iris_run_config.py) * [Example of Saving and Restoring Models](iris_save_restore.py) * [Multi-output Deep Neural Network regression](multioutput_regression.py) + + +## Techniques + * [Improving Performance Using Early Stopping with Iris Data](iris_val_based_early_stopping.py) +* [Using skflow with Pipeline](iris_with_pipeline.py) +* [Building A Custom Model Using Multiple GPUs](multiple_gpu.py) +* [Grid search and Deep Neural Network Classification](iris_gridsearch_cv.py) +* [Deep Neural Network with Customized Decay Function](iris_custom_decay_dnn.py) * [Out-of-core Data Classification Using Dask](out_of_core_data_classification.py) - +* [Handling Large HDF5 Dataset](hdf5_classification.py) ## Image classification diff --git a/tensorflow/examples/skflow/boston.py b/tensorflow/examples/skflow/boston.py index 2349baf72e..bf2066770c 100644 --- a/tensorflow/examples/skflow/boston.py +++ b/tensorflow/examples/skflow/boston.py @@ -40,6 +40,6 @@ regressor = skflow.TensorFlowDNNRegressor(hidden_units=[10, 10], regressor.fit(X_train, y_train) # Predict and score -score = metrics.mean_squared_error(regressor.predict(scaler.fit_transform(X_test)), y_test) +score = metrics.mean_squared_error(regressor.predict(scaler.transform(X_test)), y_test) print('MSE: {0:f}'.format(score)) diff --git a/tensorflow/examples/skflow/digits.py b/tensorflow/examples/skflow/digits.py index f61df6c572..b3c684b7df 100644 --- a/tensorflow/examples/skflow/digits.py +++ b/tensorflow/examples/skflow/digits.py @@ -18,8 +18,8 @@ from __future__ import print_function from sklearn import datasets, cross_validation, metrics import tensorflow as tf -from tensorflow.contrib import skflow -from tensorflow.contrib.skflow import monitors +from tensorflow.contrib import learn +from tensorflow.contrib.learn import monitors # Load dataset @@ -45,13 +45,13 @@ X_train, X_val, y_train, y_val = cross_validation.train_test_split(X_train, def conv_model(X, y): X = tf.expand_dims(X, 3) - features = tf.reduce_max(skflow.ops.conv2d(X, 12, [3, 3]), [1, 2]) + features = tf.reduce_max(learn.ops.conv2d(X, 12, [3, 3]), [1, 2]) features = tf.reshape(features, [-1, 12]) - return skflow.models.logistic_regression(features, y) + return learn.models.logistic_regression(features, y) val_monitor = monitors.ValidationMonitor(X_val, y_val, every_n_steps=50) # Create a classifier, train and predict. -classifier = skflow.TensorFlowEstimator(model_fn=conv_model, n_classes=10, +classifier = learn.TensorFlowEstimator(model_fn=conv_model, n_classes=10, steps=1000, learning_rate=0.05, batch_size=128) classifier.fit(X_train, y_train, val_monitor) diff --git a/tensorflow/examples/skflow/dnn_autoencoder_iris.py b/tensorflow/examples/skflow/dnn_autoencoder_iris.py new file mode 100644 index 0000000000..c4383ae608 --- /dev/null +++ b/tensorflow/examples/skflow/dnn_autoencoder_iris.py @@ -0,0 +1,35 @@ +# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random + +import tensorflow as tf +from tensorflow.contrib import learn +from tensorflow.contrib.learn import datasets + +# Load Iris Data +iris = datasets.load_iris() + +# Initialize a deep neural network autoencoder +# You can also add noise and add dropout if needed +# Details see TensorFlowDNNAutoencoder documentation. +autoencoder = learn.TensorFlowDNNAutoencoder(hidden_units=[10, 20]) + +# Fit with Iris data +transformed = autoencoder.fit_transform(iris.data) + +print(transformed) diff --git a/tensorflow/examples/skflow/hdf5_classification.py b/tensorflow/examples/skflow/hdf5_classification.py new file mode 100644 index 0000000000..0a4a7fd731 --- /dev/null +++ b/tensorflow/examples/skflow/hdf5_classification.py @@ -0,0 +1,49 @@ +# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from sklearn import metrics, cross_validation +from tensorflow.contrib import learn +import h5py + +# Load dataset. +iris = learn.datasets.load_dataset('iris') +X_train, X_test, y_train, y_test = cross_validation.train_test_split(iris.data, iris.target, + test_size=0.2, random_state=42) + +# Note that we are saving and load iris data as h5 format as a simple demonstration here. +h5f = h5py.File('test_hdf5.h5', 'w') +h5f.create_dataset('X_train', data=X_train) +h5f.create_dataset('X_test', data=X_test) +h5f.create_dataset('y_train', data=y_train) +h5f.create_dataset('y_test', data=y_test) +h5f.close() + +h5f = h5py.File('test_hdf5.h5', 'r') +X_train = h5f['X_train'] +X_test = h5f['X_test'] +y_train = h5f['y_train'] +y_test = h5f['y_test'] + +# Build 3 layer DNN with 10, 20, 10 units respectively. +classifier = learn.TensorFlowDNNClassifier(hidden_units=[10, 20, 10], + n_classes=3, steps=200) + +# Fit and predict. +classifier.fit(X_train, y_train) +score = metrics.accuracy_score(y_test, classifier.predict(X_test)) +print('Accuracy: {0:f}'.format(score)) + diff --git a/tensorflow/examples/skflow/iris.py b/tensorflow/examples/skflow/iris.py index 3ea1ef1f4c..c6c566b10f 100644 --- a/tensorflow/examples/skflow/iris.py +++ b/tensorflow/examples/skflow/iris.py @@ -17,15 +17,15 @@ from __future__ import print_function from sklearn import metrics, cross_validation -from tensorflow.contrib import skflow +from tensorflow.contrib import learn # Load dataset. -iris = skflow.datasets.load_dataset('iris') +iris = learn.datasets.load_dataset('iris') X_train, X_test, y_train, y_test = cross_validation.train_test_split(iris.data, iris.target, test_size=0.2, random_state=42) # Build 3 layer DNN with 10, 20, 10 units respectively. -classifier = skflow.TensorFlowDNNClassifier(hidden_units=[10, 20, 10], +classifier = learn.TensorFlowDNNClassifier(hidden_units=[10, 20, 10], n_classes=3, steps=200) # Fit and predict. diff --git a/tensorflow/examples/skflow/iris_custom_model.py b/tensorflow/examples/skflow/iris_custom_model.py index f142a7db8c..11c7bc88d6 100644 --- a/tensorflow/examples/skflow/iris_custom_model.py +++ b/tensorflow/examples/skflow/iris_custom_model.py @@ -16,7 +16,7 @@ from __future__ import division from __future__ import print_function from sklearn import datasets, metrics, cross_validation -from tensorflow.contrib import skflow +from tensorflow.contrib import learn iris = datasets.load_iris() X_train, X_test, y_train, y_test = cross_validation.train_test_split(iris.data, iris.target, @@ -24,10 +24,10 @@ X_train, X_test, y_train, y_test = cross_validation.train_test_split(iris.data, def my_model(X, y): """This is DNN with 10, 20, 10 hidden layers, and dropout of 0.1 probability.""" - layers = skflow.ops.dnn(X, [10, 20, 10], dropout=0.1) - return skflow.models.logistic_regression(layers, y) + layers = learn.ops.dnn(X, [10, 20, 10], dropout=0.1) + return learn.models.logistic_regression(layers, y) -classifier = skflow.TensorFlowEstimator(model_fn=my_model, n_classes=3, +classifier = learn.TensorFlowEstimator(model_fn=my_model, n_classes=3, steps=1000) classifier.fit(X_train, y_train) score = metrics.accuracy_score(y_test, classifier.predict(X_test)) diff --git a/tensorflow/examples/skflow/iris_run_config.py b/tensorflow/examples/skflow/iris_run_config.py index c85fcec224..4f057f817d 100644 --- a/tensorflow/examples/skflow/iris_run_config.py +++ b/tensorflow/examples/skflow/iris_run_config.py @@ -17,7 +17,7 @@ from __future__ import print_function from sklearn import datasets, metrics, cross_validation -from tensorflow.contrib import skflow +from tensorflow.contrib import learn # Load dataset. @@ -27,10 +27,10 @@ X_train, X_test, y_train, y_test = cross_validation.train_test_split(iris.data, # You can define you configurations by providing a RunConfig object to # estimator to control session configurations, e.g. num_cores and gpu_memory_fraction -run_config = skflow.estimators.RunConfig(num_cores=3, gpu_memory_fraction=0.6) +run_config = learn.estimators.RunConfig(num_cores=3, gpu_memory_fraction=0.6) # Build 3 layer DNN with 10, 20, 10 units respecitvely. -classifier = skflow.TensorFlowDNNClassifier(hidden_units=[10, 20, 10], +classifier = learn.TensorFlowDNNClassifier(hidden_units=[10, 20, 10], n_classes=3, steps=200, config=run_config) # Fit and predict. diff --git a/tensorflow/examples/skflow/iris_save_restore.py b/tensorflow/examples/skflow/iris_save_restore.py index f9613c96fe..d29237a26f 100644 --- a/tensorflow/examples/skflow/iris_save_restore.py +++ b/tensorflow/examples/skflow/iris_save_restore.py @@ -18,13 +18,13 @@ from __future__ import print_function import shutil from sklearn import datasets, metrics, cross_validation -from tensorflow.contrib import skflow +from tensorflow.contrib import learn iris = datasets.load_iris() X_train, X_test, y_train, y_test = cross_validation.train_test_split(iris.data, iris.target, test_size=0.2, random_state=42) -classifier = skflow.TensorFlowLinearClassifier(n_classes=3) +classifier = learn.TensorFlowLinearClassifier(n_classes=3) classifier.fit(X_train, y_train) score = metrics.accuracy_score(y_test, classifier.predict(X_test)) print('Accuracy: {0:f}'.format(score)) @@ -40,6 +40,6 @@ classifier.save('/tmp/skflow_examples/iris_custom_model') classifier = None ## Restore everything -new_classifier = skflow.TensorFlowEstimator.restore('/tmp/skflow_examples/iris_custom_model') +new_classifier = learn.TensorFlowEstimator.restore('/tmp/skflow_examples/iris_custom_model') score = metrics.accuracy_score(y_test, new_classifier.predict(X_test)) print('Accuracy: {0:f}'.format(score)) diff --git a/tensorflow/examples/skflow/iris_val_based_early_stopping.py b/tensorflow/examples/skflow/iris_val_based_early_stopping.py index 74e02561f3..3ded960bac 100644 --- a/tensorflow/examples/skflow/iris_val_based_early_stopping.py +++ b/tensorflow/examples/skflow/iris_val_based_early_stopping.py @@ -18,7 +18,7 @@ from __future__ import print_function from sklearn import datasets, metrics from sklearn.cross_validation import train_test_split -from tensorflow.contrib import skflow +from tensorflow.contrib import learn iris = datasets.load_iris() @@ -29,17 +29,17 @@ X_train, X_test, y_train, y_test = train_test_split(iris.data, X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42) -val_monitor = skflow.monitors.ValidationMonitor(X_val, y_val, - early_stopping_rounds=200) +val_monitor = learn.monitors.ValidationMonitor(X_val, y_val, + early_stopping_rounds=200) # classifier with early stopping on training data -classifier1 = skflow.TensorFlowDNNClassifier(hidden_units=[10, 20, 10], +classifier1 = learn.TensorFlowDNNClassifier(hidden_units=[10, 20, 10], n_classes=3, steps=2000) classifier1.fit(X_train, y_train, logdir='/tmp/iris_model/') score1 = metrics.accuracy_score(y_test, classifier1.predict(X_test)) # classifier with early stopping on validation data -classifier2 = skflow.TensorFlowDNNClassifier(hidden_units=[10, 20, 10], +classifier2 = learn.TensorFlowDNNClassifier(hidden_units=[10, 20, 10], n_classes=3, steps=2000) classifier2.fit(X_train, y_train, val_monitor, logdir='/tmp/iris_model_val/') score2 = metrics.accuracy_score(y_test, classifier2.predict(X_test)) diff --git a/tensorflow/examples/skflow/iris_with_pipeline.py b/tensorflow/examples/skflow/iris_with_pipeline.py index f6408f84a8..3ba5739250 100644 --- a/tensorflow/examples/skflow/iris_with_pipeline.py +++ b/tensorflow/examples/skflow/iris_with_pipeline.py @@ -20,7 +20,7 @@ from sklearn.datasets import load_iris from sklearn import cross_validation from sklearn.preprocessing import StandardScaler from sklearn.metrics import accuracy_score -from tensorflow.contrib import skflow +from tensorflow.contrib import learn iris = load_iris() X_train, X_test, y_train, y_test = cross_validation.train_test_split(iris.data, iris.target, @@ -30,7 +30,7 @@ X_train, X_test, y_train, y_test = cross_validation.train_test_split(iris.data, scaler = StandardScaler() # DNN classifier -DNNclassifier = skflow.TensorFlowDNNClassifier(hidden_units=[10, 20, 10], n_classes=3, steps=200) +DNNclassifier = learn.TensorFlowDNNClassifier(hidden_units=[10, 20, 10], n_classes=3, steps=200) pipeline = Pipeline([('scaler', scaler), ('DNNclassifier', DNNclassifier)]) diff --git a/tensorflow/examples/skflow/language_model.py b/tensorflow/examples/skflow/language_model.py index 439d6f3198..ccc1b04c5b 100644 --- a/tensorflow/examples/skflow/language_model.py +++ b/tensorflow/examples/skflow/language_model.py @@ -22,7 +22,7 @@ import math import numpy as np import tensorflow as tf -from tensorflow.contrib import skflow +from tensorflow.contrib import learn ### Training data @@ -53,7 +53,7 @@ def unpack_xy(iter_obj): return (item[0] for item in X), (item[1] for item in y) -byte_processor = skflow.preprocessing.ByteProcessor( +byte_processor = learn.preprocessing.ByteProcessor( max_document_length=MAX_DOC_LENGTH) data = training_data(CORPUS_FILENAME) @@ -68,31 +68,31 @@ HIDDEN_SIZE = 10 def seq_autoencoder(X, y): """Sequence auto-encoder with RNN.""" - inputs = skflow.ops.one_hot_matrix(X, 256) - in_X, in_y, out_y = skflow.ops.seq2seq_inputs(inputs, y, MAX_DOC_LENGTH, MAX_DOC_LENGTH) + inputs = learn.ops.one_hot_matrix(X, 256) + in_X, in_y, out_y = learn.ops.seq2seq_inputs(inputs, y, MAX_DOC_LENGTH, MAX_DOC_LENGTH) encoder_cell = tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE) decoder_cell = tf.nn.rnn_cell.OutputProjectionWrapper(tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE), 256) - decoding, _, sampling_decoding, _ = skflow.ops.rnn_seq2seq(in_X, in_y, encoder_cell, decoder_cell) - return skflow.ops.sequence_classifier(decoding, out_y, sampling_decoding) + decoding, _, sampling_decoding, _ = learn.ops.rnn_seq2seq(in_X, in_y, encoder_cell, decoder_cell) + return learn.ops.sequence_classifier(decoding, out_y, sampling_decoding) def get_language_model(hidden_size): """Returns a language model with given hidden size.""" def language_model(X, y): - inputs = skflow.ops.one_hot_matrix(X, 256) - inputs = skflow.ops.split_squeeze(1, MAX_DOC_LENGTH, inputs) - target = skflow.ops.split_squeeze(1, MAX_DOC_LENGTH, y) + inputs = learn.ops.one_hot_matrix(X, 256) + inputs = learn.ops.split_squeeze(1, MAX_DOC_LENGTH, inputs) + target = learn.ops.split_squeeze(1, MAX_DOC_LENGTH, y) encoder_cell = tf.nn.rnn_cell.OutputProjectionWrapper(tf.nn.rnn_cell.GRUCell(hidden_size),256) output, _ = tf.nn.rnn(encoder_cell, inputs, dtype=tf.float32) - return skflow.ops.sequence_classifier(output, target) + return learn.ops.sequence_classifier(output, target) return language_model ### Training model. -estimator = skflow.TensorFlowEstimator(model_fn=get_language_model(HIDDEN_SIZE), +estimator = learn.TensorFlowEstimator(model_fn=get_language_model(HIDDEN_SIZE), n_classes=256, optimizer='Adam', learning_rate=0.01, steps=1000, batch_size=64, continue_training=True) diff --git a/tensorflow/examples/skflow/mnist.py b/tensorflow/examples/skflow/mnist.py index 504a5ae9f8..082ecb2f83 100644 --- a/tensorflow/examples/skflow/mnist.py +++ b/tensorflow/examples/skflow/mnist.py @@ -24,15 +24,15 @@ from __future__ import print_function from sklearn import metrics import tensorflow as tf -from tensorflow.contrib import skflow +from tensorflow.contrib import learn ### Download and load MNIST data. -mnist = skflow.datasets.load_dataset('mnist') +mnist = learn.datasets.load_dataset('mnist') ### Linear classifier. -classifier = skflow.TensorFlowLinearClassifier( +classifier = learn.TensorFlowLinearClassifier( n_classes=10, batch_size=100, steps=1000, learning_rate=0.01) classifier.fit(mnist.train.images, mnist.train.labels) score = metrics.accuracy_score(mnist.test.labels, classifier.predict(mnist.test.images)) @@ -50,22 +50,22 @@ def conv_model(X, y): X = tf.reshape(X, [-1, 28, 28, 1]) # first conv layer will compute 32 features for each 5x5 patch with tf.variable_scope('conv_layer1'): - h_conv1 = skflow.ops.conv2d(X, n_filters=32, filter_shape=[5, 5], + h_conv1 = learn.ops.conv2d(X, n_filters=32, filter_shape=[5, 5], bias=True, activation=tf.nn.relu) h_pool1 = max_pool_2x2(h_conv1) # second conv layer will compute 64 features for each 5x5 patch with tf.variable_scope('conv_layer2'): - h_conv2 = skflow.ops.conv2d(h_pool1, n_filters=64, filter_shape=[5, 5], + h_conv2 = learn.ops.conv2d(h_pool1, n_filters=64, filter_shape=[5, 5], bias=True, activation=tf.nn.relu) h_pool2 = max_pool_2x2(h_conv2) # reshape tensor into a batch of vectors h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64]) # densely connected layer with 1024 neurons - h_fc1 = skflow.ops.dnn(h_pool2_flat, [1024], activation=tf.nn.relu, dropout=0.5) - return skflow.models.logistic_regression(h_fc1, y) + h_fc1 = learn.ops.dnn(h_pool2_flat, [1024], activation=tf.nn.relu, dropout=0.5) + return learn.models.logistic_regression(h_fc1, y) # Training and predicting -classifier = skflow.TensorFlowEstimator( +classifier = learn.TensorFlowEstimator( model_fn=conv_model, n_classes=10, batch_size=100, steps=20000, learning_rate=0.001) classifier.fit(mnist.train.images, mnist.train.labels) diff --git a/tensorflow/examples/skflow/mnist_weights.py b/tensorflow/examples/skflow/mnist_weights.py index 7d4acd24af..b0c2ea583e 100644 --- a/tensorflow/examples/skflow/mnist_weights.py +++ b/tensorflow/examples/skflow/mnist_weights.py @@ -24,15 +24,15 @@ from __future__ import print_function from sklearn import metrics import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data -from tensorflow.contrib import skflow +from tensorflow.contrib import learn ### Download and load MNIST data. -mnist = skflow.datasets.load_dataset('mnist') +mnist = learn.datasets.load_dataset('mnist') ### Linear classifier. -classifier = skflow.TensorFlowLinearClassifier( +classifier = learn.TensorFlowLinearClassifier( n_classes=10, batch_size=100, steps=1000, learning_rate=0.01) classifier.fit(mnist.train.images, mnist.train.labels) score = metrics.accuracy_score(mnist.test.labels, classifier.predict(mnist.test.images)) @@ -50,22 +50,22 @@ def conv_model(X, y): X = tf.reshape(X, [-1, 28, 28, 1]) # first conv layer will compute 32 features for each 5x5 patch with tf.variable_scope('conv_layer1'): - h_conv1 = skflow.ops.conv2d(X, n_filters=32, filter_shape=[5, 5], + h_conv1 = learn.ops.conv2d(X, n_filters=32, filter_shape=[5, 5], bias=True, activation=tf.nn.relu) h_pool1 = max_pool_2x2(h_conv1) # second conv layer will compute 64 features for each 5x5 patch with tf.variable_scope('conv_layer2'): - h_conv2 = skflow.ops.conv2d(h_pool1, n_filters=64, filter_shape=[5, 5], + h_conv2 = learn.ops.conv2d(h_pool1, n_filters=64, filter_shape=[5, 5], bias=True, activation=tf.nn.relu) h_pool2 = max_pool_2x2(h_conv2) # reshape tensor into a batch of vectors h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64]) # densely connected layer with 1024 neurons - h_fc1 = skflow.ops.dnn(h_pool2_flat, [1024], activation=tf.nn.relu, dropout=0.5) - return skflow.models.logistic_regression(h_fc1, y) + h_fc1 = learn.ops.dnn(h_pool2_flat, [1024], activation=tf.nn.relu, dropout=0.5) + return learn.models.logistic_regression(h_fc1, y) # Training and predicting -classifier = skflow.TensorFlowEstimator( +classifier = learn.TensorFlowEstimator( model_fn=conv_model, n_classes=10, batch_size=100, steps=20000, learning_rate=0.001) classifier.fit(mnist.train.images, mnist.train.labels) diff --git a/tensorflow/examples/skflow/multioutput_regression.py b/tensorflow/examples/skflow/multioutput_regression.py index a4300a5508..c0ddf1cf30 100644 --- a/tensorflow/examples/skflow/multioutput_regression.py +++ b/tensorflow/examples/skflow/multioutput_regression.py @@ -26,7 +26,7 @@ import matplotlib.pyplot as plt from sklearn import datasets from sklearn.metrics import mean_squared_error -from tensorflow.contrib import skflow +from tensorflow.contrib import learn # Create random dataset. rng = np.random.RandomState(1) @@ -38,11 +38,11 @@ regressors = [] options = [[2], [10, 10], [20, 20]] for hidden_units in options: def tanh_dnn(X, y): - features = skflow.ops.dnn(X, hidden_units=hidden_units, - activation=skflow.tf.tanh) - return skflow.models.linear_regression(features, y) + features = learn.ops.dnn(X, hidden_units=hidden_units, + activation=learn.tf.tanh) + return learn.models.linear_regression(features, y) - regressor = skflow.TensorFlowEstimator(model_fn=tanh_dnn, n_classes=0, + regressor = learn.TensorFlowEstimator(model_fn=tanh_dnn, n_classes=0, steps=500, learning_rate=0.1, batch_size=100) regressor.fit(X, y) score = mean_squared_error(regressor.predict(X), y) diff --git a/tensorflow/examples/skflow/multiple_gpu.py b/tensorflow/examples/skflow/multiple_gpu.py index 4afa667e6d..1168184a38 100644 --- a/tensorflow/examples/skflow/multiple_gpu.py +++ b/tensorflow/examples/skflow/multiple_gpu.py @@ -17,7 +17,7 @@ from __future__ import print_function from sklearn import datasets, metrics, cross_validation import tensorflow as tf -from tensorflow.contrib import skflow +from tensorflow.contrib import learn iris = datasets.load_iris() X_train, X_test, y_train, y_test = cross_validation.train_test_split(iris.data, iris.target, @@ -31,11 +31,11 @@ def my_model(X, y): CUDNN 6.5 V2 from NVIDIA need to be installed beforehand. """ with tf.device('/gpu:1'): - layers = skflow.ops.dnn(X, [10, 20, 10], dropout=0.5) + layers = learn.ops.dnn(X, [10, 20, 10], dropout=0.5) with tf.device('/gpu:2'): - return skflow.models.logistic_regression(layers, y) + return learn.models.logistic_regression(layers, y) -classifier = skflow.TensorFlowEstimator(model_fn=my_model, n_classes=3) +classifier = learn.TensorFlowEstimator(model_fn=my_model, n_classes=3) classifier.fit(X_train, y_train) score = metrics.accuracy_score(y_test, classifier.predict(X_test)) print('Accuracy: {0:f}'.format(score)) diff --git a/tensorflow/examples/skflow/neural_translation.py b/tensorflow/examples/skflow/neural_translation.py index d0c8c33b93..7832767145 100644 --- a/tensorflow/examples/skflow/neural_translation.py +++ b/tensorflow/examples/skflow/neural_translation.py @@ -22,7 +22,7 @@ import os import numpy as np import tensorflow as tf -from tensorflow.contrib import skflow +from tensorflow.contrib import learn # Get training data @@ -88,15 +88,15 @@ MAX_DOCUMENT_LENGTH = 30 HIDDEN_SIZE = 100 def translate_model(X, y): - byte_list = skflow.ops.one_hot_matrix(X, 256) - in_X, in_y, out_y = skflow.ops.seq2seq_inputs( + byte_list = learn.ops.one_hot_matrix(X, 256) + in_X, in_y, out_y = learn.ops.seq2seq_inputs( byte_list, y, MAX_DOCUMENT_LENGTH, MAX_DOCUMENT_LENGTH) cell = tf.nn.rnn_cell.OutputProjectionWrapper(tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE), 256) - decoding, _, sampling_decoding, _ = skflow.ops.rnn_seq2seq(in_X, in_y, cell) - return skflow.ops.sequence_classifier(decoding, out_y, sampling_decoding) + decoding, _, sampling_decoding, _ = learn.ops.rnn_seq2seq(in_X, in_y, cell) + return learn.ops.sequence_classifier(decoding, out_y, sampling_decoding) -vocab_processor = skflow.preprocessing.ByteProcessor( +vocab_processor = learn.preprocessing.ByteProcessor( max_document_length=MAX_DOCUMENT_LENGTH) x_iter = vocab_processor.transform(X_train) @@ -107,9 +107,9 @@ ygold = list(y_test)[:20] PATH = '/tmp/tf_examples/ntm/' if os.path.exists(PATH): - translator = skflow.TensorFlowEstimator.restore(PATH) + translator = learn.TensorFlowEstimator.restore(PATH) else: - translator = skflow.TensorFlowEstimator(model_fn=translate_model, + translator = learn.TensorFlowEstimator(model_fn=translate_model, n_classes=256, optimizer='Adam', learning_rate=0.01, batch_size=128, continue_training=True) diff --git a/tensorflow/examples/skflow/neural_translation_word.py b/tensorflow/examples/skflow/neural_translation_word.py index e917c97a9b..90c73f0ba5 100644 --- a/tensorflow/examples/skflow/neural_translation_word.py +++ b/tensorflow/examples/skflow/neural_translation_word.py @@ -25,7 +25,7 @@ import random import numpy as np import tensorflow as tf -from tensorflow.contrib import skflow +from tensorflow.contrib import learn # Get training data @@ -93,9 +93,9 @@ X_test, y_test = Xy(read_iterator('test.data')) MAX_DOCUMENT_LENGTH = 10 if not (os.path.exists('en.vocab') and os.path.exists('fr.vocab')): - X_vocab_processor = skflow.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH, + X_vocab_processor = learn.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH, min_frequency=5) - y_vocab_processor = skflow.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH, + y_vocab_processor = learn.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH, min_frequency=5) Xtrainff, ytrainff = Xy(read_iterator('train.data')) print('Fitting dictionary for English...') @@ -126,24 +126,24 @@ HIDDEN_SIZE = 20 EMBEDDING_SIZE = 20 def translate_model(X, y): - word_vectors = skflow.ops.categorical_variable(X, n_classes=n_en_words, + word_vectors = learn.ops.categorical_variable(X, n_classes=n_en_words, embedding_size=EMBEDDING_SIZE, name='words') - in_X, in_y, out_y = skflow.ops.seq2seq_inputs( + in_X, in_y, out_y = learn.ops.seq2seq_inputs( word_vectors, y, MAX_DOCUMENT_LENGTH, MAX_DOCUMENT_LENGTH) encoder_cell = tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE) decoder_cell = tf.nn.rnn_cell.OutputProjectionWrapper( tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE), n_fr_words) - decoding, _, sampling_decoding, _ = skflow.ops.rnn_seq2seq(in_X, in_y, + decoding, _, sampling_decoding, _ = learn.ops.rnn_seq2seq(in_X, in_y, encoder_cell, decoder_cell=decoder_cell) - return skflow.ops.sequence_classifier(decoding, out_y, sampling_decoding) + return learn.ops.sequence_classifier(decoding, out_y, sampling_decoding) PATH = '/tmp/tf_examples/ntm_words/' if os.path.exists(os.path.join(PATH, 'graph.pbtxt')): - translator = skflow.TensorFlowEstimator.restore(PATH) + translator = learn.TensorFlowEstimator.restore(PATH) else: - translator = skflow.TensorFlowEstimator(model_fn=translate_model, + translator = learn.TensorFlowEstimator(model_fn=translate_model, n_classes=n_fr_words, optimizer='Adam', learning_rate=0.01, batch_size=128, continue_training=True, steps=100) diff --git a/tensorflow/examples/skflow/out_of_core_data_classification.py b/tensorflow/examples/skflow/out_of_core_data_classification.py index 9c09f436c6..6328941e6d 100644 --- a/tensorflow/examples/skflow/out_of_core_data_classification.py +++ b/tensorflow/examples/skflow/out_of_core_data_classification.py @@ -20,7 +20,7 @@ from sklearn import datasets, metrics, cross_validation import pandas as pd import dask.dataframe as dd -from tensorflow.contrib import skflow +from tensorflow.contrib import learn # Sometimes when your dataset is too large to hold in the memory # you may want to load it into a out-of-core dataframe as provided by dask library @@ -41,7 +41,7 @@ X_train, y_train, X_test, y_test = [pd.DataFrame(data) for data in [X_train, y_t X_train, y_train, X_test, y_test = [dd.from_pandas(data, npartitions=2) for data in [X_train, y_train, X_test, y_test]] # Initialize a TensorFlow linear classifier -classifier = skflow.TensorFlowLinearClassifier(n_classes=3) +classifier = learn.TensorFlowLinearClassifier(n_classes=3) # Fit the model using training set classifier.fit(X_train, y_train) diff --git a/tensorflow/examples/skflow/resnet.py b/tensorflow/examples/skflow/resnet.py index e4dab3fd25..f1f39568d4 100644 --- a/tensorflow/examples/skflow/resnet.py +++ b/tensorflow/examples/skflow/resnet.py @@ -30,7 +30,7 @@ from math import sqrt from sklearn import metrics import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data -from tensorflow.contrib import skflow +from tensorflow.contrib import learn def res_net(x, y, activation=tf.nn.relu): @@ -62,7 +62,7 @@ def res_net(x, y, activation=tf.nn.relu): # First convolution expands to 64 channels with tf.variable_scope('conv_layer1'): - net = skflow.ops.conv2d(x, 64, [7, 7], batch_norm=True, + net = learn.ops.conv2d(x, 64, [7, 7], batch_norm=True, activation=activation, bias=False) # Max pool @@ -71,7 +71,7 @@ def res_net(x, y, activation=tf.nn.relu): # First chain of resnets with tf.variable_scope('conv_layer2'): - net = skflow.ops.conv2d(net, blocks[0].num_filters, + net = learn.ops.conv2d(net, blocks[0].num_filters, [1, 1], [1, 1, 1, 1], padding='VALID', bias=True) @@ -83,7 +83,7 @@ def res_net(x, y, activation=tf.nn.relu): # 1x1 convolution responsible for reducing dimension with tf.variable_scope(name + '/conv_in'): - conv = skflow.ops.conv2d(net, block.bottleneck_size, + conv = learn.ops.conv2d(net, block.bottleneck_size, [1, 1], [1, 1, 1, 1], padding='VALID', activation=activation, @@ -91,7 +91,7 @@ def res_net(x, y, activation=tf.nn.relu): bias=False) with tf.variable_scope(name + '/conv_bottleneck'): - conv = skflow.ops.conv2d(conv, block.bottleneck_size, + conv = learn.ops.conv2d(conv, block.bottleneck_size, [3, 3], [1, 1, 1, 1], padding='SAME', activation=activation, @@ -100,7 +100,7 @@ def res_net(x, y, activation=tf.nn.relu): # 1x1 convolution responsible for restoring dimension with tf.variable_scope(name + '/conv_out'): - conv = skflow.ops.conv2d(conv, block.num_filters, + conv = learn.ops.conv2d(conv, block.num_filters, [1, 1], [1, 1, 1, 1], padding='VALID', activation=activation, @@ -115,7 +115,7 @@ def res_net(x, y, activation=tf.nn.relu): # upscale to the next block size next_block = blocks[block_i + 1] with tf.variable_scope('block_%d/conv_upscale' % block_i): - net = skflow.ops.conv2d(net, next_block.num_filters, + net = learn.ops.conv2d(net, next_block.num_filters, [1, 1], [1, 1, 1, 1], bias=False, padding='SAME') @@ -130,7 +130,7 @@ def res_net(x, y, activation=tf.nn.relu): net_shape = net.get_shape().as_list() net = tf.reshape(net, [-1, net_shape[1] * net_shape[2] * net_shape[3]]) - return skflow.models.logistic_regression(net, y) + return learn.models.logistic_regression(net, y) # Download and load MNIST data. @@ -138,10 +138,10 @@ mnist = input_data.read_data_sets('MNIST_data') # Restore model if graph is saved into a folder. if os.path.exists("models/resnet/graph.pbtxt"): - classifier = skflow.TensorFlowEstimator.restore("models/resnet/") + classifier = learn.TensorFlowEstimator.restore("models/resnet/") else: # Create a new resnet classifier. - classifier = skflow.TensorFlowEstimator( + classifier = learn.TensorFlowEstimator( model_fn=res_net, n_classes=10, batch_size=100, steps=100, learning_rate=0.001, continue_training=True) diff --git a/tensorflow/examples/skflow/text_classification.py b/tensorflow/examples/skflow/text_classification.py index 190d4a1464..08ef507f17 100644 --- a/tensorflow/examples/skflow/text_classification.py +++ b/tensorflow/examples/skflow/text_classification.py @@ -20,12 +20,12 @@ from sklearn import metrics import pandas import tensorflow as tf -from tensorflow.contrib import skflow +from tensorflow.contrib import learn ### Training data # Downloads, unpacks and reads DBpedia dataset. -dbpedia = skflow.datasets.load_dataset('dbpedia') +dbpedia = learn.datasets.load_dataset('dbpedia') X_train, y_train = pandas.DataFrame(dbpedia.train.data)[1], pandas.Series(dbpedia.train.target) X_test, y_test = pandas.DataFrame(dbpedia.test.data)[1], pandas.Series(dbpedia.test.target) @@ -33,7 +33,7 @@ X_test, y_test = pandas.DataFrame(dbpedia.test.data)[1], pandas.Series(dbpedia.t MAX_DOCUMENT_LENGTH = 10 -vocab_processor = skflow.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH) +vocab_processor = learn.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH) X_train = np.array(list(vocab_processor.fit_transform(X_train))) X_test = np.array(list(vocab_processor.transform(X_test))) @@ -45,10 +45,10 @@ print('Total words: %d' % n_words) EMBEDDING_SIZE = 50 def average_model(X, y): - word_vectors = skflow.ops.categorical_variable(X, n_classes=n_words, + word_vectors = learn.ops.categorical_variable(X, n_classes=n_words, embedding_size=EMBEDDING_SIZE, name='words') features = tf.reduce_max(word_vectors, reduction_indices=1) - return skflow.models.logistic_regression(features, y) + return learn.models.logistic_regression(features, y) def rnn_model(X, y): """Recurrent neural network model to predict from sequence of words @@ -57,11 +57,11 @@ def rnn_model(X, y): # This creates embeddings matrix of [n_words, EMBEDDING_SIZE] and then # maps word indexes of the sequence into [batch_size, sequence_length, # EMBEDDING_SIZE]. - word_vectors = skflow.ops.categorical_variable(X, n_classes=n_words, + word_vectors = learn.ops.categorical_variable(X, n_classes=n_words, embedding_size=EMBEDDING_SIZE, name='words') # Split into list of embedding per word, while removing doc length dim. # word_list results to be a list of tensors [batch_size, EMBEDDING_SIZE]. - word_list = skflow.ops.split_squeeze(1, MAX_DOCUMENT_LENGTH, word_vectors) + word_list = learn.ops.split_squeeze(1, MAX_DOCUMENT_LENGTH, word_vectors) # Create a Gated Recurrent Unit cell with hidden size of EMBEDDING_SIZE. cell = tf.nn.rnn_cell.GRUCell(EMBEDDING_SIZE) # Create an unrolled Recurrent Neural Networks to length of @@ -70,9 +70,9 @@ def rnn_model(X, y): # Given encoding of RNN, take encoding of last step (e.g hidden size of the # neural network of last step) and pass it as features for logistic # regression over output classes. - return skflow.models.logistic_regression(encoding, y) + return learn.models.logistic_regression(encoding, y) -classifier = skflow.TensorFlowEstimator(model_fn=rnn_model, n_classes=15, +classifier = learn.TensorFlowEstimator(model_fn=rnn_model, n_classes=15, steps=1000, optimizer='Adam', learning_rate=0.01, continue_training=True) # Continuously train for 1000 steps & predict on test set. diff --git a/tensorflow/examples/skflow/text_classification_builtin_rnn_model.py b/tensorflow/examples/skflow/text_classification_builtin_rnn_model.py index 9eb7dcfb97..d4824d52c6 100644 --- a/tensorflow/examples/skflow/text_classification_builtin_rnn_model.py +++ b/tensorflow/examples/skflow/text_classification_builtin_rnn_model.py @@ -20,12 +20,12 @@ from sklearn import metrics import pandas import tensorflow as tf -from tensorflow.contrib import skflow +from tensorflow.contrib import learn ### Training data # Downloads, unpacks and reads DBpedia dataset. -dbpedia = skflow.datasets.load_dataset('dbpedia') +dbpedia = learn.datasets.load_dataset('dbpedia') X_train, y_train = pandas.DataFrame(dbpedia.train.data)[1], pandas.Series(dbpedia.train.target) X_test, y_test = pandas.DataFrame(dbpedia.test.data)[1], pandas.Series(dbpedia.test.target) @@ -33,7 +33,7 @@ X_test, y_test = pandas.DataFrame(dbpedia.test.data)[1], pandas.Series(dbpedia.t MAX_DOCUMENT_LENGTH = 10 -vocab_processor = skflow.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH) +vocab_processor = learn.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH) X_train = np.array(list(vocab_processor.fit_transform(X_train))) X_test = np.array(list(vocab_processor.transform(X_test))) @@ -50,15 +50,15 @@ def input_op_fn(X): # This creates embeddings matrix of [n_words, EMBEDDING_SIZE] and then # maps word indexes of the sequence into [batch_size, sequence_length, # EMBEDDING_SIZE]. - word_vectors = skflow.ops.categorical_variable(X, n_classes=n_words, + word_vectors = learn.ops.categorical_variable(X, n_classes=n_words, embedding_size=EMBEDDING_SIZE, name='words') # Split into list of embedding per word, while removing doc length dim. # word_list results to be a list of tensors [batch_size, EMBEDDING_SIZE]. - word_list = skflow.ops.split_squeeze(1, MAX_DOCUMENT_LENGTH, word_vectors) + word_list = learn.ops.split_squeeze(1, MAX_DOCUMENT_LENGTH, word_vectors) return word_list # Single direction GRU with a single layer -classifier = skflow.TensorFlowRNNClassifier(rnn_size=EMBEDDING_SIZE, +classifier = learn.TensorFlowRNNClassifier(rnn_size=EMBEDDING_SIZE, n_classes=15, cell_type='gru', input_op_fn=input_op_fn, num_layers=1, bidirectional=False, sequence_length=None, steps=1000, optimizer='Adam', learning_rate=0.01, continue_training=True) diff --git a/tensorflow/examples/skflow/text_classification_character_cnn.py b/tensorflow/examples/skflow/text_classification_character_cnn.py index 71dec5b4ee..f447ddb3e7 100644 --- a/tensorflow/examples/skflow/text_classification_character_cnn.py +++ b/tensorflow/examples/skflow/text_classification_character_cnn.py @@ -32,12 +32,12 @@ from sklearn import metrics import pandas import tensorflow as tf -from tensorflow.contrib import skflow +from tensorflow.contrib import learn ### Training data # Downloads, unpacks and reads DBpedia dataset. -dbpedia = skflow.datasets.load_dataset('dbpedia') +dbpedia = learn.datasets.load_dataset('dbpedia') X_train, y_train = pandas.DataFrame(dbpedia.train.data)[1], pandas.Series(dbpedia.train.target) X_test, y_test = pandas.DataFrame(dbpedia.test.data)[1], pandas.Series(dbpedia.test.target) @@ -45,7 +45,7 @@ X_test, y_test = pandas.DataFrame(dbpedia.test.data)[1], pandas.Series(dbpedia.t MAX_DOCUMENT_LENGTH = 100 -char_processor = skflow.preprocessing.ByteProcessor(MAX_DOCUMENT_LENGTH) +char_processor = learn.preprocessing.ByteProcessor(MAX_DOCUMENT_LENGTH) X_train = np.array(list(char_processor.fit_transform(X_train))) X_test = np.array(list(char_processor.transform(X_test))) @@ -59,11 +59,11 @@ POOLING_STRIDE = 2 def char_cnn_model(X, y): """Character level convolutional neural network model to predict classes.""" - byte_list = tf.reshape(skflow.ops.one_hot_matrix(X, 256), + byte_list = tf.reshape(learn.ops.one_hot_matrix(X, 256), [-1, MAX_DOCUMENT_LENGTH, 256, 1]) with tf.variable_scope('CNN_Layer1'): # Apply Convolution filtering on input sequence. - conv1 = skflow.ops.conv2d(byte_list, N_FILTERS, FILTER_SHAPE1, padding='VALID') + conv1 = learn.ops.conv2d(byte_list, N_FILTERS, FILTER_SHAPE1, padding='VALID') # Add a RELU for non linearity. conv1 = tf.nn.relu(conv1) # Max pooling across output of Convlution+Relu. @@ -73,14 +73,14 @@ def char_cnn_model(X, y): pool1 = tf.transpose(pool1, [0, 1, 3, 2]) with tf.variable_scope('CNN_Layer2'): # Second level of convolution filtering. - conv2 = skflow.ops.conv2d(pool1, N_FILTERS, FILTER_SHAPE2, + conv2 = learn.ops.conv2d(pool1, N_FILTERS, FILTER_SHAPE2, padding='VALID') # Max across each filter to get useful features for classification. pool2 = tf.squeeze(tf.reduce_max(conv2, 1), squeeze_dims=[1]) # Apply regular WX + B and classification. - return skflow.models.logistic_regression(pool2, y) + return learn.models.logistic_regression(pool2, y) -classifier = skflow.TensorFlowEstimator(model_fn=char_cnn_model, n_classes=15, +classifier = learn.TensorFlowEstimator(model_fn=char_cnn_model, n_classes=15, steps=100, optimizer='Adam', learning_rate=0.01, continue_training=True) # Continuously train for 1000 steps & predict on test set. diff --git a/tensorflow/examples/skflow/text_classification_character_rnn.py b/tensorflow/examples/skflow/text_classification_character_rnn.py index af1e37641b..6281850af4 100644 --- a/tensorflow/examples/skflow/text_classification_character_rnn.py +++ b/tensorflow/examples/skflow/text_classification_character_rnn.py @@ -32,12 +32,12 @@ from sklearn import metrics import pandas import tensorflow as tf -from tensorflow.contrib import skflow +from tensorflow.contrib import learn ### Training data # Downloads, unpacks and reads DBpedia dataset. -dbpedia = skflow.datasets.load_dataset('dbpedia') +dbpedia = learn.datasets.load_dataset('dbpedia') X_train, y_train = pandas.DataFrame(dbpedia.train.data)[1], pandas.Series(dbpedia.train.target) X_test, y_test = pandas.DataFrame(dbpedia.test.data)[1], pandas.Series(dbpedia.test.target) @@ -45,7 +45,7 @@ X_test, y_test = pandas.DataFrame(dbpedia.test.data)[1], pandas.Series(dbpedia.t MAX_DOCUMENT_LENGTH = 100 -char_processor = skflow.preprocessing.ByteProcessor(MAX_DOCUMENT_LENGTH) +char_processor = learn.preprocessing.ByteProcessor(MAX_DOCUMENT_LENGTH) X_train = np.array(list(char_processor.fit_transform(X_train))) X_test = np.array(list(char_processor.transform(X_test))) @@ -54,13 +54,13 @@ X_test = np.array(list(char_processor.transform(X_test))) HIDDEN_SIZE = 20 def char_rnn_model(X, y): - byte_list = skflow.ops.one_hot_matrix(X, 256) - byte_list = skflow.ops.split_squeeze(1, MAX_DOCUMENT_LENGTH, byte_list) + byte_list = learn.ops.one_hot_matrix(X, 256) + byte_list = learn.ops.split_squeeze(1, MAX_DOCUMENT_LENGTH, byte_list) cell = tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE) _, encoding = tf.nn.rnn(cell, byte_list, dtype=tf.float32) - return skflow.models.logistic_regression(encoding, y) + return learn.models.logistic_regression(encoding, y) -classifier = skflow.TensorFlowEstimator(model_fn=char_rnn_model, n_classes=15, +classifier = learn.TensorFlowEstimator(model_fn=char_rnn_model, n_classes=15, steps=100, optimizer='Adam', learning_rate=0.01, continue_training=True) # Continuously train for 1000 steps & predict on test set. diff --git a/tensorflow/examples/skflow/text_classification_cnn.py b/tensorflow/examples/skflow/text_classification_cnn.py index c42a12819e..de14695e0d 100644 --- a/tensorflow/examples/skflow/text_classification_cnn.py +++ b/tensorflow/examples/skflow/text_classification_cnn.py @@ -20,12 +20,12 @@ from sklearn import metrics import pandas import tensorflow as tf -from tensorflow.contrib import skflow +from tensorflow.contrib import learn ### Training data # Downloads, unpacks and reads DBpedia dataset. -dbpedia = skflow.datasets.load_dataset('dbpedia') +dbpedia = learn.datasets.load_dataset('dbpedia') X_train, y_train = pandas.DataFrame(dbpedia.train.data)[1], pandas.Series(dbpedia.train.target) X_test, y_test = pandas.DataFrame(dbpedia.test.data)[1], pandas.Series(dbpedia.test.target) @@ -33,7 +33,7 @@ X_test, y_test = pandas.DataFrame(dbpedia.test.data)[1], pandas.Series(dbpedia.t MAX_DOCUMENT_LENGTH = 100 -vocab_processor = skflow.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH) +vocab_processor = learn.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH) X_train = np.array(list(vocab_processor.fit_transform(X_train))) X_test = np.array(list(vocab_processor.transform(X_test))) @@ -57,12 +57,12 @@ def cnn_model(X, y): # This creates embeddings matrix of [n_words, EMBEDDING_SIZE] and then # maps word indexes of the sequence into [batch_size, sequence_length, # EMBEDDING_SIZE]. - word_vectors = skflow.ops.categorical_variable(X, n_classes=n_words, + word_vectors = learn.ops.categorical_variable(X, n_classes=n_words, embedding_size=EMBEDDING_SIZE, name='words') word_vectors = tf.expand_dims(word_vectors, 3) with tf.variable_scope('CNN_Layer1'): # Apply Convolution filtering on input sequence. - conv1 = skflow.ops.conv2d(word_vectors, N_FILTERS, FILTER_SHAPE1, padding='VALID') + conv1 = learn.ops.conv2d(word_vectors, N_FILTERS, FILTER_SHAPE1, padding='VALID') # Add a RELU for non linearity. conv1 = tf.nn.relu(conv1) # Max pooling across output of Convlution+Relu. @@ -72,15 +72,15 @@ def cnn_model(X, y): pool1 = tf.transpose(pool1, [0, 1, 3, 2]) with tf.variable_scope('CNN_Layer2'): # Second level of convolution filtering. - conv2 = skflow.ops.conv2d(pool1, N_FILTERS, FILTER_SHAPE2, + conv2 = learn.ops.conv2d(pool1, N_FILTERS, FILTER_SHAPE2, padding='VALID') # Max across each filter to get useful features for classification. pool2 = tf.squeeze(tf.reduce_max(conv2, 1), squeeze_dims=[1]) # Apply regular WX + B and classification. - return skflow.models.logistic_regression(pool2, y) + return learn.models.logistic_regression(pool2, y) -classifier = skflow.TensorFlowEstimator(model_fn=cnn_model, n_classes=15, +classifier = learn.TensorFlowEstimator(model_fn=cnn_model, n_classes=15, steps=100, optimizer='Adam', learning_rate=0.01, continue_training=True) # Continuously train for 1000 steps & predict on test set. diff --git a/tensorflow/examples/skflow/text_classification_save_restore.py b/tensorflow/examples/skflow/text_classification_save_restore.py index 71551762b9..3fcde2b308 100644 --- a/tensorflow/examples/skflow/text_classification_save_restore.py +++ b/tensorflow/examples/skflow/text_classification_save_restore.py @@ -21,12 +21,12 @@ from sklearn import metrics import pandas import tensorflow as tf -from tensorflow.contrib import skflow +from tensorflow.contrib import learn ### Training data # Downloads, unpacks and reads DBpedia dataset. -dbpedia = skflow.datasets.load_dataset('dbpedia') +dbpedia = learn.datasets.load_dataset('dbpedia') X_train, y_train = pandas.DataFrame(dbpedia.train.data)[1], pandas.Series(dbpedia.train.target) X_test, y_test = pandas.DataFrame(dbpedia.test.data)[1], pandas.Series(dbpedia.test.target) @@ -34,7 +34,7 @@ X_test, y_test = pandas.DataFrame(dbpedia.test.data)[1], pandas.Series(dbpedia.t MAX_DOCUMENT_LENGTH = 10 -vocab_processor = skflow.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH) +vocab_processor = learn.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH) X_train = np.array(list(vocab_processor.fit_transform(X_train))) X_test = np.array(list(vocab_processor.transform(X_test))) @@ -46,10 +46,10 @@ print('Total words: %d' % n_words) EMBEDDING_SIZE = 50 def average_model(X, y): - word_vectors = skflow.ops.categorical_variable(X, n_classes=n_words, + word_vectors = learn.ops.categorical_variable(X, n_classes=n_words, embedding_size=EMBEDDING_SIZE, name='words') features = tf.reduce_max(word_vectors, reduction_indices=1) - return skflow.models.logistic_regression(features, y) + return learn.models.logistic_regression(features, y) def rnn_model(X, y): """Recurrent neural network model to predict from sequence of words @@ -58,11 +58,11 @@ def rnn_model(X, y): # This creates embeddings matrix of [n_words, EMBEDDING_SIZE] and then # maps word indexes of the sequence into [batch_size, sequence_length, # EMBEDDING_SIZE]. - word_vectors = skflow.ops.categorical_variable(X, n_classes=n_words, + word_vectors = learn.ops.categorical_variable(X, n_classes=n_words, embedding_size=EMBEDDING_SIZE, name='words') # Split into list of embedding per word, while removing doc length dim. # word_list results to be a list of tensors [batch_size, EMBEDDING_SIZE]. - word_list = skflow.ops.split_squeeze(1, MAX_DOCUMENT_LENGTH, word_vectors) + word_list = learn.ops.split_squeeze(1, MAX_DOCUMENT_LENGTH, word_vectors) # Create a Gated Recurrent Unit cell with hidden size of EMBEDDING_SIZE. cell = tf.nn.rnn_cell.GRUCell(EMBEDDING_SIZE) # Create an unrolled Recurrent Neural Networks to length of @@ -71,13 +71,13 @@ def rnn_model(X, y): # Given encoding of RNN, take encoding of last step (e.g hidden size of the # neural network of last step) and pass it as features for logistic # regression over output classes. - return skflow.models.logistic_regression(encoding, y) + return learn.models.logistic_regression(encoding, y) model_path = '/tmp/skflow_examples/text_classification' if os.path.exists(model_path): - classifier = skflow.TensorFlowEstimator.restore(model_path) + classifier = learn.TensorFlowEstimator.restore(model_path) else: - classifier = skflow.TensorFlowEstimator(model_fn=rnn_model, n_classes=15, + classifier = learn.TensorFlowEstimator(model_fn=rnn_model, n_classes=15, steps=100, optimizer='Adam', learning_rate=0.01, continue_training=True) # Continuously train for 1000 steps diff --git a/tensorflow/examples/udacity/1_notmnist.ipynb b/tensorflow/examples/udacity/1_notmnist.ipynb index 2265445815..32ece419b0 100644 --- a/tensorflow/examples/udacity/1_notmnist.ipynb +++ b/tensorflow/examples/udacity/1_notmnist.ipynb @@ -110,11 +110,31 @@ }, "source": [ "url = 'http://commondatastorage.googleapis.com/books1000/'\n", + "last_percent_reported = None\n", "\n", + "def download_progress_hook(count, blockSize, totalSize):\n", + " \"\"\"A hook to report the progress of a download. This is mostly intended for users with\n", + " slow internet connections. Reports every 1% change in download progress.\n", + " \"\"\"\n", + " global last_percent_reported\n", + " percent = int(count * blockSize * 100 / totalSize)\n", + "\n", + " if last_percent_reported != percent:\n", + " if percent % 5 == 0:\n", + " sys.stdout.write(\"%s%%\" % percent)\n", + " sys.stdout.flush()\n", + " else:\n", + " sys.stdout.write(\".\")\n", + " sys.stdout.flush()\n", + " \n", + " last_percent_reported = percent\n", + " \n", "def maybe_download(filename, expected_bytes, force=False):\n", " \"\"\"Download a file if not present, and make sure it's the right size.\"\"\"\n", " if force or not os.path.exists(filename):\n", - " filename, _ = urlretrieve(url + filename, filename)\n", + " print('Attempting to download:', filename) \n", + " filename, _ = urlretrieve(url + filename, filename, reporthook=download_progress_hook)\n", + " print('\\nDownload Complete!')\n", " statinfo = os.stat(filename)\n", " if statinfo.st_size == expected_bytes:\n", " print('Found and verified', filename)\n", diff --git a/tensorflow/examples/udacity/5_word2vec.ipynb b/tensorflow/examples/udacity/5_word2vec.ipynb index 62dbec4e11..f932f62e28 100644 --- a/tensorflow/examples/udacity/5_word2vec.ipynb +++ b/tensorflow/examples/udacity/5_word2vec.ipynb @@ -446,6 +446,11 @@ " train_labels, num_sampled, vocabulary_size))\n", "\n", " # Optimizer.\n", + " # Note: The optimizer will optimize the softmax_weights AND the embeddings.\n", + " # This is because the embeddings are defined as a variable quantity and the\n", + " # optimizer's `minimize` method will by default modify all variable quantities \n", + " # that contribute to the tensor it is passed.\n", + " # See docs on `tf.train.Optimizer.minimize()` for more details.\n", " optimizer = tf.train.AdagradOptimizer(1.0).minimize(loss)\n", " \n", " # Compute the similarity between minibatch examples and all embeddings.\n", diff --git a/tensorflow/g3doc/api_docs/python/constant_op.md b/tensorflow/g3doc/api_docs/python/constant_op.md index 008174f9d6..1aaf39bd50 100644 --- a/tensorflow/g3doc/api_docs/python/constant_op.md +++ b/tensorflow/g3doc/api_docs/python/constant_op.md @@ -60,7 +60,7 @@ tf.zeros_like(tensor) ==> [[0, 0, 0], [0, 0, 0]] * <b>`tensor`</b>: A `Tensor`. * <b>`dtype`</b>: A type for the returned `Tensor`. Must be `float32`, `float64`, - `int8`, `int16`, `int32`, `int64`, `uint8`, or `complex64`. + `int8`, `int16`, `int32`, `int64`, `uint8`, `complex64`, or `complex128`. * <b>`name`</b>: A name for the operation (optional). @@ -119,7 +119,7 @@ tf.ones_like(tensor) ==> [[1, 1, 1], [1, 1, 1]] * <b>`tensor`</b>: A `Tensor`. * <b>`dtype`</b>: A type for the returned `Tensor`. Must be `float32`, `float64`, - `int8`, `int16`, `int32`, `int64`, `uint8`, or `complex64`. + `int8`, `int16`, `int32`, `int64`, `uint8`, `complex64` or `complex128`. * <b>`name`</b>: A name for the operation (optional). diff --git a/tensorflow/g3doc/api_docs/python/control_flow_ops.md b/tensorflow/g3doc/api_docs/python/control_flow_ops.md index 6a139fb6d3..da6262a6f7 100644 --- a/tensorflow/g3doc/api_docs/python/control_flow_ops.md +++ b/tensorflow/g3doc/api_docs/python/control_flow_ops.md @@ -417,7 +417,7 @@ Returns the truth value of (x == y) element-wise. ##### Args: -* <b>`x`</b>: A `Tensor`. Must be one of the following types: `half`, `float32`, `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`, `complex64`, `quint8`, `qint8`, `qint32`, `string`, `bool`. +* <b>`x`</b>: A `Tensor`. Must be one of the following types: `half`, `float32`, `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`, `complex64`, `complex128`, `quint8`, `qint8`, `qint32`, `string`. * <b>`y`</b>: A `Tensor`. Must have the same type as `x`. * <b>`name`</b>: A name for the operation (optional). @@ -435,7 +435,7 @@ Returns the truth value of (x != y) element-wise. ##### Args: -* <b>`x`</b>: A `Tensor`. Must be one of the following types: `half`, `float32`, `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`, `complex64`, `quint8`, `qint8`, `qint32`, `string`, `bool`. +* <b>`x`</b>: A `Tensor`. Must be one of the following types: `half`, `float32`, `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`, `complex64`, `complex128`, `quint8`, `qint8`, `qint32`, `string`. * <b>`y`</b>: A `Tensor`. Must have the same type as `x`. * <b>`name`</b>: A name for the operation (optional). diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.diag.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.diag.md index 394e351ad7..94eb6a6717 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.diag.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.diag.md @@ -23,7 +23,7 @@ tf.diag(diagonal) ==> [[1, 0, 0, 0] ##### Args: -* <b>`diagonal`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`. +* <b>`diagonal`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`, `complex64`. Rank k tensor where k is at most 3. * <b>`name`</b>: A name for the operation (optional). diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.diag_part.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.diag_part.md index 182aa5264b..249eb80e50 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.diag_part.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.diag_part.md @@ -24,7 +24,7 @@ tf.diag_part(input) ==> [1, 2, 3, 4] ##### Args: -* <b>`input`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`. +* <b>`input`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`, `complex64`. Rank k tensor where k is 2, 4, or 6. * <b>`name`</b>: A name for the operation (optional). diff --git a/tensorflow/g3doc/api_docs/python/math_ops.md b/tensorflow/g3doc/api_docs/python/math_ops.md index 5ee3dcead4..7fac1aad70 100644 --- a/tensorflow/g3doc/api_docs/python/math_ops.md +++ b/tensorflow/g3doc/api_docs/python/math_ops.md @@ -931,7 +931,7 @@ tf.diag(diagonal) ==> [[1, 0, 0, 0] ##### Args: -* <b>`diagonal`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`. +* <b>`diagonal`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`, `complex64`. Rank k tensor where k is at most 3. * <b>`name`</b>: A name for the operation (optional). @@ -968,7 +968,7 @@ tf.diag_part(input) ==> [1, 2, 3, 4] ##### Args: -* <b>`input`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`. +* <b>`input`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`, `complex64`. Rank k tensor where k is 2, 4, or 6. * <b>`name`</b>: A name for the operation (optional). diff --git a/tensorflow/g3doc/api_docs/python/nn.md b/tensorflow/g3doc/api_docs/python/nn.md index 4eabd326f0..8a40d364e1 100644 --- a/tensorflow/g3doc/api_docs/python/nn.md +++ b/tensorflow/g3doc/api_docs/python/nn.md @@ -163,7 +163,7 @@ case where both types are quantized. * <b>`value`</b>: A `Tensor` with type `float`, `double`, `int64`, `int32`, `uint8`, - `int16`, `int8`, or `complex64`. + `int16`, `int8`, `complex64` or `complex128`. * <b>`bias`</b>: A 1-D `Tensor` with size matching the last dimension of `value`. Must be the same type as `value` unless `value` is a quantized type, in which case a different quantized type may be used. @@ -186,7 +186,7 @@ Specifically, `y = 1 / (1 + exp(-x))`. ##### Args: -* <b>`x`</b>: A Tensor with type `float`, `double`, `int32`, `complex64`, `int64`, +* <b>`x`</b>: A Tensor with type `float`, `double`, `int32`, `complex64`, `complex128`, `int64`, or `qint32`. * <b>`name`</b>: A name for the operation (optional). @@ -205,7 +205,7 @@ Computes hyperbolic tangent of `x` element-wise. ##### Args: -* <b>`x`</b>: A Tensor with type `float`, `double`, `int32`, `complex64`, `int64`, +* <b>`x`</b>: A Tensor with type `float`, `double`, `int32`, `complex64`, `complex128`, `int64`, or `qint32`. * <b>`name`</b>: A name for the operation (optional). diff --git a/tensorflow/g3doc/api_docs/python/train.md b/tensorflow/g3doc/api_docs/python/train.md index bf270ddfd4..c2aa1c01fa 100644 --- a/tensorflow/g3doc/api_docs/python/train.md +++ b/tensorflow/g3doc/api_docs/python/train.md @@ -948,7 +948,7 @@ saver.restore(...checkpoint filename...) Creates a new ExponentialMovingAverage object. -The `Apply()` method has to be called to create shadow variables and add +The `apply()` method has to be called to create shadow variables and add ops to maintain moving averages. The optional `num_updates` parameter allows one to tweak the decay rate @@ -965,7 +965,7 @@ move faster. If passed, the actual decay rate used is: * <b>`decay`</b>: Float. The decay to use. * <b>`num_updates`</b>: Optional count of number of updates applied to variables. * <b>`name`</b>: String. Optional prefix name to use for the name of ops added in - `Apply()`. + `apply()`. - - - diff --git a/tensorflow/g3doc/get_started/os_setup.md b/tensorflow/g3doc/get_started/os_setup.md index 3a9b883bc4..50309a5df9 100644 --- a/tensorflow/g3doc/get_started/os_setup.md +++ b/tensorflow/g3doc/get_started/os_setup.md @@ -59,11 +59,11 @@ $ sudo easy_install pip Install TensorFlow: ```bash -# Ubuntu/Linux 64-bit, CPU only: +# Ubuntu/Linux 64-bit, CPU only, Python 2.7: $ sudo pip install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl -# Ubuntu/Linux 64-bit, GPU enabled. Requires CUDA toolkit 7.5 and CuDNN v4. For -# other versions, see "Install from sources" below. +# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7. Requires CUDA toolkit 7.5 and CuDNN v4. +# For other versions, see "Install from sources" below. $ sudo pip install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl # Mac OS X, CPU only: @@ -74,11 +74,11 @@ $ sudo pip install --upgrade https://storage.googleapis.com/tensorflow/mac/tenso For python3: ```bash -# Ubuntu/Linux 64-bit, CPU only: +# Ubuntu/Linux 64-bit, CPU only, Python 3.4: $ sudo pip3 install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl -# Ubuntu/Linux 64-bit, GPU enabled. Requires CUDA toolkit 7.5 and CuDNN v4. For -# other versions, see "Install from sources" below. +# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4. Requires CUDA toolkit 7.5 and CuDNN v4. +# For other versions, see "Install from sources" below. $ sudo pip3 install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl # Mac OS X, CPU only: @@ -134,11 +134,11 @@ $ source ~/tensorflow/bin/activate # If using bash $ source ~/tensorflow/bin/activate.csh # If using csh (tensorflow)$ # Your prompt should change -# Ubuntu/Linux 64-bit, CPU only: +# Ubuntu/Linux 64-bit, CPU only, Python 2.7: (tensorflow)$ pip install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl -# Ubuntu/Linux 64-bit, GPU enabled. Requires CUDA toolkit 7.5 and CuDNN v4. For -# other versions, see "Install from sources" below. +# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7. Requires CUDA toolkit 7.5 and CuDNN v4. +# For other versions, see "Install from sources" below. (tensorflow)$ pip install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl # Mac OS X, CPU only: @@ -152,11 +152,11 @@ $ source ~/tensorflow/bin/activate # If using bash $ source ~/tensorflow/bin/activate.csh # If using csh (tensorflow)$ # Your prompt should change -# Ubuntu/Linux 64-bit, CPU only: +# Ubuntu/Linux 64-bit, CPU only, Python 3.4: (tensorflow)$ pip3 install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl -# Ubuntu/Linux 64-bit, GPU enabled. Requires CUDA toolkit 7.5 and CuDNN v4. For -# other versions, see "Install from sources" below. +# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4. Requires CUDA toolkit 7.5 and CuDNN v4. +# For other versions, see "Install from sources" below. (tensorflow)$ pip3 install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl # Mac OS X, CPU only: @@ -225,15 +225,15 @@ Use the `--ignore-installed` flag to prevent errors about `easy_install`. $ source activate tensorflow (tensorflow)$ # Your prompt should change -# Ubuntu/Linux 64-bit, CPU only: -(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0rc0-cp27-none-linux_x86_64.whl +# Ubuntu/Linux 64-bit, CPU only, Python 2.7: +(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl -# Ubuntu/Linux 64-bit, GPU enabled. Requires CUDA toolkit 7.5 and CuDNN v4. For -# other versions, see "Install from sources" below. -(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0rc0-cp27-none-linux_x86_64.whl +# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7. Requires CUDA toolkit 7.5 and CuDNN v4. +# For other versions, see "Install from sources" below. +(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl # Mac OS X, CPU only: -(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/mac/tensorflow-0.8.0rc0-py2-none-any.whl +(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/mac/tensorflow-0.8.0-py2-none-any.whl ``` and again for Python 3: @@ -242,15 +242,15 @@ and again for Python 3: $ source activate tensorflow (tensorflow)$ # Your prompt should change -# Ubuntu/Linux 64-bit, CPU only: -(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0rc0-cp34-cp34m-linux_x86_64.whl +# Ubuntu/Linux 64-bit, CPU only, Python 3.4: +(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl -# Ubuntu/Linux 64-bit, GPU enabled. Requires CUDA toolkit 7.5 and CuDNN v4. For -# other versions, see "Install from sources" below. -(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0rc0-cp34-cp34m-linux_x86_64.whl +# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4. Requires CUDA toolkit 7.5 and CuDNN v4. +# For other versions, see "Install from sources" below. +(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl # Mac OS X, CPU only: -(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/mac/tensorflow-0.8.0rc0-py3-none-any.whl +(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/mac/tensorflow-0.8.0-py3-none-any.whl ``` With the conda environment activated, you can now @@ -309,9 +309,13 @@ After Docker is installed, launch a Docker container with the TensorFlow binary image as follows. ```bash -$ docker run -it gcr.io/tensorflow/tensorflow +$ docker run -it -p 8888:8888 gcr.io/tensorflow/tensorflow ``` +The option `-p 8888:8888` is used to publish the Docker container᾿s internal port to the host machine, in this case to ensure Jupyter notebook connection. + +The format of the port mapping `hostPort:containerPort`. You can speficy any valid port number for the host port but has to be `8888` for the container port portion. + If you're using a container with GPU support, some additional flags must be passed to expose the GPU device to the container. For the default config, we include a @@ -437,7 +441,10 @@ binary path. #### Install other dependencies ```bash -$ sudo apt-get install python-numpy swig python-dev +# For Python 2.7: +$ sudo apt-get install python-numpy swig python-dev python-wheel +# For Python 3.x: +$ sudo apt-get install python3-numpy swig python3-dev python3-wheel ``` #### Configure the installation @@ -627,7 +634,7 @@ $ bazel build -c opt --config=cuda //tensorflow/tools/pip_package:build_pip_pack $ bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg # The name of the .whl file will depend on your platform. -$ pip install /tmp/tensorflow_pkg/tensorflow-0.8.0-py2-none-linux_x86_64.whl +$ sudo pip install /tmp/tensorflow_pkg/tensorflow-0.8.0-py2-none-any.whl ``` ## Setting up TensorFlow for Development diff --git a/tensorflow/g3doc/how_tos/adding_an_op/index.md b/tensorflow/g3doc/how_tos/adding_an_op/index.md index 7b067ebb12..c1516ea487 100644 --- a/tensorflow/g3doc/how_tos/adding_an_op/index.md +++ b/tensorflow/g3doc/how_tos/adding_an_op/index.md @@ -815,7 +815,7 @@ expressions: ```c++ REGISTER_OP("BuiltInTypesExample") .Input("integers: int32") - .Input("complex_numbers: scomplex64"); + .Input("complex_numbers: complex64"); ``` * `<attr-type>`, where `<attr-type>` is the name of an [Attr](#attrs) with type diff --git a/tensorflow/g3doc/how_tos/distributed/index.md b/tensorflow/g3doc/how_tos/distributed/index.md index 2adb7b3eb0..88f334cb53 100644 --- a/tensorflow/g3doc/how_tos/distributed/index.md +++ b/tensorflow/g3doc/how_tos/distributed/index.md @@ -21,7 +21,7 @@ $ python ``` The -[`tf.train.Server.create_local_server()`](../../api_docs/train.md#Server.create_local_server) +[`tf.train.Server.create_local_server()`](../../api_docs/python/train.md#Server.create_local_server) method creates a single-process cluster, with an in-process server. ## Create a cluster @@ -110,7 +110,7 @@ which you'd like to see support, please raise a ## Specifying distributed devices in your model To place operations on a particular process, you can use the same -[`tf.device()`](https://www.tensorflow.org/versions/master/api_docs/python/framework.html#device) +[`tf.device()`](../../api_docs/python/framework.md#device) function that is used to specify whether ops run on the CPU or GPU. For example: ```python @@ -158,7 +158,7 @@ simplify the work of specifying a replicated model. Possible approaches include: for each `/job:worker` task, typically in the same process as the worker task. Each client builds a similar graph containing the parameters (pinned to `/job:ps` as before using - [`tf.train.replica_device_setter()`](../../api_docs/train.md#replica_device_setter) + [`tf.train.replica_device_setter()`](../../api_docs/python/train.md#replica_device_setter) to map them deterministically to the same tasks); and a single copy of the compute-intensive part of the model, pinned to the local task in `/job:worker`. @@ -199,7 +199,7 @@ FLAGS = tf.app.flags.FLAGS def main(_): ps_hosts = FLAGS.ps_hosts.split(",") - worker_hosts = FLAGS.worker_hosts(",") + worker_hosts = FLAGS.worker_hosts.split(",") # Create a cluster from the parameter server and worker hosts. cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts}) diff --git a/tensorflow/g3doc/how_tos/threading_and_queues/index.md b/tensorflow/g3doc/how_tos/threading_and_queues/index.md index 9d1fbddb68..ec17518da0 100644 --- a/tensorflow/g3doc/how_tos/threading_and_queues/index.md +++ b/tensorflow/g3doc/how_tos/threading_and_queues/index.md @@ -167,10 +167,10 @@ try: break sess.run(train_op) except Exception, e: - # Report exceptions to the coordinator. - coord.request_stop(e) - -# Terminate as usual. It is innocuous to request stop twice. -coord.request_stop() -coord.join(threads) + # Report exceptions to the coordinator. + coord.request_stop(e) +finally: + # Terminate as usual. It is innocuous to request stop twice. + coord.request_stop() + coord.join(threads) ``` diff --git a/tensorflow/g3doc/resources/dims_types.md b/tensorflow/g3doc/resources/dims_types.md index 8e55e609a0..76110190d9 100644 --- a/tensorflow/g3doc/resources/dims_types.md +++ b/tensorflow/g3doc/resources/dims_types.md @@ -62,6 +62,7 @@ Data type | Python type | Description `DT_STRING` | `tf.string` | Variable length byte arrays. Each element of a Tensor is a byte array. `DT_BOOL` | `tf.bool` | Boolean. `DT_COMPLEX64` | `tf.complex64` | Complex number made of two 32 bits floating points: real and imaginary parts. +`DT_COMPLEX128` | `tf.complex128` | Complex number made of two 64 bits floating points: real and imaginary parts. `DT_QINT8` | `tf.qint8` | 8 bits signed integer used in quantized Ops. `DT_QINT32` | `tf.qint32` | 32 bits signed integer used in quantized Ops. `DT_QUINT8` | `tf.quint8` | 8 bits unsigned integer used in quantized Ops. diff --git a/tensorflow/g3doc/tutorials/mandelbrot/index.md b/tensorflow/g3doc/tutorials/mandelbrot/index.md index 6b1c070791..8009e32d84 100755 --- a/tensorflow/g3doc/tutorials/mandelbrot/index.md +++ b/tensorflow/g3doc/tutorials/mandelbrot/index.md @@ -20,9 +20,8 @@ import numpy as np # Imports for visualization import PIL.Image -from cStringIO import StringIO -from IPython.display import clear_output, Image, display -import scipy.ndimage as nd +from io import BytesIO +from IPython.display import Image, display ``` Now we'll define a function to actually display the image once we have @@ -39,7 +38,7 @@ def DisplayFractal(a, fmt='jpeg'): img[a==a.max()] = 0 a = img a = np.uint8(np.clip(a, 0, 255)) - f = StringIO() + f = BytesIO() PIL.Image.fromarray(a).save(f, fmt) display(Image(data=f.getvalue())) ``` diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index 5bef94e789..ef025917b6 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -39,6 +39,10 @@ import traceback # the mode to RTLD_GLOBAL to make the symbols visible, so libraries such # as the ones implementing custom ops can have access to tensorflow # framework's symbols. +# one catch is that numpy *must* be imported before the call to +# setdlopenflags(), or there is a risk that later c modules will segfault +# when importing numpy (gh-2034). +import numpy as np _default_dlopen_flags = sys.getdlopenflags() sys.setdlopenflags(_default_dlopen_flags | ctypes.RTLD_GLOBAL) from tensorflow.python import pywrap_tensorflow diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py index d15d3562dd..74e91b826a 100644 --- a/tensorflow/python/kernel_tests/conv_ops_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_test.py @@ -465,6 +465,21 @@ class Conv2DTest(tf.test.TestCase): data_format=data_format, use_gpu=use_gpu) + def testConv2DStrideTwoFilterOneSameBackpropInput(self): + expected_output = [1.0, 0.0, 2.0, 0.0, + 0.0, 0.0, 0.0, 0.0, + 3.0, 0.0, 4.0, 0.0, + 0.0, 0.0, 0.0, 0.0] + for (data_format, use_gpu) in GetTestConfigs(): + self._RunAndVerifyBackpropInput(input_sizes=[1, 4, 4, 1], + filter_sizes=[1, 1, 1, 1], + output_sizes=[1, 2, 2, 1], + strides=[2, 2], + padding="SAME", + expected=expected_output, + data_format=data_format, + use_gpu=use_gpu) + # Testing for backprops def _RunAndVerifyBackpropFilter(self, input_sizes, filter_sizes, output_sizes, strides, padding, expected, data_format, @@ -568,6 +583,18 @@ class Conv2DTest(tf.test.TestCase): data_format=data_format, use_gpu=use_gpu) + def testConv2DStrideTwoFilterOneSameBackpropFilter(self): + expected_output = [78.] + for (data_format, use_gpu) in GetTestConfigs(): + self._RunAndVerifyBackpropFilter(input_sizes=[1, 4, 4, 1], + filter_sizes=[1, 1, 1, 1], + output_sizes=[1, 2, 2, 1], + strides=[2, 2], + padding="SAME", + expected=expected_output, + data_format=data_format, + use_gpu=use_gpu) + # Gradient checkers def ConstructAndTestGradient(self, batch, input_rows, input_cols, filter_rows, filter_cols, in_depth, out_depth, stride_rows, diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py index 8ce81f1b14..596390bf42 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_test.py @@ -64,12 +64,8 @@ class UnaryOpTest(tf.test.TestCase): else: self.assertAllClose(np_ans, tf_cpu) - # TODO(ebrevdo): consider adding polygamma function - if tf_func in (tf.digamma,): - return # Return early - - if x.dtype == np.complex64 and tf_func in ( - tf.sign, tf.sqrt, tf.rsqrt, tf.log): + if (x.dtype in (np.complex64, np.complex128) and + tf_func in (tf.sign, tf.sqrt, tf.rsqrt, tf.log)): return # Return early if x.dtype == np.float16: @@ -89,7 +85,7 @@ class UnaryOpTest(tf.test.TestCase): x_init_value=xf) jacob_n = jacob_n.astype(np.float16) self.assertAllClose(jacob_t, jacob_n, rtol=5e-3, atol=5e-3) - elif x.dtype == np.float32 or x.dtype == np.complex64: + elif x.dtype in (np.float32, np.complex64): s = list(np.shape(x)) jacob_t, jacob_n = tf.test.compute_gradient(inx, s, @@ -97,7 +93,7 @@ class UnaryOpTest(tf.test.TestCase): s, x_init_value=x) self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3) - elif x.dtype == np.float64: + elif x.dtype in (np.float64, np.complex128): s = list(np.shape(x)) jacob_t, jacob_n = tf.test.compute_gradient(inx, s, @@ -290,6 +286,30 @@ class UnaryOpTest(tf.test.TestCase): return x / np.abs(x) self._compareCpu(y, complex_sign, tf.sign) + def testComplex128Basic(self): + x = np.complex(1, 1) * np.arange(-3, 3).reshape(1, 3, 2).astype( + np.complex128) + y = x + 0.5 # no zeros + self._compareCpu(x, np.abs, tf.abs) + self._compareCpu(x, np.abs, _ABS) + self._compareCpu(x, np.negative, tf.neg) + self._compareCpu(x, np.negative, _NEG) + self._compareCpu(y, self._inv, tf.inv) + self._compareCpu(x, np.square, tf.square) + self._compareCpu(x, np.sqrt, tf.sqrt) + self._compareCpu(y, self._rsqrt, tf.rsqrt) + self._compareCpu(x, np.exp, tf.exp) + self._compareCpu(y, np.log, tf.log) + self._compareCpu(x, np.tanh, tf.tanh) + self._compareCpu(x, self._sigmoid, tf.sigmoid) + self._compareCpu(x, np.sin, tf.sin) + self._compareCpu(x, np.cos, tf.cos) + + # Numpy uses an incorrect definition of sign; use the right one instead. + def complex_sign(x): + return x / np.abs(x) + self._compareCpu(y, complex_sign, tf.sign) + class BinaryOpTest(tf.test.TestCase): @@ -397,10 +417,10 @@ class BinaryOpTest(tf.test.TestCase): def _compareBoth(self, x, y, np_func, tf_func): self._compareCpu(x, y, np_func, tf_func) if x.dtype in (np.float16, np.float32, np.float64): - if tf_func not in (_FLOORDIV, tf.floordiv, tf.igamma, tf.igammac): + if tf_func not in (_FLOORDIV, tf.floordiv, tf.igamma, tf.igammac, tf.zeta, tf.polygamma): self._compareGradientX(x, y, np_func, tf_func) self._compareGradientY(x, y, np_func, tf_func) - if tf_func in (tf.igamma, tf.igammac): + if tf_func in (tf.igamma, tf.igammac, tf.zeta, tf.polygamma): # These methods only support gradients in the second parameter self._compareGradientY(x, y, np_func, tf_func) self._compareGpu(x, y, np_func, tf_func) @@ -424,6 +444,10 @@ class BinaryOpTest(tf.test.TestCase): x_pos_small = np.linspace(0.1, 10, 15).reshape(1, 3, 5).astype(np.float32) self._compareBoth(a_pos_small, x_pos_small, special.gammainc, tf.igamma) self._compareBoth(a_pos_small, x_pos_small, special.gammaincc, tf.igammac) + # Need x > 1 + self._compareBoth(x_pos_small + 1, a_pos_small, special.zeta, tf.zeta) + n_small = np.arange(0, 15).reshape(1, 3, 5).astype(np.float32) + self._compareBoth(n_small, x_pos_small, special.polygamma, tf.polygamma) except ImportError as e: tf.logging.warn("Cannot test special functions: %s" % str(e)) @@ -520,6 +544,20 @@ class BinaryOpTest(tf.test.TestCase): self._compareCpu(x, y, np.multiply, _MUL) self._compareCpu(x, y + 0.1, np.true_divide, _TRUEDIV) + def testComplex128Basic(self): + x = np.complex(1, 1) * np.linspace(-10, 10, 6).reshape(1, 3, 2).astype( + np.complex128) + y = np.complex(1, 1) * np.linspace(20, -20, 6).reshape(1, 3, 2).astype( + np.complex128) + self._compareCpu(x, y, np.add, tf.add) + self._compareCpu(x, y, np.subtract, tf.sub) + self._compareCpu(x, y, np.multiply, tf.mul) + self._compareCpu(x, y + 0.1, np.true_divide, tf.truediv) + self._compareCpu(x, y, np.add, _ADD) + self._compareCpu(x, y, np.subtract, _SUB) + self._compareCpu(x, y, np.multiply, _MUL) + self._compareCpu(x, y + 0.1, np.true_divide, _TRUEDIV) + def testStringComparison(self): x = np.array([["abc", "bh"], ["c", ""]]) y = np.array([["abc", "bh"], ["def", "hi"]]) @@ -571,11 +609,13 @@ class BinaryOpTest(tf.test.TestCase): np.float64, np.int32, np.int64, - np.complex64 + np.complex64, + np.complex128, ] for dtype in dtypes: for (np_func, tf_func) in funcs: - if dtype == np.complex64 and tf_func in (_FLOORDIV, tf.floordiv): + if (dtype in (np.complex64, np.complex128) and + tf_func in (_FLOORDIV, tf.floordiv)): continue # floordiv makes no sense for complex numbers self._compareBCast(xs, ys, dtype, np_func, tf_func) self._compareBCast(ys, xs, dtype, np_func, tf_func) @@ -827,18 +867,18 @@ class ComparisonOpTest(tf.test.TestCase): for t in dtypes: for x in data: for y in data: - self.assertEqual(self._compare(tf.less, x, y, t), - x < y) - self.assertEqual(self._compare(tf.less_equal, x, y, t), - x <= y) - self.assertEqual(self._compare(tf.greater, x, y, t), - x > y) - self.assertEqual(self._compare(tf.greater_equal, x, y, t), - x >= y) - self.assertEqual(self._compare(tf.equal, x, y, t), - x == y) - self.assertEqual(self._compare(tf.not_equal, x, y, t), - x != y) + self.assertEqual(self._compare(tf.less, x, y, t), x < y) + self.assertEqual(self._compare(tf.less_equal, x, y, t), x <= y) + self.assertEqual(self._compare(tf.greater, x, y, t), x > y) + self.assertEqual(self._compare(tf.greater_equal, x, y, t), x >= y) + self.assertEqual(self._compare(tf.equal, x, y, t), x == y) + self.assertEqual(self._compare(tf.not_equal, x, y, t), x != y) + data = [-1, 0, 1, -1j, 1j, 1 + 1j, 1 - 1j] + for t in [np.complex64, np.complex128]: + for x in data: + for y in data: + self.assertEqual(self._compare(tf.equal, x, y, t), x == y) + self.assertEqual(self._compare(tf.not_equal, x, y, t), x != y) def _compareCpu(self, x, y, np_func, tf_func): np_ans = np_func(x, y) @@ -872,10 +912,9 @@ class ComparisonOpTest(tf.test.TestCase): self._compareBoth(xt, yt, np.equal, tf.equal) self._compareBoth(xt, yt, np.not_equal, tf.not_equal) # TODO(zhifengc): complex64 doesn't work on GPU yet. - self._compareCpu(x.astype(np.complex64), y.astype(np.complex64), - np.equal, tf.equal) - self._compareCpu(x.astype(np.complex64), y.astype(np.complex64), - np.not_equal, tf.not_equal) + for t in [np.complex64, np.complex128]: + self._compareCpu(x.astype(t), y.astype(t), np.equal, tf.equal) + self._compareCpu(x.astype(t), y.astype(t), np.not_equal, tf.not_equal) def _compareBCast(self, xs, ys, dtype, np_func, tf_func): x = np.linspace(-15, 15, np.prod(xs)).astype(dtype).reshape(xs) @@ -1115,7 +1154,7 @@ class SelectOpTest(tf.test.TestCase): x = np.random.rand(1, 3, 2) * 100 y = np.random.rand(1, 3, 2) * 100 for t in [np.float16, np.float32, np.float64, np.int32, np.int64, - np.complex64]: + np.complex64, np.complex128]: xt = x.astype(t) yt = y.astype(t) self._compare(c, xt, yt, use_gpu=False) @@ -1242,7 +1281,7 @@ class BatchSelectOpTest(tf.test.TestCase): x = np.random.rand(16, 2, 8) * 100 y = np.random.rand(16, 2, 8) * 100 for t in [np.float16, np.float32, np.float64, np.int32, np.int64, - np.complex64]: + np.complex64, np.complex128]: xt = x.astype(t) yt = y.astype(t) self._compare(c, xt, yt, use_gpu=False) @@ -1273,7 +1312,7 @@ class BatchSelectOpTest(tf.test.TestCase): x = np.random.rand(16, 3, 2) * 100 y = np.random.rand(16, 3, 2) * 100 for t in [np.float16, np.float32, np.float64, np.int32, np.int64, - np.complex64]: + np.complex64, np.complex128]: xt = x.astype(t) yt = y.astype(t) with self.assertRaises(ValueError): @@ -1394,6 +1433,7 @@ class MathOpsOverloadTest(tf.test.TestCase): tf.int32, tf.int64, tf.complex64, + tf.complex128, ] funcs = [ (np.add, _ADD), @@ -1405,7 +1445,7 @@ class MathOpsOverloadTest(tf.test.TestCase): ] for dtype in dtypes: for np_func, tf_func in funcs: - if dtype == tf.complex64 and tf_func == _FLOORDIV: + if dtype in (tf.complex64, tf.complex128) and tf_func == _FLOORDIV: continue # floordiv makes no sense for complex self._compareBinary(10, 5, dtype, np_func, tf_func) # Mod only works for int32 and int64. @@ -1537,10 +1577,17 @@ class ComplexMakeRealImagTest(tf.test.TestCase): self.assertShapeEqual(np_real, tf_real) self.assertShapeEqual(np_imag, tf_imag) - def testRealImag(self): + def testRealImag64(self): real = (np.arange(-3, 3) / 4.).reshape([1, 3, 2]).astype(np.float32) imag = (np.arange(-3, 3) / 5.).reshape([1, 3, 2]).astype(np.float32) - cplx = real + (1j) * imag + cplx = real + 1j * imag + self._compareRealImag(cplx, use_gpu=False) + self._compareRealImag(cplx, use_gpu=True) + + def testRealImag128(self): + real = (np.arange(-3, 3) / 4.).reshape([1, 3, 2]).astype(np.float64) + imag = (np.arange(-3, 3) / 5.).reshape([1, 3, 2]).astype(np.float64) + cplx = real + 1j * imag self._compareRealImag(cplx, use_gpu=False) self._compareRealImag(cplx, use_gpu=True) @@ -1553,10 +1600,17 @@ class ComplexMakeRealImagTest(tf.test.TestCase): self.assertAllEqual(np_ans, tf_ans) self.assertShapeEqual(np_ans, tf_conj) - def testConj(self): + def testConj64(self): real = (np.arange(-3, 3) / 4.).reshape([1, 3, 2]).astype(np.float32) imag = (np.arange(-3, 3) / 5.).reshape([1, 3, 2]).astype(np.float32) - cplx = real + (1j) * imag + cplx = real + 1j * imag + self._compareConj(cplx, use_gpu=False) + self._compareConj(cplx, use_gpu=True) + + def testConj128(self): + real = (np.arange(-3, 3) / 4.).reshape([1, 3, 2]).astype(np.float64) + imag = (np.arange(-3, 3) / 5.).reshape([1, 3, 2]).astype(np.float64) + cplx = real + 1j * imag self._compareConj(cplx, use_gpu=False) self._compareConj(cplx, use_gpu=True) @@ -1585,8 +1639,12 @@ class ComplexMakeRealImagTest(tf.test.TestCase): self.assertAllClose(jacob_t, jacob_n, rtol=epsilon, atol=epsilon) def testGradient(self): + # complex64 data = np.arange(1, 2, 0.10).reshape([5, 2]).astype(np.float32) self._compareGradient(data) + # complex128 + data = np.arange(1, 2, 0.10).reshape([5, 2]).astype(np.float64) + self._compareGradient(data) def _compareMulGradient(self, data): # data is a float matrix of shape [n, 4]. data[:, 0], data[:, 1], diff --git a/tensorflow/python/kernel_tests/diag_op_test.py b/tensorflow/python/kernel_tests/diag_op_test.py index 04be325dfd..20a7760fd2 100644 --- a/tensorflow/python/kernel_tests/diag_op_test.py +++ b/tensorflow/python/kernel_tests/diag_op_test.py @@ -165,6 +165,14 @@ class DiagTest(tf.test.TestCase): self.diagOp(x, np.float32, expected_ans) self.diagOp(x, np.float64, expected_ans) + def testRankOneComplexTensor(self): + x = np.array([1.1 + 1.1j, 2.2 + 2.2j, 3.3 + 3.3j], dtype = np.complex64) + expected_ans = np.array( + [[1.1 + 1.1j, 0 + 0j, 0 + 0j], + [0 + 0j, 2.2 + 2.2j, 0 + 0j], + [0 + 0j, 0 + 0j, 3.3 + 3.3j]], dtype = np.complex64) + self.diagOp(x, np.complex64, expected_ans) + def testRankTwoIntTensor(self): x = np.array([[1, 2, 3], [4, 5, 6]]) expected_ans = np.array( @@ -189,6 +197,19 @@ class DiagTest(tf.test.TestCase): self.diagOp(x, np.float32, expected_ans) self.diagOp(x, np.float64, expected_ans) + def testRankTwoComplexTensor(self): + x = np.array([[1.1 + 1.1j, 2.2 + 2.2j, 3.3 + 3.3j], + [4.4 + 4.4j, 5.5 + 5.5j, 6.6 + 6.6j]], dtype = np.complex64) + expected_ans = np.array( + [[[[1.1 + 1.1j, 0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j, 0 + 0j]], + [[0 + 0j, 2.2 + 2.2j, 0 + 0j], [0 + 0j, 0 + 0j, 0 + 0j]], + [[0 + 0j, 0 + 0j, 3.3 + 3.3j], [0 + 0j, 0 + 0j, 0 + 0j]]], + [[[0 + 0j, 0 + 0j, 0 + 0j], [4.4 + 4.4j, 0 + 0j, 0 + 0j]], + [[0 + 0j, 0 + 0j, 0 + 0j], [0 + 0j, 5.5 + 5.5j, 0 + 0j]], + [[0 + 0j, 0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j, 6.6 + 6.6j]]]], + dtype = np.complex64) + self.diagOp(x, np.complex64, expected_ans) + def testRankThreeFloatTensor(self): x = np.array([[[1.1, 2.2], [3.3, 4.4]], [[5.5, 6.6], [7.7, 8.8]]]) @@ -204,6 +225,30 @@ class DiagTest(tf.test.TestCase): self.diagOp(x, np.float32, expected_ans) self.diagOp(x, np.float64, expected_ans) + def testRankThreeComplexTensor(self): + x = np.array([[[1.1 + 1.1j, 2.2 + 2.2j], [3.3 + 3.3j, 4.4 + 4.4j]], + [[5.5 + 5.5j, 6.6 + 6.6j], [7.7 + 7.7j, 8.8 + 8.8j]]], + dtype = np.complex64) + expected_ans = np.array( + [[[[[[1.1 + 1.1j, 0 + 0j], [0 + 0j, 0 + 0j]], + [[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]]], + [[[0 + 0j, 2.2 + 2.2j], [0 + 0j, 0 + 0j]], + [[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]]]], + [[[[0 + 0j, 0 + 0j], [3.3 + 3.3j, 0 + 0j]], + [[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]]], + [[[0 + 0j, 0 + 0j], [0 + 0j, 4.4 + 4.4j]], + [[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]]]]], + [[[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], + [[5.5 + 5.5j, 0 + 0j], [0 + 0j, 0 + 0j]]], + [[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], + [[0 + 0j, 6.6 + 6.6j], [0 + 0j, 0 + 0j]]]], + [[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], + [[0 + 0j, 0 + 0j], [7.7 + 7.7j, 0 + 0j]]], + [[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], + [[0 + 0j, 0 + 0j], [0 + 0j, 8.8 + 8.8j]]]]]], + dtype = np.complex64) + self.diagOp(x, np.complex64, expected_ans) + class DiagPartOpTest(tf.test.TestCase): diff --git a/tensorflow/python/kernel_tests/gradient_checker.py b/tensorflow/python/kernel_tests/gradient_checker.py index 935446b1b5..5ca529cc73 100644 --- a/tensorflow/python/kernel_tests/gradient_checker.py +++ b/tensorflow/python/kernel_tests/gradient_checker.py @@ -188,7 +188,7 @@ def _compute_gradient(x, """Computes the theoretical and numerical jacobian.""" t = dtypes.as_dtype(x.dtype) allowed_types = [dtypes.float16, dtypes.float32, dtypes.float64, - dtypes.complex64] + dtypes.complex64, dtypes.complex128] assert t.base_dtype in allowed_types, "Don't support type %s for x" % t.name t2 = dtypes.as_dtype(y.dtype) assert t2.base_dtype in allowed_types, "Don't support type %s for y" % t2.name diff --git a/tensorflow/python/kernel_tests/matmul_op_test.py b/tensorflow/python/kernel_tests/matmul_op_test.py index 3d39af8d85..0e7f2efe61 100644 --- a/tensorflow/python/kernel_tests/matmul_op_test.py +++ b/tensorflow/python/kernel_tests/matmul_op_test.py @@ -68,10 +68,14 @@ class MatMulTest(tf.test.TestCase): self.assertAllEqual(np_ans.shape, tf_ans.shape) def _randMatrix(self, rows, cols, dtype): - if dtype is np.complex64: - real = self._randMatrix(rows, cols, np.float32) - imag = self._randMatrix(rows, cols, np.float32) - return real + np.complex(0, 1) * imag + if dtype in (np.complex64, np.complex128): + if dtype == np.complex64: + float_dtype = np.float32 + else: + float_dtype = np.float64 + real = self._randMatrix(rows, cols, float_dtype) + imag = self._randMatrix(rows, cols, float_dtype) + return real + 1j * imag else: return np.random.uniform(low=1.0, high=100.0, size=rows * cols).reshape( [rows, cols]).astype(dtype) @@ -106,11 +110,16 @@ class MatMulTest(tf.test.TestCase): y = np.arange(1., 3.).reshape([1, 2]).astype(np.int32) self._testCpuMatmul(x, y) - def testSComplexBasic(self): + def testComplex64Basic(self): x = np.arange(1., 5.).reshape([4, 1]).astype(np.complex64) y = np.arange(1., 3.).reshape([1, 2]).astype(np.complex64) self._testCpuMatmul(x, y) + def testComplex128Basic(self): + x = np.arange(1., 5.).reshape([4, 1]).astype(np.complex128) + y = np.arange(1., 3.).reshape([1, 2]).astype(np.complex128) + self._testCpuMatmul(x, y) + # Tests testing random sized matrices. def testFloatRandom(self): for _ in range(10): @@ -145,13 +154,20 @@ class MatMulTest(tf.test.TestCase): y = self._randMatrix(k, m, np.int32) self._testCpuMatmul(x, y) - def testSComplexRandom(self): + def testComplex64Random(self): for _ in range(10): n, k, m = np.random.randint(1, 100, size=3) x = self._randMatrix(n, k, np.complex64) y = self._randMatrix(k, m, np.complex64) self._testCpuMatmul(x, y) + def testComplex128Random(self): + for _ in range(10): + n, k, m = np.random.randint(1, 100, size=3) + x = self._randMatrix(n, k, np.complex128) + y = self._randMatrix(k, m, np.complex128) + self._testCpuMatmul(x, y) + # Test the cases that transpose the matrices before multiplying. # NOTE(keveman): The cases where only one of the inputs is # transposed are covered by tf.matmul's gradient function. diff --git a/tensorflow/python/kernel_tests/padding_fifo_queue_test.py b/tensorflow/python/kernel_tests/padding_fifo_queue_test.py index d3fecf4f44..bdf0f5778f 100644 --- a/tensorflow/python/kernel_tests/padding_fifo_queue_test.py +++ b/tensorflow/python/kernel_tests/padding_fifo_queue_test.py @@ -1259,7 +1259,7 @@ class PaddingFIFOQueueTest(tf.test.TestCase): def testDtypes(self): with self.test_session() as sess: dtypes = [tf.float32, tf.float64, tf.int32, tf.uint8, tf.int16, tf.int8, - tf.int64, tf.bool, tf.complex64] + tf.int64, tf.bool, tf.complex64, tf.complex128] shape = (32, 4, 128) q = tf.PaddingFIFOQueue(32, dtypes, [shape[1:]] * len(dtypes)) diff --git a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py index b3b25bf031..cf0bcbd50f 100644 --- a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py @@ -79,6 +79,7 @@ class SparseTensorDenseMatMulTest(tf.test.TestCase): self._testBasic(np.float32) self._testBasic(np.float64) self._testBasic(np.complex64) + self._testBasic(np.complex128) # Tests setting one dimension to be a high value. def _testLarge(self, np_dtype): @@ -102,6 +103,7 @@ class SparseTensorDenseMatMulTest(tf.test.TestCase): self._testLarge(np.float32) self._testLarge(np.float64) self._testLarge(np.complex64) + self._testLarge(np.complex128) # Tests random sized matrices. def testFloatRandom(self): diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py index e03c54c953..0389f27622 100644 --- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py +++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py @@ -77,6 +77,7 @@ class TensorArrayCPUTest(tf.test.TestCase): self._testTensorArrayWritePack(tf.int32) self._testTensorArrayWritePack(tf.int64) self._testTensorArrayWritePack(tf.complex64) + self._testTensorArrayWritePack(tf.complex128) self._testTensorArrayWritePack(tf.string) def _testTensorArrayWriteConcat(self, tf_dtype): @@ -111,6 +112,7 @@ class TensorArrayCPUTest(tf.test.TestCase): self._testTensorArrayWriteConcat(tf.int32) self._testTensorArrayWriteConcat(tf.int64) self._testTensorArrayWriteConcat(tf.complex64) + self._testTensorArrayWriteConcat(tf.complex128) self._testTensorArrayWriteConcat(tf.string) def testTensorArrayUnpackWrongMajorSizeFails(self): @@ -173,6 +175,7 @@ class TensorArrayCPUTest(tf.test.TestCase): self._testTensorArrayUnpackRead(tf.int32) self._testTensorArrayUnpackRead(tf.int64) self._testTensorArrayUnpackRead(tf.complex64) + self._testTensorArrayUnpackRead(tf.complex128) self._testTensorArrayUnpackRead(tf.string) def _testTensorArraySplitRead(self, tf_dtype): @@ -231,6 +234,7 @@ class TensorArrayCPUTest(tf.test.TestCase): self._testTensorArraySplitRead(tf.int32) self._testTensorArraySplitRead(tf.int64) self._testTensorArraySplitRead(tf.complex64) + self._testTensorArraySplitRead(tf.complex128) self._testTensorArraySplitRead(tf.string) def testTensorGradArrayWriteRead(self): @@ -468,8 +472,9 @@ class TensorArrayCPUTest(tf.test.TestCase): wb1_grad.flow.eval() def testTensorArrayWriteGradientAddMultipleAdds(self): - for dtype in [tf.int32, tf.int64, tf.float32, tf.float64, tf.complex64]: - self._testTensorArrayWriteGradientAddMultipleAdds(dtype) + for dtype in (tf.int32, tf.int64, tf.float32, + tf.float64, tf.complex64, tf.complex128): + self._testTensorArrayWriteGradientAddMultipleAdds(dtype) def testMultiTensorArray(self): with self.test_session(use_gpu=self._use_gpu): @@ -540,7 +545,8 @@ class TensorArrayCPUTest(tf.test.TestCase): self.assertAllEqual(c(-2.0), grad_vals[1]) def testTensorArrayGradientWriteRead(self): - for dtype in (np.float32, np.float64, np.int32, np.int64, np.complex64): + for dtype in (np.float32, np.float64, np.int32, + np.int64, np.complex64, np.complex128): self._testTensorArrayGradientWriteReadType(dtype) def testTensorArrayGradientWritePackConcatAndRead(self): diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py index b2cad3ee05..6ea8ae3280 100644 --- a/tensorflow/python/kernel_tests/xent_op_test.py +++ b/tensorflow/python/kernel_tests/xent_op_test.py @@ -130,10 +130,9 @@ class XentTest(tf.test.TestCase): np.array([[0., 0., 0., 1.], [0., .5, .5, 0.]]).astype(np.float32)) def testDouble(self): - self._testXent( + self._testAll( np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float64), - np.array([[0., 0., 0., 1.], [0., .5, .5, 0.]]).astype(np.float64), - use_gpu=False) + np.array([[0., 0., 0., 1.], [0., .5, .5, 0.]]).astype(np.float64)) def testGradient(self): with self.test_session(): diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index ff340d7f87..9ac80960b2 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -516,7 +516,7 @@ def boolean_mask(tensor, mask, name="boolean_mask"): ```python # 2-D example - a = [[1, 2], [3, 4], [5, 6]] + tensor = [[1, 2], [3, 4], [5, 6]] mask = [True, False, True] boolean_mask(tensor, mask) ==> [[1, 2], [5, 6]] ``` diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index ab71e28cab..009ed8653a 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -287,8 +287,11 @@ def _LgammaGrad(op, grad): @ops.RegisterGradient("Digamma") -def _DigammaGrad(op, grad): # pylint: disable=unused-argument - raise NotImplementedError("grad(Digamma) == Polygamma(1) is not implemented") +def _DigammaGrad(op, grad): + """Compute gradient of the digamma function with respect to its argument.""" + x = op.inputs[0] + with ops.control_dependencies([grad.op]): + return grad * math_ops.polygamma(array_ops.constant(1, dtype=x.dtype), x) @ops.RegisterGradient("Igamma") @@ -314,6 +317,40 @@ def _IgammacGrad(op, grad): return [-1 * g if g is not None else None for g in _IgammaGrad(op, grad)] +@ops.RegisterGradient("Zeta") +def _ZetaGrad(op, grad): + """Returns gradient of zeta(x, q) with respect to x and q.""" + # TODO(tillahoffmann): Add derivative with respect to x + x = op.inputs[0] + q = op.inputs[1] + # Broadcast gradients + sx = array_ops.shape(x) + sq = array_ops.shape(q) + unused_rx, rq = gen_array_ops._broadcast_gradient_args(sx, sq) + # Evaluate gradient + with ops.control_dependencies([grad.op]): + partial_q = -x * math_ops.zeta(x + 1, q) + return (None, + array_ops.reshape(math_ops.reduce_sum(partial_q * grad, rq), sq)) + + +@ops.RegisterGradient("Polygamma") +def _PolygammaGrad(op, grad): + """Returns gradient of psi(n, x) with respect to n and x.""" + # TODO(tillahoffmann): Add derivative with respect to n + n = op.inputs[0] + x = op.inputs[1] + # Broadcast gradients + sn = array_ops.shape(n) + sx = array_ops.shape(x) + unused_rn, rx = gen_array_ops._broadcast_gradient_args(sn, sx) + # Evaluate gradient + with ops.control_dependencies([grad.op]): + partial_x = math_ops.polygamma(n + 1, x) + return (None, + array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx)) + + @ops.RegisterGradient("Sigmoid") def _SigmoidGrad(op, grad): """Returns grad * sigmoid(x) * (1 - sigmoid(x)).""" diff --git a/tensorflow/python/ops/math_grad_test.py b/tensorflow/python/ops/math_grad_test.py index 39b7d1193f..04ee61d640 100644 --- a/tensorflow/python/ops/math_grad_test.py +++ b/tensorflow/python/ops/math_grad_test.py @@ -65,7 +65,7 @@ class AbsOpTest(tf.test.TestCase): def _testGrad(self, shape, dtype=None, max_error=None, bias=None, sigma=None): np.random.seed(7) - if dtype == tf.complex64: + if dtype in (tf.complex64, tf.complex128): value = tf.complex(self._biasedRandN(shape, bias=bias, sigma=sigma), self._biasedRandN(shape, bias=bias, sigma=sigma)) else: @@ -74,7 +74,7 @@ class AbsOpTest(tf.test.TestCase): for use_gpu in [True, False]: with self.test_session(use_gpu=use_gpu): - if dtype == tf.complex64: + if dtype in (tf.complex64, tf.complex128): output = tf.complex_abs(value) else: output = tf.abs(value) diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index c622d83490..3075c8b9ef 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -58,6 +58,8 @@ mathematical functions to your graph. @@squared_difference @@igamma @@igammac +@@zeta +@@polygamma ## Matrix Math Functions @@ -240,11 +242,36 @@ def abs(x, name=None): """ with ops.op_scope([x], name, "Abs") as name: x = ops.convert_to_tensor(x, name="x") - if x.dtype == dtypes.complex64: - return gen_math_ops.complex_abs(x, name=name) + if x.dtype in (dtypes.complex64, dtypes.complex128): + return gen_math_ops.complex_abs(x, Tout=x.dtype.real_dtype, name=name) return gen_math_ops._abs(x, name=name) +def complex_abs(x, name=None): + r"""Computes the complex absolute value of a tensor. + + Given a tensor `x` of complex numbers, this operation returns a tensor of type + `float` or `double` that is the absolute value of each element in `x`. All + elements in `x` must be complex numbers of the form \\(a + bj\\). The + absolute value is computed as \\( \sqrt{a^2 + b^2}\\). + + For example: + + ``` + # tensor 'x' is [[-2.25 + 4.75j], [-3.25 + 5.75j]] + tf.complex_abs(x) ==> [5.25594902, 6.60492229] + ``` + + Args: + x: A `Tensor` of type `complex64` or `complex128`. + name: A name for the operation (optional). + + Returns: + A `Tensor` of type `float32` or `float64`. + """ + return gen_math_ops.complex_abs(x, Tout=x.dtype.real_dtype, name=name) + + def scalar_mul(scalar, x): """Multiplies a scalar times a `Tensor` or `IndexedSlices` object. @@ -302,29 +329,94 @@ def complex(real, imag, name=None): Given a tensor `real` representing the real part of a complex number, and a tensor `imag` representing the imaginary part of a complex number, this - operation computes complex numbers elementwise of the form \\\\(a + bj\\\\), - where *a* represents the `real` part and *b* represents the `imag` part. + operation returns complex numbers elementwise of the form \\(a + bj\\), where + *a* represents the `real` part and *b* represents the `imag` part. - The input tensors `real` and `imag` must be the same shape. + The input tensors `real` and `imag` must have the same shape. For example: ``` # tensor 'real' is [2.25, 3.25] # tensor `imag` is [4.75, 5.75] - tf.complex(real, imag) ==> [[2.25 + 4.74j], [3.25 + 5.75j]] + tf.complex(real, imag) ==> [[2.25 + 4.75j], [3.25 + 5.75j]] ``` Args: - real: A `Tensor` of type `float`. - imag: A `Tensor` of type `float`. + real: A `Tensor`. Must be one of the following types: `float32`, `float64`. + imag: A `Tensor`. Must have the same type as `real`. name: A name for the operation (optional). Returns: - A `Tensor` of type `complex64`. + A `Tensor` of type `complex64` or `complex128`. """ + real = ops.convert_to_tensor(real, name="real") + imag = ops.convert_to_tensor(imag, name="imag") with ops.op_scope([real, imag], name, "Complex") as name: - return gen_math_ops._complex(real, imag, name=name) + input_types = (real.dtype, imag.dtype) + if input_types == (dtypes.float64, dtypes.float64): + Tout = dtypes.complex128 + elif input_types == (dtypes.float32, dtypes.float32): + Tout = dtypes.complex64 + else: + raise TypeError("Types of real and imag don't match: " + "{} {}".format(real.dtype.name, imag.dtype.name)) + return gen_math_ops._complex(real, imag, Tout=Tout, name=name) + + +def real(input, name=None): + """Returns the real part of a complex number. + + Given a tensor `input` of complex numbers, this operation returns a tensor of + type `float` or `double` that is the real part of each element in `input`. + All elements in `input` must be complex numbers of the form \\(a + bj\\), + where *a* is the real part returned by this operation and *b* is the + imaginary part. + + For example: + + ``` + # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] + tf.real(input) ==> [-2.25, 3.25] + ``` + + Args: + input: A `Tensor`. Must be one of the following types: `complex64`, + `complex128`. + name: A name for the operation (optional). + + Returns: + A `Tensor` of type `float` or `double`. + """ + with ops.op_scope([input], name, "Real") as name: + return gen_math_ops.real(input, Tout=input.dtype.real_dtype, name=name) + + +def imag(input, name=None): + """Returns the imaginary part of a complex number. + + Given a tensor `input` of complex numbers, this operation returns a tensor of + type `float` or `double` that is the imaginary part of each element in + `input`. All elements in `input` must be complex numbers of the form \\(a + + bj\\), where *a* is the real part and *b* is the imaginary part returned by + this operation. + + For example: + + ``` + # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] + tf.imag(input) ==> [4.75, 5.75] + ``` + + Args: + input: A `Tensor`. Must be one of the following types: `complex64`, `complex128`. + name: A name for the operation (optional). + + Returns: + A `Tensor` of type `float` or `double`. + """ + with ops.op_scope([input], name, "Imag") as name: + return gen_math_ops.imag(input, Tout=input.dtype.real_dtype, name=name) def round(x, name=None): @@ -559,6 +651,7 @@ _TRUEDIV_TABLE = { dtypes.float32: None, dtypes.float64: None, dtypes.complex64: None, + dtypes.complex128: None, } @@ -1379,6 +1472,8 @@ ops.RegisterShape("BatchIFFT3D")(common_shapes.unchanged_shape) @ops.RegisterShape("GreaterEqual") @ops.RegisterShape("Igamma") @ops.RegisterShape("Igammac") +@ops.RegisterShape("Zeta") +@ops.RegisterShape("Polygamma") @ops.RegisterShape("Less") @ops.RegisterShape("LessEqual") @ops.RegisterShape("LogicalAnd") diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 6555096eae..662405d924 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -68,15 +68,15 @@ def atrous_conv2d(value, filters, rate, padding, name=None): the amount of computation. For a description of atrous convolution and how it can be used for dense - feature extraction, please see: (Semantic Image Segmentation with Deep - Convolutional Nets and Fully Connected CRFs)[http://arxiv.org/abs/1412.7062]. - The same operation is investigated further in (Multi-Scale Context Aggregation - by Dilated Convolutions)[http://arxiv.org/abs/1511.07122]. Previous works + feature extraction, please see: [Semantic Image Segmentation with Deep + Convolutional Nets and Fully Connected CRFs](http://arxiv.org/abs/1412.7062). + The same operation is investigated further in [Multi-Scale Context Aggregation + by Dilated Convolutions](http://arxiv.org/abs/1511.07122). Previous works that effectively use atrous convolution in different ways are, among others, - (OverFeat: Integrated Recognition, Localization and Detection using - Convolutional Networks) [http://arxiv.org/abs/1312.6229] and (Fast Image - Scanning with Deep Max-Pooling Convolutional Neural Networks) - [http://arxiv.org/abs/1302.1700]. Atrous convolution is also closely related + [OverFeat: Integrated Recognition, Localization and Detection using + Convolutional Networks](http://arxiv.org/abs/1312.6229) and [Fast Image + Scanning with Deep Max-Pooling Convolutional Neural Networks] + (http://arxiv.org/abs/1302.1700). Atrous convolution is also closely related to the so-called noble identities in multi-rate signal processing. There are many different ways to implement atrous convolution (see the refs @@ -227,8 +227,8 @@ def conv2d_transpose(value, name=None): """The transpose of `conv2d`. - This operation is sometimes called "deconvolution" after (Deconvolutional - Networks)[http://www.matthewzeiler.com/pubs/cvpr2010/cvpr2010.pdf], but is + This operation is sometimes called "deconvolution" after [Deconvolutional + Networks](http://www.matthewzeiler.com/pubs/cvpr2010/cvpr2010.pdf), but is actually the transpose (gradient) of `conv2d` rather than an actual deconvolution. diff --git a/tensorflow/python/ops/rnn_cell.py b/tensorflow/python/ops/rnn_cell.py index c6fbea1933..bfd0758883 100644 --- a/tensorflow/python/ops/rnn_cell.py +++ b/tensorflow/python/ops/rnn_cell.py @@ -120,10 +120,11 @@ class RNNCell(object): class BasicRNNCell(RNNCell): """The most basic RNN cell.""" - def __init__(self, num_units, input_size=None): + def __init__(self, num_units, input_size=None, activation=tanh): if input_size is not None: logging.warn("%s: The input_size parameter is deprecated." % self) self._num_units = num_units + self._activation = activation @property def state_size(self): @@ -134,19 +135,20 @@ class BasicRNNCell(RNNCell): return self._num_units def __call__(self, inputs, state, scope=None): - """Most basic RNN: output = new_state = tanh(W * input + U * state + B).""" + """Most basic RNN: output = new_state = activation(W * input + U * state + B).""" with vs.variable_scope(scope or type(self).__name__): # "BasicRNNCell" - output = tanh(_linear([inputs, state], self._num_units, True)) + output = self._activation(_linear([inputs, state], self._num_units, True)) return output, output class GRUCell(RNNCell): """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).""" - def __init__(self, num_units, input_size=None): + def __init__(self, num_units, input_size=None, activation=tanh): if input_size is not None: logging.warn("%s: The input_size parameter is deprecated." % self) self._num_units = num_units + self._activation = activation @property def state_size(self): @@ -165,7 +167,8 @@ class GRUCell(RNNCell): 2 * self._num_units, True, 1.0)) r, u = sigmoid(r), sigmoid(u) with vs.variable_scope("Candidate"): - c = tanh(_linear([inputs, r * state], self._num_units, True)) + c = self._activation(_linear([inputs, r * state], + self._num_units, True)) new_h = u * state + (1 - u) * c return new_h, new_h @@ -185,7 +188,7 @@ class BasicLSTMCell(RNNCell): """ def __init__(self, num_units, forget_bias=1.0, input_size=None, - state_is_tuple=False): + state_is_tuple=False, activation=tanh): """Initialize the basic LSTM cell. Args: @@ -195,6 +198,7 @@ class BasicLSTMCell(RNNCell): state_is_tuple: If True, accepted and returned states are 2-tuples of the `c_state` and `m_state`. By default (False), they are concatenated along the column axis. This default behavior will soon be deprecated. + activation: Activation function of the inner states. """ if not state_is_tuple: logging.warn( @@ -205,6 +209,7 @@ class BasicLSTMCell(RNNCell): self._num_units = num_units self._forget_bias = forget_bias self._state_is_tuple = state_is_tuple + self._activation = activation @property def state_size(self): @@ -228,8 +233,9 @@ class BasicLSTMCell(RNNCell): # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = array_ops.split(1, 4, concat) - new_c = c * sigmoid(f + self._forget_bias) + sigmoid(i) * tanh(j) - new_h = tanh(new_c) * sigmoid(o) + new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) * + self._activation(j)) + new_h = self._activation(new_c) * sigmoid(o) if self._state_is_tuple: new_state = (new_c, new_h) @@ -300,7 +306,8 @@ class LSTMCell(RNNCell): use_peepholes=False, cell_clip=None, initializer=None, num_proj=None, num_unit_shards=1, num_proj_shards=1, - forget_bias=1.0, state_is_tuple=False): + forget_bias=1.0, state_is_tuple=False, + activation=tanh): """Initialize the parameters for an LSTM cell. Args: @@ -323,6 +330,7 @@ class LSTMCell(RNNCell): state_is_tuple: If True, accepted and returned states are 2-tuples of the `c_state` and `m_state`. By default (False), they are concatenated along the column axis. This default behavior will soon be deprecated. + activation: Activation function of the inner states. """ if not state_is_tuple: logging.warn( @@ -339,6 +347,7 @@ class LSTMCell(RNNCell): self._num_proj_shards = num_proj_shards self._forget_bias = forget_bias self._state_is_tuple = state_is_tuple + self._activation = activation if num_proj: self._state_size = ( @@ -420,9 +429,10 @@ class LSTMCell(RNNCell): if self._use_peepholes: c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev + - sigmoid(i + w_i_diag * c_prev) * tanh(j)) + sigmoid(i + w_i_diag * c_prev) * self._activation(j)) else: - c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * tanh(j)) + c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * + self._activation(j)) if self._cell_clip is not None: # pylint: disable=invalid-unary-operand-type @@ -430,9 +440,9 @@ class LSTMCell(RNNCell): # pylint: enable=invalid-unary-operand-type if self._use_peepholes: - m = sigmoid(o + w_o_diag * c) * tanh(c) + m = sigmoid(o + w_o_diag * c) * self._activation(c) else: - m = sigmoid(o) * tanh(c) + m = sigmoid(o) * self._activation(c) if self._num_proj is not None: concat_w_proj = _get_concat_variable( diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py index 1e75bac23c..a3f4d536d8 100644 --- a/tensorflow/python/training/moving_averages.py +++ b/tensorflow/python/training/moving_averages.py @@ -210,7 +210,7 @@ class ExponentialMovingAverage(object): def __init__(self, decay, num_updates=None, name="ExponentialMovingAverage"): """Creates a new ExponentialMovingAverage object. - The `Apply()` method has to be called to create shadow variables and add + The `apply()` method has to be called to create shadow variables and add ops to maintain moving averages. The optional `num_updates` parameter allows one to tweak the decay rate @@ -225,7 +225,7 @@ class ExponentialMovingAverage(object): decay: Float. The decay to use. num_updates: Optional count of number of updates applied to variables. name: String. Optional prefix name to use for the name of ops added in - `Apply()`. + `apply()`. """ self._decay = decay self._num_updates = num_updates diff --git a/tensorflow/tools/docker/README.md b/tensorflow/tools/docker/README.md index fba6c3144a..be54f97b9e 100644 --- a/tensorflow/tools/docker/README.md +++ b/tensorflow/tools/docker/README.md @@ -40,10 +40,14 @@ accomplished via $ export CUDA_SO=$(\ls /usr/lib/x86_64-linux-gnu/libcuda.* | xargs -I{} echo '-v {}:{}') $ export DEVICES=$(\ls /dev/nvidia* | xargs -I{} echo '--device {}:{}') - $ docker run -it -p 8888:8888 $CUDA_SO $DEVICES gcr.io/tensorflow/tensorflow-devel-gpu + $ docker run -it -p 8888:8888 $CUDA_SO $DEVICES gcr.io/tensorflow/tensorflow-gpu Alternately, you can use the `docker_run_gpu.sh` script in this directory. +In order to set Jupyter Notebook to require a password, the `-e PASSWORD=pass` option must be provided + + $ docker run -it -p 8888:8888 $CUDA_SO $DEVICES -e PASSWORD=pass gcr.io/tensorflow/tensorflow-gpu + ## Rebuilding the containers Just pick the dockerfile corresponding to the container you want to build, and run; diff --git a/tensorflow/tools/docker/jupyter_notebook_config.py b/tensorflow/tools/docker/jupyter_notebook_config.py index 9bf6f98ce3..bcc4234459 100644 --- a/tensorflow/tools/docker/jupyter_notebook_config.py +++ b/tensorflow/tools/docker/jupyter_notebook_config.py @@ -12,8 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +import os +from IPython.lib import passwd c.NotebookApp.ip = '*' c.NotebookApp.port = 8888 c.NotebookApp.open_browser = False c.MultiKernelManager.default_kernel_name = 'python2' + +# sets a password if PASSWORD is set in the environment +if 'PASSWORD' in os.environ: + c.NotebookApp.password = passwd(os.environ['PASSWORD']) + del os.environ['PASSWORD'] diff --git a/third_party/gpus/cuda/platform.bzl b/third_party/gpus/cuda/platform.bzl index 20ab441bf4..06f3d0cff4 100644 --- a/third_party/gpus/cuda/platform.bzl +++ b/third_party/gpus/cuda/platform.bzl @@ -1,7 +1,5 @@ CUDA_VERSION = "" - CUDNN_VERSION = "" - PLATFORM = "" def cuda_sdk_version(): |