aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Martin Wicke <wicke@google.com>2016-05-30 13:44:02 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-30 15:04:04 -0700
commitd8c888d23f27a6072fe7e87da95ce8ea51bef3a0 (patch)
tree0168bcd18ea785ee1ee57914c5c9d96d4c359d4c
parentab8b4b572daa87bf0c9d53ad34cd71d10dd49460 (diff)
API restructure to unify pieces and support a wider variety of models (such as multi head, unsupervised, etc)
Done: Estimator now uses (features, targets, mode) -> (prediction, loss, train_op). Moved optimize_loss into TensorFlowEstimator (it's now a "simple" Estimator). Removed classification flag from Estimator. Todo: Integrate BaseEstimator and Estimator. Add Classifier and Regressor subclasses. Make sure TensorFlowEstimator supports multi-feature. Integrate TensorFlowDNN/Linear with DNN/Linear Classifiers and Regressors. Closes #2551. Change: 123595769
-rw-r--r--tensorflow/contrib/learn/BUILD12
-rw-r--r--tensorflow/contrib/learn/__init__.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/README.md2
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/__init__.py2
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/_sklearn.py7
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/base.py102
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/classifier.py82
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/classifier_test.py74
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn.py21
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py18
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py160
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator_test.py113
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/linear.py27
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/rnn.py8
-rw-r--r--tensorflow/contrib/learn/python/learn/graph_actions.py16
-rw-r--r--tensorflow/contrib/learn/python/learn/io/__init__.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/io/data_feeder.py6
-rw-r--r--tensorflow/contrib/learn/python/learn/models.py10
-rw-r--r--tensorflow/contrib/learn/python/learn/ops/conv_ops.py2
-rw-r--r--tensorflow/contrib/learn/python/learn/preprocessing/tests/test_categorical.py2
-rw-r--r--tensorflow/contrib/learn/python/learn/tests/test_io.py7
-rw-r--r--tensorflow/contrib/learn/python/learn/tests/test_nonlinear.py81
-rw-r--r--tensorflow/examples/skflow/iris.py5
-rw-r--r--tensorflow/examples/skflow/iris_custom_decay_dnn.py10
24 files changed, 496 insertions, 279 deletions
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD
index 27e57fb0ec..5b57df9a6a 100644
--- a/tensorflow/contrib/learn/BUILD
+++ b/tensorflow/contrib/learn/BUILD
@@ -221,6 +221,18 @@ py_test(
)
py_test(
+ name = "classifier_test",
+ size = "small",
+ srcs = ["python/learn/estimators/classifier_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":learn",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework_test_lib",
+ ],
+)
+
+py_test(
name = "dnn_linear_combined_test",
size = "medium",
srcs = ["python/learn/estimators/dnn_linear_combined_test.py"],
diff --git a/tensorflow/contrib/learn/__init__.py b/tensorflow/contrib/learn/__init__.py
index 2d178cf186..fec3183529 100644
--- a/tensorflow/contrib/learn/__init__.py
+++ b/tensorflow/contrib/learn/__init__.py
@@ -25,9 +25,13 @@ Train and evaluate TensorFlow models.
@@Estimator
@@ModeKeys
@@TensorFlowClassifier
+@@DNNClassifier
+@@DNNRegressor
@@TensorFlowDNNClassifier
@@TensorFlowDNNRegressor
@@TensorFlowEstimator
+@@LinearClassifier
+@@LinearRegressor
@@TensorFlowLinearClassifier
@@TensorFlowLinearRegressor
@@TensorFlowRNNClassifier
diff --git a/tensorflow/contrib/learn/python/learn/README.md b/tensorflow/contrib/learn/python/learn/README.md
index 2ab165f284..bde7bade5c 100644
--- a/tensorflow/contrib/learn/python/learn/README.md
+++ b/tensorflow/contrib/learn/python/learn/README.md
@@ -88,7 +88,7 @@ Example of 3 layer network with 10, 20 and 10 hidden units respectively:
from sklearn import datasets, metrics
iris = datasets.load_iris()
-classifier = learn.TensorFlowDNNClassifier(hidden_units=[10, 20, 10], n_classes=3)
+classifier = learn.DNNClassifier(hidden_units=[10, 20, 10], n_classes=3)
classifier.fit(iris.data, iris.target)
score = metrics.accuracy_score(iris.target, classifier.predict(iris.data))
print("Accuracy: %f" % score)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/__init__.py b/tensorflow/contrib/learn/python/learn/estimators/__init__.py
index ef0798a4be..657d2276a5 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/__init__.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/__init__.py
@@ -20,9 +20,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError
from tensorflow.contrib.learn.python.learn.estimators.autoencoder import TensorFlowDNNAutoencoder
from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowBaseTransformer
from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowEstimator
+from tensorflow.contrib.learn.python.learn.estimators.classifier import Classifier
from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNClassifier
from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNRegressor
from tensorflow.contrib.learn.python.learn.estimators.dnn import TensorFlowDNNClassifier
diff --git a/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py b/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py
index 2a99f40d84..c08bada2d0 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py
@@ -129,7 +129,7 @@ class _TransformerMixin():
"""Mixin class for all transformer estimators."""
-class _NotFittedError(ValueError, AttributeError):
+class NotFittedError(ValueError, AttributeError):
"""Exception class to raise if estimator is used before fitting.
This class inherits from both ValueError and AttributeError to help with
@@ -175,7 +175,7 @@ def _train_test_split(*args, **options):
train_size = 0.75
elif train_size is None:
train_size = 1 - test_size
- train_size *= args[0].shape[0]
+ train_size = int(train_size * args[0].shape[0])
np.random.seed(random_state)
indices = np.random.permutation(args[0].shape[0])
@@ -199,14 +199,13 @@ if TRY_IMPORT_SKLEARN:
try:
from sklearn.utils.validation import NotFittedError
except ImportError:
- NotFittedError = _NotFittedError
+ pass
else:
# Naive implementations of sklearn classes and functions.
BaseEstimator = _BaseEstimator
ClassifierMixin = _ClassifierMixin
RegressorMixin = _RegressorMixin
TransformerMixin = _TransformerMixin
- NotFittedError = _NotFittedError
accuracy_score = _accuracy_score
log_loss = None
mean_squared_error = _mean_squared_error
diff --git a/tensorflow/contrib/learn/python/learn/estimators/base.py b/tensorflow/contrib/learn/python/learn/estimators/base.py
index 952fa9ebfb..8e2bf70c90 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/base.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/base.py
@@ -22,13 +22,17 @@ from __future__ import print_function
import json
import os
+import types
+
+import six
from six import string_types
+from tensorflow.contrib import framework as contrib_framework
+from tensorflow.contrib import layers
from tensorflow.contrib.learn.python.learn.estimators import _sklearn
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError
from tensorflow.contrib.learn.python.learn.io.data_feeder import setup_train_data_feeder
-from tensorflow.contrib.learn.python.learn.utils import checkpoints
from tensorflow.python.framework import ops
from tensorflow.python.ops import constant_op
@@ -54,28 +58,6 @@ def _copy_dir(dir_in, dir_out):
gfile.Copy(name_in, name_out, overwrite=True)
-def _new_tf_model_fn(model_fn, class_weight):
- """Backward compatibility way of adding class weight and IS_TRAINING.
-
- TODO(ipolosukhin): Remove this function after new layers are available.
- Specifically:
- * dropout and batch norm should work via update ops.
- * class weights should be retrieved from weights column or hparams.
-
- Args:
- model_fn: Core model function.
- class_weight: Class weight.
- Returns:
- Model function.
- """
- def _model_fn(features, targets, mode):
- ops.get_default_graph().add_to_collection('IS_TRAINING', mode == 'train')
- if class_weight is not None:
- constant_op.constant(class_weight, name='class_weight')
- return model_fn(features, targets)
- return _model_fn
-
-
class TensorFlowEstimator(estimator.Estimator):
"""Base class for all TensorFlow estimators.
@@ -122,12 +104,17 @@ class TensorFlowEstimator(estimator.Estimator):
continue_training=False,
config=None,
verbose=1):
+ self.class_weight = class_weight
+ self.learning_rate = learning_rate
+ self.clip_gradients = clip_gradients
+ if isinstance(optimizer, six.string_types):
+ if optimizer not in layers.OPTIMIZER_CLS_NAMES:
+ raise ValueError(
+ 'Optimizer name should be one of [%s], you provided %s.' %
+ (', '.join(layers.OPTIMIZER_CLS_NAMES), optimizer))
+ self.optimizer = optimizer
super(TensorFlowEstimator, self).__init__(
- model_fn=_new_tf_model_fn(model_fn, class_weight),
- classification=n_classes > 1,
- learning_rate=learning_rate,
- optimizer=optimizer,
- clip_gradients=clip_gradients,
+ model_fn=self._get_model_fn(model_fn),
config=config)
self.n_classes = n_classes
self.batch_size = batch_size
@@ -275,27 +262,6 @@ class TensorFlowEstimator(estimator.Estimator):
"""
return self._graph.get_tensor_by_name(name)
- def get_tensor_value(self, name):
- """Returns value of the tensor give by name.
-
- Args:
- name: string, name of the tensor.
-
- Returns:
- Numpy array - value of the tensor.
- """
- if name.endswith(':0'):
- name = name[:-2]
- return checkpoints.load_variable(self.model_dir, name)
-
- def get_variable_names(self):
- """Returns list of all variable names in this model.
-
- Returns:
- List of names.
- """
- return [name for name, _ in checkpoints.list_variables(self.model_dir)]
-
def save(self, path):
"""Saves checkpoints and graph to given path.
@@ -383,11 +349,47 @@ class TensorFlowEstimator(estimator.Estimator):
result._restore(path)
return result
+ def _get_model_fn(self, model_fn):
+ """Backward compatibility way of adding class weight and IS_TRAINING.
+
+ TODO(ipolosukhin): Remove this function after new layers are available.
+ Specifically:
+ * dropout and batch norm should work via update ops.
+ * class weights should be retrieved from weights column or hparams.
+
+ Args:
+ model_fn: Core model function.
+ Returns:
+ Model function.
+ """
+ def _model_fn(features, targets, mode):
+ """Backward-compatible model_fn."""
+ ops.get_default_graph().add_to_collection('IS_TRAINING', mode == 'train')
+ if self.class_weight is not None:
+ constant_op.constant(self.class_weight, name='class_weight')
+ predictions, loss = model_fn(features, targets)
+ if isinstance(self.learning_rate, types.FunctionType):
+ learning_rate = self.learning_rate(contrib_framework.get_global_step())
+ else:
+ learning_rate = self.learning_rate
+ if isinstance(self.optimizer, types.FunctionType):
+ optimizer = self.optimizer(learning_rate)
+ else:
+ optimizer = self.optimizer
+ train_op = layers.optimize_loss(
+ loss,
+ contrib_framework.get_global_step(),
+ learning_rate=learning_rate,
+ optimizer=optimizer,
+ clip_gradients=self.clip_gradients)
+ return predictions, loss, train_op
+ return _model_fn
+
class TensorFlowBaseTransformer(TensorFlowEstimator, _sklearn.TransformerMixin):
"""TensorFlow Base Transformer class."""
- def transform(self, X):
+ def transform(self, X): # pylint: disable=invalid-name
"""Transform X using trained transformer."""
return(super(TensorFlowBaseTransformer, self).predict(
X, axis=1, batch_size=None))
diff --git a/tensorflow/contrib/learn/python/learn/estimators/classifier.py b/tensorflow/contrib/learn/python/learn/estimators/classifier.py
new file mode 100644
index 0000000000..b68e8d2f0d
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/estimators/classifier.py
@@ -0,0 +1,82 @@
+# pylint: disable=g-bad-file-header
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Classifier class."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib import metrics as metrics_lib
+from tensorflow.contrib.learn.python.learn.estimators import estimator
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
+
+
+def _get_classifier_metrics(unused_n_classes):
+ return {
+ ('accuracy', 'classes'): metrics_lib.streaming_accuracy
+ }
+
+
+class Classifier(estimator.Estimator):
+ """Classifier single output Estimator.
+
+ Given logits generating function, provides class / probabilities heads and
+ functions to work with them.
+ """
+
+ CLASS_OUTPUT = 'classes'
+ PROBABILITY_OUTPUT = 'probabilities'
+
+ def __init__(self, model_fn, n_classes, model_dir=None, config=None):
+ """Constructor for Classifier.
+
+ Args:
+ model_fn: (targets, predictions, mode) -> logits, loss, train_op
+ n_classes: Number of classes
+ model_dir: Base directory for output data
+ config: Configuration object (optional)
+ """
+ self._n_classes = n_classes
+ self._logits_fn = model_fn
+ super(Classifier, self).__init__(model_fn=self._classifier_model,
+ model_dir=model_dir, config=config)
+
+ def evaluate(self, x=None, y=None, input_fn=None, batch_size=None,
+ steps=None, metrics=None):
+ metrics = metrics or _get_classifier_metrics(self._n_classes)
+ return super(Classifier, self).evaluate(x=x, y=y, input_fn=input_fn,
+ batch_size=batch_size,
+ steps=steps, metrics=metrics)
+
+ def predict(self, x=None, input_fn=None, batch_size=None):
+ return super(Classifier, self).predict(
+ x=x, input_fn=input_fn, batch_size=batch_size,
+ outputs=[self.CLASS_OUTPUT])[self.CLASS_OUTPUT]
+
+ def predict_proba(self, x=None, input_fn=None, batch_size=None):
+ return super(Classifier, self).predict(
+ x=x, input_fn=input_fn, batch_size=batch_size,
+ outputs=[self.PROBABILITY_OUTPUT])[self.PROBABILITY_OUTPUT]
+
+ def _classifier_model(self, features, targets, mode):
+ logits, loss, train_op = self._logits_fn(features, targets, mode)
+ return {
+ 'classes': math_ops.argmax(logits, len(logits.get_shape()) - 1),
+ 'probabilities': nn.softmax(logits)
+ }, loss, train_op
+
diff --git a/tensorflow/contrib/learn/python/learn/estimators/classifier_test.py b/tensorflow/contrib/learn/python/learn/estimators/classifier_test.py
new file mode 100644
index 0000000000..a2efcd1f03
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/estimators/classifier_test.py
@@ -0,0 +1,74 @@
+# pylint: disable=g-bad-file-header
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for Classifier."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.contrib.learn.python.learn.estimators import _sklearn
+
+
+def iris_input_fn():
+ iris = tf.contrib.learn.datasets.load_iris()
+ features = tf.cast(
+ tf.reshape(
+ tf.constant(iris.data), [-1, 4]), tf.float32)
+ target = tf.cast(
+ tf.reshape(
+ tf.constant(iris.target), [-1]), tf.int64)
+ return features, target
+
+
+def logistic_model_fn(features, target, unused_mode):
+ target = tf.one_hot(target, 3, 1, 0)
+ prediction, loss = tf.contrib.learn.models.logistic_regression_zero_init(
+ features, target)
+ train_op = tf.contrib.layers.optimize_loss(
+ loss, tf.contrib.framework.get_global_step(), optimizer='Adagrad',
+ learning_rate=0.1)
+ return prediction, loss, train_op
+
+
+class ClassifierTest(tf.test.TestCase):
+
+ def testIrisAll(self):
+ iris = tf.contrib.learn.datasets.load_iris()
+ est = tf.contrib.learn.Classifier(model_fn=logistic_model_fn, n_classes=3)
+ est.fit(iris.data, iris.target, steps=100)
+ scores = est.evaluate(x=iris.data, y=iris.target)
+ predictions = est.predict(x=iris.data)
+ predictions_proba = est.predict_proba(x=iris.data)
+ self.assertEqual(predictions.shape[0], iris.target.shape[0])
+ self.assertAllClose(predictions, np.argmax(predictions_proba, axis=1))
+ other_score = _sklearn.accuracy_score(iris.target, predictions)
+ self.assertAllClose(other_score, scores['accuracy'])
+
+ def testIrisInputFn(self):
+ iris = tf.contrib.learn.datasets.load_iris()
+ est = tf.contrib.learn.Classifier(model_fn=logistic_model_fn, n_classes=3)
+ est.train(input_fn=iris_input_fn, steps=100)
+ _ = est.evaluate(input_fn=iris_input_fn, steps=1)
+ predictions = est.predict(x=iris.data)
+ self.assertEqual(predictions.shape[0], iris.target.shape[0])
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn.py b/tensorflow/contrib/learn/python/learn/estimators/dnn.py
index 33b6a4e230..cf952f6d8b 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn.py
@@ -245,7 +245,7 @@ class TensorFlowDNNClassifier(TensorFlowEstimator, _sklearn.ClassifierMixin):
config=config,
verbose=verbose)
- def _model_fn(self, X, y):
+ def _model_fn(self, X, y): # pylint: disable=invalid-name
return models.get_dnn_model(self.hidden_units,
models.logistic_regression,
dropout=self.dropout)(X, y)
@@ -253,18 +253,19 @@ class TensorFlowDNNClassifier(TensorFlowEstimator, _sklearn.ClassifierMixin):
@property
def weights_(self):
"""Returns weights of the DNN weight layers."""
- return [self.get_tensor_value(w.name)
+ return [self.get_variable_value(w.name)
for w in self._graph.get_collection('dnn_weights')
- ] + [self.get_tensor_value('logistic_regression/weights')]
+ ] + [self.get_variable_value('logistic_regression/weights')]
@property
def bias_(self):
"""Returns bias of the DNN's bias layers."""
- return [self.get_tensor_value(b.name)
+ return [self.get_variable_value(b.name)
for b in self._graph.get_collection('dnn_biases')
- ] + [self.get_tensor_value('logistic_regression/bias')]
+ ] + [self.get_variable_value('logistic_regression/bias')]
+# TODO(ipolosukhin): Deprecate this class in favor of DNNRegressor.
class TensorFlowDNNRegressor(TensorFlowEstimator, _sklearn.RegressorMixin):
"""TensorFlow DNN Regressor model.
@@ -318,7 +319,7 @@ class TensorFlowDNNRegressor(TensorFlowEstimator, _sklearn.RegressorMixin):
config=config,
verbose=verbose)
- def _model_fn(self, X, y):
+ def _model_fn(self, X, y): # pylint: disable=invalid-name
return models.get_dnn_model(self.hidden_units,
models.linear_regression,
dropout=self.dropout)(X, y)
@@ -326,13 +327,13 @@ class TensorFlowDNNRegressor(TensorFlowEstimator, _sklearn.RegressorMixin):
@property
def weights_(self):
"""Returns weights of the DNN weight layers."""
- return [self.get_tensor_value(w.name)
+ return [self.get_variable_value(w.name)
for w in self._graph.get_collection('dnn_weights')
- ] + [self.get_tensor_value('linear_regression/weights')]
+ ] + [self.get_variable_value('linear_regression/weights')]
@property
def bias_(self):
"""Returns bias of the DNN's bias layers."""
- return [self.get_tensor_value(b.name)
+ return [self.get_variable_value(b.name)
for b in self._graph.get_collection('dnn_biases')
- ] + [self.get_tensor_value('linear_regression/bias')]
+ ] + [self.get_variable_value('linear_regression/bias')]
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
index f380786ccb..36981bd26a 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
@@ -142,6 +142,24 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
"""
return self._infer_model(x=x, input_fn=input_fn, batch_size=batch_size)
+ @property
+ def linear_weights_(self):
+ """Returns weights per feature of the linear part."""
+ all_variables = self.get_variable_names()
+ values = {}
+ for name in all_variables:
+ if (name.startswith("linear/") and name.rfind("/") == 6 and
+ name != "linear/bias_weight"):
+ values[name] = self.get_variable_value(name)
+ if len(values) == 1:
+ return values[values.keys()[0]]
+ return values
+
+ @property
+ def linear_bias_(self):
+ """Returns bias of the linear part."""
+ return self.get_variable_value("linear/bias_weight")
+
def _get_train_ops(self, features, targets):
"""See base class."""
global_step = contrib_variables.get_global_step()
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index 71641920ee..bf2a5232bd 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -24,20 +24,17 @@ import abc
import os
import tempfile
import time
-import types
import numpy as np
import six
from tensorflow.contrib import framework as contrib_framework
-from tensorflow.contrib import layers
-from tensorflow.contrib import losses
-from tensorflow.contrib import metrics as metrics_lib
from tensorflow.contrib.learn.python.learn import graph_actions
from tensorflow.contrib.learn.python.learn import monitors as monitors_lib
from tensorflow.contrib.learn.python.learn.estimators import _sklearn as sklearn
from tensorflow.contrib.learn.python.learn.estimators import run_config
from tensorflow.contrib.learn.python.learn.estimators import tensor_signature
+from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError
from tensorflow.contrib.learn.python.learn.graph_actions import evaluate
from tensorflow.contrib.learn.python.learn.graph_actions import infer
from tensorflow.contrib.learn.python.learn.graph_actions import train
@@ -53,16 +50,6 @@ from tensorflow.python.training import device_setter
from tensorflow.python.training import saver
-# Default metrics for evaluation.
-_EVAL_METRICS = {
- 'regression': {
- 'mean_squared_error': metrics_lib.streaming_mean_squared_error,
- },
- 'classification': {
- 'logistic': losses.sigmoid_cross_entropy,
- },}
-
-
class ModeKeys(object):
"""Standard names for model modes.
@@ -98,7 +85,6 @@ class BaseEstimator(sklearn.BaseEstimator):
* _get_train_ops
* _get_eval_ops
* _get_predict_ops
- It may override _get_default_metric_functions.
`Estimator` implemented below is a good example of how to use this class.
@@ -155,7 +141,7 @@ class BaseEstimator(sklearn.BaseEstimator):
inside the training loop.
Returns:
- Returns self.
+ Returns final loss.
"""
input_fn, feed_fn = _get_input_fn(x, y, batch_size)
return self._train_model(input_fn=input_fn,
@@ -202,7 +188,7 @@ class BaseEstimator(sklearn.BaseEstimator):
inside the training loop.
Returns:
- Returns self.
+ Returns final loss.
"""
input_fn, feed_fn = _get_input_fn(x, y, batch_size)
return self._train_model(input_fn=input_fn,
@@ -236,7 +222,7 @@ class BaseEstimator(sklearn.BaseEstimator):
different data sets, such as evaluate on training data vs test data.
Returns:
- Returns self.
+ Returns `dict` with evaluation results.
Raises:
ValueError: If x or y are not None while input_fn or feed_fn is not None.
@@ -252,18 +238,21 @@ class BaseEstimator(sklearn.BaseEstimator):
metrics=metrics,
name=name)
- def predict(self, x=None, input_fn=None, batch_size=None):
+ def predict(self, x=None, input_fn=None, batch_size=None, outputs=None):
"""Returns predictions for given features.
Args:
x: features.
input_fn: Input function. If set, x must be None.
batch_size: Override default batch size.
+ outputs: list of `str`, name of the output to predict.
+ If `None`, returns all.
Returns:
Numpy array of predicted classes or regression values.
"""
- return self._infer_model(x=x, input_fn=input_fn, batch_size=batch_size)
+ return self._infer_model(x=x, input_fn=input_fn, batch_size=batch_size,
+ outputs=outputs)
def get_variable_value(self, name):
"""Returns value of the variable given by name.
@@ -374,6 +363,7 @@ class BaseEstimator(sklearn.BaseEstimator):
monitors=None,
log_every_steps=100,
fail_on_nan_loss=True):
+ # TODO(wicke): This is a hack and needs to go.
if self._config.execution_mode not in ('all', 'train'):
return
@@ -462,12 +452,20 @@ class BaseEstimator(sklearn.BaseEstimator):
feed_fn=None,
metrics=None,
name=''):
+ # TODO(wicke): This is a hack and needs to go.
if self._config.execution_mode not in ('all', 'evaluate', 'eval_evalset'):
return
+ # Check that model has been trained.
checkpoint_path = self._model_dir
+ latest_path = saver.latest_checkpoint(checkpoint_path)
+ if not latest_path:
+ raise NotFittedError("Couldn't find trained model at %s."
+ % checkpoint_path)
+ # Setup output directory.
eval_dir = os.path.join(self._model_dir, 'eval' if not name else
'eval_' + name)
+
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
global_step = contrib_framework.create_global_step(g)
@@ -492,21 +490,38 @@ class BaseEstimator(sklearn.BaseEstimator):
return result[0]
return result
- def _infer_model(self, x=None, input_fn=None, feed_fn=None, batch_size=None):
+ def _infer_model(self, x=None, input_fn=None, feed_fn=None, batch_size=None,
+ outputs=None):
# Converts inputs into tf.DataFrame / tf.Series.
batch_size = -1 if batch_size is None else batch_size
if x is not None:
input_fn, feed_fn = _get_predict_input_fn(x, None, batch_size)
+ # Check that model has been trained.
checkpoint_path = saver.latest_checkpoint(self._model_dir)
+ if not checkpoint_path:
+ raise NotFittedError("Couldn't find trained model at %s."
+ % self._model_dir)
+
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
contrib_framework.create_global_step(g)
features = self._get_features_from_input_fn(input_fn)
predictions = self._get_predict_ops(features)
+ # If predictions is single output - wrap it into dict, and remember to
+ # return not a dict.
return_dict = True
if not isinstance(predictions, dict):
predictions, return_dict = {'predictions': predictions}, False
+ # Filter what to run predictions on, if outputs provided.
+ if outputs:
+ existing_keys = predictions.keys()
+ predictions = {
+ key: value for key, value in predictions.items() if key in outputs
+ }
+ if not predictions:
+ raise ValueError('Expected to run at least one output from %s, '
+ 'provided %s.' % (existing_keys, outputs))
if feed_fn is None:
preds = infer(checkpoint_path, predictions)
else:
@@ -532,80 +547,17 @@ class Estimator(BaseEstimator):
Parameters:
model_fn: Model function, takes features and targets tensors or dicts of
tensors and returns predictions and loss tensors.
- E.g. `(features, targets) -> (predictions, loss)`.
+ E.g. `(features, targets) -> (predictions, loss, train_op)`.
model_dir: Directory to save model parameters, graph and etc.
- classification: boolean, true if classification problem.
- learning_rate: learning rate for the model.
- optimizer: optimizer for the model, can be:
- string: name of optimizer, like 'SGD', 'Adam', 'Adagrad', 'Ftl',
- 'Momentum', 'RMSProp', 'Momentum').
- Full list in contrib/layers/optimizers.py
- class: sub-class of Optimizer
- (like tf.train.GradientDescentOptimizer).
- clip_gradients: clip_norm value for call to `clip_by_global_norm`. None
- denotes no gradient clipping.
config: Configuration object.
"""
def __init__(self,
model_fn=None,
model_dir=None,
- classification=True,
- learning_rate=0.1,
- optimizer='Adagrad',
- clip_gradients=None,
config=None):
super(Estimator, self).__init__(model_dir=model_dir, config=config)
-
self._model_fn = model_fn
- self._classification = classification
- if isinstance(optimizer, six.string_types):
- if optimizer not in layers.OPTIMIZER_CLS_NAMES:
- raise ValueError(
- 'Optimizer name should be one of [%s], you provided %s.' %
- (', '.join(layers.OPTIMIZER_CLS_NAMES), optimizer))
- self.optimizer = optimizer
- self.learning_rate = learning_rate
- self.clip_gradients = clip_gradients
-
- def predict(self, x=None, input_fn=None, axis=None, batch_size=None):
- """Returns predictions for given features.
-
- Args:
- x: features.
- input_fn: Input function. If set, x must be None.
- axis: Axis on which to argmax (for classification).
- Last axis is used by default.
- batch_size: Override default batch size.
-
- Returns:
- Numpy array of predicted classes or regression values.
- """
- predictions = self._infer_model(x=x,
- input_fn=input_fn,
- batch_size=batch_size)
- if self._classification:
- if isinstance(predictions, dict):
- for key in predictions:
- cur_axis = (len(predictions[key].shape) - 1) if axis is None else axis
- predictions[key] = np.argmax(predictions[key], axis=cur_axis)
- else:
- cur_axis = (len(predictions.shape) - 1) if axis is None else axis
- predictions = np.argmax(predictions, axis=cur_axis)
- return predictions
-
- def predict_proba(self, x=None, input_fn=None, batch_size=None):
- """Returns prediction probabilities for given features (classification).
-
- Args:
- x: features.
- input_fn: Input function. If set, x and y must be None.
- batch_size: Override default batch size.
-
- Returns:
- Numpy array of predicted probabilities.
- """
- return self._infer_model(x=x, input_fn=input_fn, batch_size=batch_size)
def _get_train_ops(self, features, targets):
"""Method that builds model graph and returns trainer ops.
@@ -621,26 +573,7 @@ class Estimator(BaseEstimator):
Returns:
Tuple of train `Operation` and loss `Tensor`.
"""
- _, loss = self._model_fn(features, targets, ModeKeys.TRAIN)
- # TODO(ipolosukhin): Move this to TensorFlowEstimator when
- # moving out training.
- if isinstance(self.learning_rate, types.FunctionType):
- learning_rate = self.learning_rate(contrib_framework.get_global_step())
- else:
- learning_rate = self.learning_rate
- if isinstance(self.optimizer, types.FunctionType):
- optimizer = self.optimizer(learning_rate)
- else:
- optimizer = self.optimizer
- train_op = layers.optimize_loss(
- loss,
- contrib_framework.get_global_step(),
- learning_rate=learning_rate,
- optimizer=optimizer,
- clip_gradients=self.clip_gradients)
- # Add update ops.
- train_op = control_flow_ops.group(
- train_op, *ops.get_collection('update_ops'))
+ _, loss, train_op = self._model_fn(features, targets, ModeKeys.TRAIN)
return train_op, loss
def _get_eval_ops(self, features, targets, metrics):
@@ -658,17 +591,19 @@ class Estimator(BaseEstimator):
Returns:
metrics: `dict` of `Tensor` objects.
"""
- predictions, loss = self._model_fn(features, targets, ModeKeys.EVAL)
+ predictions, loss, _ = self._model_fn(features, targets, ModeKeys.EVAL)
result = {'loss': loss}
- if metrics is None:
- metrics = _EVAL_METRICS[
- 'classification' if self._classification else 'regression']
+ metrics = metrics or {}
if isinstance(targets, dict) and len(targets) == 1:
# Unpack single target into just tensor.
targets = targets[targets.keys()[0]]
for name, metric in six.iteritems(metrics):
- # TODO(ipolosukhin): Add support for multi-head metrics.
- result[name] = metric(predictions, targets)
+ if isinstance(name, tuple):
+ # Multi-head metrics.
+ result[name[0]] = metric(predictions[name[1]], targets)
+ else:
+ # Single head metrics.
+ result[name] = metric(predictions, targets)
return result
def _get_predict_ops(self, features):
@@ -686,7 +621,7 @@ class Estimator(BaseEstimator):
"""
targets = tensor_signature.create_placeholders_from_signatures(
self._targets_info)
- predictions, _ = self._model_fn(features, targets, ModeKeys.INFER)
+ predictions, _, _ = self._model_fn(features, targets, ModeKeys.INFER)
return predictions
def _get_feature_ops_from_example(self, examples_batch):
@@ -705,3 +640,4 @@ class Estimator(BaseEstimator):
"""
raise NotImplementedError('_get_feature_ops_from_example not yet '
'implemented')
+
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
index 4925cde328..1e320bbd7e 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
@@ -25,7 +25,7 @@ import tempfile
import numpy as np
import tensorflow as tf
-from tensorflow.contrib.learn.python.learn.estimators._sklearn import mean_squared_error
+from tensorflow.contrib.learn.python.learn.estimators import _sklearn
def boston_input_fn():
@@ -46,7 +46,7 @@ def iris_input_fn():
tf.constant(iris.data), [-1, 4]), tf.float32)
target = tf.cast(
tf.reshape(
- tf.constant(iris.target), [-1, 1]), tf.int32)
+ tf.constant(iris.target), [-1]), tf.int32)
return features, target
@@ -63,11 +63,24 @@ def boston_eval_fn():
def linear_model_fn(features, target, unused_mode):
- return tf.contrib.learn.models.linear_regression_zero_init(features, target)
+ prediction, loss = (
+ tf.contrib.learn.models.linear_regression_zero_init(features, target)
+ )
+ train_op = tf.contrib.layers.optimize_loss(
+ loss, tf.contrib.framework.get_global_step(), optimizer='Adagrad',
+ learning_rate=0.1)
+ return prediction, loss, train_op
def logistic_model_fn(features, target, unused_mode):
- return tf.contrib.learn.models.logistic_regression_zero_init(features, target)
+ target = tf.one_hot(target, 3, 1, 0)
+ prediction, loss = (
+ tf.contrib.learn.models.logistic_regression_zero_init(features, target)
+ )
+ train_op = tf.contrib.layers.optimize_loss(
+ loss, tf.contrib.framework.get_global_step(), optimizer='Adagrad',
+ learning_rate=0.1)
+ return {'class': tf.argmax(prediction, 1), 'prob': prediction}, loss, train_op
class CheckCallsMonitor(tf.contrib.learn.monitors.BaseMonitor):
@@ -90,44 +103,101 @@ class CheckCallsMonitor(tf.contrib.learn.monitors.BaseMonitor):
class EstimatorTest(tf.test.TestCase):
- def testBostonAll(self):
+ def testUntrained(self):
+ boston = tf.contrib.learn.datasets.load_boston()
+ est = tf.contrib.learn.Estimator(model_fn=linear_model_fn)
+ with self.assertRaises(tf.contrib.learn.NotFittedError):
+ _ = est.evaluate(
+ x=boston.data,
+ y=boston.target.astype(np.float32))
+ with self.assertRaises(tf.contrib.learn.NotFittedError):
+ est.predict(x=boston.data)
+
+ def testContinueTraining(self):
boston = tf.contrib.learn.datasets.load_boston()
+ output_dir = tempfile.mkdtemp()
est = tf.contrib.learn.Estimator(model_fn=linear_model_fn,
- classification=False)
+ model_dir=output_dir)
+ est.fit(x=boston.data, y=boston.target.astype(np.float32), steps=50)
+ scores = est.evaluate(
+ x=boston.data,
+ y=boston.target.astype(np.float32),
+ metrics={'MSE': tf.contrib.metrics.streaming_mean_squared_error})
+ del est
+ # Create another estimator object with the same output dir.
+ est2 = tf.contrib.learn.Estimator(model_fn=linear_model_fn,
+ model_dir=output_dir)
+
+ # Check we can evaluate and predict.
+ scores2 = est2.evaluate(
+ x=boston.data,
+ y=boston.target.astype(np.float32),
+ metrics={'MSE': tf.contrib.metrics.streaming_mean_squared_error})
+ self.assertAllClose(scores2['MSE'],
+ scores['MSE'])
+ predictions = est2.predict(x=boston.data)
+ other_score = _sklearn.mean_squared_error(predictions, boston.target)
+ self.assertAllClose(other_score, scores['MSE'])
+
+ # Check we can keep training.
+ est2.fit(x=boston.data, y=boston.target.astype(np.float32), steps=100)
+ scores3 = est2.evaluate(
+ x=boston.data,
+ y=boston.target.astype(np.float32),
+ metrics={'MSE': tf.contrib.metrics.streaming_mean_squared_error})
+ self.assertLess(scores3['MSE'], scores['MSE'])
+
+ def testBostonAll(self):
+ boston = tf.contrib.learn.datasets.load_boston()
+ est = tf.contrib.learn.Estimator(model_fn=linear_model_fn)
est.fit(x=boston.data, y=boston.target.astype(np.float32), steps=100)
scores = est.evaluate(
x=boston.data,
- y=boston.target.astype(np.float32))
+ y=boston.target.astype(np.float32),
+ metrics={'MSE': tf.contrib.metrics.streaming_mean_squared_error})
predictions = est.predict(x=boston.data)
- other_score = mean_squared_error(predictions, boston.target)
- self.assertAllClose(other_score, scores['mean_squared_error'])
+ other_score = _sklearn.mean_squared_error(predictions, boston.target)
+ self.assertAllClose(other_score, scores['MSE'])
def testIrisAll(self):
iris = tf.contrib.learn.datasets.load_iris()
- est = tf.contrib.learn.Estimator(model_fn=logistic_model_fn,
- classification=True)
+ est = tf.contrib.learn.Estimator(model_fn=logistic_model_fn)
+ est.fit(iris.data, iris.target, steps=100)
+ scores = est.evaluate(
+ x=iris.data,
+ y=iris.target,
+ metrics={('accuracy', 'class'): tf.contrib.metrics.streaming_accuracy})
+ predictions = est.predict(x=iris.data)
+ predictions_class = est.predict(x=iris.data, outputs=['class'])
+ self.assertEqual(predictions['class'].shape[0], iris.target.shape[0])
+ self.assertAllClose(predictions['class'], predictions_class['class'])
+ self.assertAllClose(predictions['class'], np.argmax(predictions['prob'],
+ axis=1))
+ other_score = _sklearn.accuracy_score(iris.target, predictions['class'])
+ self.assertAllClose(other_score, scores['accuracy'])
+
+ def testIrisInputFn(self):
+ iris = tf.contrib.learn.datasets.load_iris()
+ est = tf.contrib.learn.Estimator(model_fn=logistic_model_fn)
est.train(input_fn=iris_input_fn, steps=100)
_ = est.evaluate(input_fn=iris_input_fn, steps=1)
- predictions = est.predict(x=iris.data)
+ predictions = est.predict(x=iris.data)['class']
self.assertEqual(predictions.shape[0], iris.target.shape[0])
def testTrainInputFn(self):
- est = tf.contrib.learn.Estimator(model_fn=linear_model_fn,
- classification=False)
+ est = tf.contrib.learn.Estimator(model_fn=linear_model_fn)
est.train(input_fn=boston_input_fn, steps=1)
_ = est.evaluate(input_fn=boston_eval_fn, steps=1)
def testPredict(self):
- est = tf.contrib.learn.Estimator(model_fn=linear_model_fn,
- classification=False)
+ est = tf.contrib.learn.Estimator(model_fn=linear_model_fn)
boston = tf.contrib.learn.datasets.load_boston()
est.train(input_fn=boston_input_fn, steps=1)
output = est.predict(boston.data)
self.assertEqual(output.shape[0], boston.target.shape[0])
def testPredictFn(self):
- est = tf.contrib.learn.Estimator(model_fn=linear_model_fn,
- classification=False)
+ est = tf.contrib.learn.Estimator(model_fn=linear_model_fn)
boston = tf.contrib.learn.datasets.load_boston()
est.train(input_fn=boston_input_fn, steps=1)
output = est.predict(input_fn=boston_input_fn)
@@ -136,16 +206,13 @@ class EstimatorTest(tf.test.TestCase):
def testWrongInput(self):
def other_input_fn():
return {'other': tf.constant([0, 0, 0])}, tf.constant([0, 0, 0])
- output_dir = tempfile.mkdtemp()
- est = tf.contrib.learn.Estimator(model_fn=linear_model_fn,
- classification=False, model_dir=output_dir)
+ est = tf.contrib.learn.Estimator(model_fn=linear_model_fn)
est.train(input_fn=boston_input_fn, steps=1)
with self.assertRaises(ValueError):
est.train(input_fn=other_input_fn, steps=1)
def testMonitors(self):
- est = tf.contrib.learn.Estimator(model_fn=linear_model_fn,
- classification=False)
+ est = tf.contrib.learn.Estimator(model_fn=linear_model_fn)
est.train(input_fn=boston_input_fn, steps=21,
monitors=[CheckCallsMonitor()])
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py
index 316f98f575..27543a51ea 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/linear.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py
@@ -96,6 +96,14 @@ class LinearClassifier(dnn_linear_combined.DNNLinearCombinedClassifier):
self._linear_feature_columns = layers.infer_real_valued_columns(features)
return super(LinearClassifier, self)._get_train_ops(features, targets)
+ @property
+ def weights_(self):
+ return self.linear_weights_
+
+ @property
+ def bias_(self):
+ return self.linear_bias_
+
class LinearRegressor(dnn_linear_combined.DNNLinearCombinedRegressor):
"""Linear regressor model.
@@ -163,8 +171,16 @@ class LinearRegressor(dnn_linear_combined.DNNLinearCombinedRegressor):
self._linear_feature_columns = layers.infer_real_valued_columns(features)
return super(LinearRegressor, self)._get_train_ops(features, targets)
+ @property
+ def weights_(self):
+ return self.linear_weights_
+
+ @property
+ def bias_(self):
+ return self.linear_bias_
-# TODO(ipolosukhin): Deprecate this class in favor of LinearClassifier.
+
+# TODO(ipolosukhin): Deprecate this class in favor of LinearRegressor.
class TensorFlowLinearRegressor(TensorFlowEstimator, _sklearn.RegressorMixin):
"""TensorFlow Linear Regression model."""
@@ -194,14 +210,15 @@ class TensorFlowLinearRegressor(TensorFlowEstimator, _sklearn.RegressorMixin):
@property
def weights_(self):
"""Returns weights of the linear regression."""
- return self.get_tensor_value('linear_regression/weights')
+ return self.get_variable_value('linear_regression/weights')
@property
def bias_(self):
"""Returns bias of the linear regression."""
- return self.get_tensor_value('linear_regression/bias')
+ return self.get_variable_value('linear_regression/bias')
+# TODO(ipolosukhin): Deprecate this class in favor of LinearClassifier.
class TensorFlowLinearClassifier(TensorFlowEstimator, _sklearn.ClassifierMixin):
"""TensorFlow Linear Classifier model."""
@@ -233,12 +250,12 @@ class TensorFlowLinearClassifier(TensorFlowEstimator, _sklearn.ClassifierMixin):
@property
def weights_(self):
"""Returns weights of the linear classifier."""
- return self.get_tensor_value('logistic_regression/weights')
+ return self.get_variable_value('logistic_regression/weights')
@property
def bias_(self):
"""Returns weights of the linear classifier."""
- return self.get_tensor_value('logistic_regression/bias')
+ return self.get_variable_value('logistic_regression/bias')
TensorFlowRegressor = TensorFlowLinearRegressor
diff --git a/tensorflow/contrib/learn/python/learn/estimators/rnn.py b/tensorflow/contrib/learn/python/learn/estimators/rnn.py
index 0d0eba9476..73626d84cf 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/rnn.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/rnn.py
@@ -117,12 +117,12 @@ class TensorFlowRNNClassifier(TensorFlowEstimator, _sklearn.ClassifierMixin):
@property
def bias_(self):
"""Returns bias of the rnn layer."""
- return self.get_tensor_value('logistic_regression/bias')
+ return self.get_variable_value('logistic_regression/bias')
@property
def weights_(self):
"""Returns weights of the rnn layer."""
- return self.get_tensor_value('logistic_regression/weights')
+ return self.get_variable_value('logistic_regression/weights')
class TensorFlowRNNRegressor(TensorFlowEstimator, _sklearn.RegressorMixin):
@@ -208,9 +208,9 @@ class TensorFlowRNNRegressor(TensorFlowEstimator, _sklearn.RegressorMixin):
@property
def bias_(self):
"""Returns bias of the rnn layer."""
- return self.get_tensor_value('linear_regression/bias')
+ return self.get_variable_value('linear_regression/bias')
@property
def weights_(self):
"""Returns weights of the rnn layer."""
- return self.get_tensor_value('linear_regression/weights')
+ return self.get_variable_value('linear_regression/weights')
diff --git a/tensorflow/contrib/learn/python/learn/graph_actions.py b/tensorflow/contrib/learn/python/learn/graph_actions.py
index e7de58cf2e..f9d37a3c10 100644
--- a/tensorflow/contrib/learn/python/learn/graph_actions.py
+++ b/tensorflow/contrib/learn/python/learn/graph_actions.py
@@ -454,8 +454,10 @@ def evaluate(graph,
# one existing already or it's a string.
existing_tags = [tensor_util.constant_value(summary.op.inputs[0])
for summary in ops.get_collection(ops.GraphKeys.SUMMARIES)]
+ existing_tags = [name.tolist() if isinstance(name, np.ndarray) else name
+ for name in existing_tags]
for key, value in eval_dict.items():
- if key in existing_tags:
+ if key.encode() in existing_tags:
continue
if isinstance(value, ops.Tensor):
summaries.summarize_tensor(value, tag=key)
@@ -490,33 +492,37 @@ def evaluate(graph,
eval_results = None
# TODO(amodei): Fix this to run through the eval set exactly once.
step = 0
+ eval_step = None
+ feed_dict = None
logging.info('Eval steps [%d,%s) for training step %d.', step,
'inf' if max_steps is None
else str(max_steps), current_global_step)
try:
try:
while (max_steps is None) or (step < max_steps):
+ step += 1
start_time = time.time()
feed_dict = feed_fn() if feed_fn is not None else None
- eval_results = None
if update_op is not None:
session.run(update_op, feed_dict=feed_dict)
else:
eval_results = _run_dict(session, eval_dict, feed_dict=feed_dict)
+ eval_step = step
# TODO(wicke): We should assert that the global step hasn't changed.
- step += 1
if step % log_every_steps == 0:
- if eval_results is None:
+ if eval_step is None or step != eval_step:
eval_results = _run_dict(session, eval_dict, feed_dict=feed_dict)
+ eval_step = step
duration = time.time() - start_time
logging.info('Results after %d steps (%.3f sec/batch): %s.',
step, float(duration),
', '.join('%s = %s' % (k, v)
for k, v in eval_results.items()))
finally:
- if eval_results is None:
+ if eval_results is None or step != eval_step:
eval_results = _run_dict(session, eval_dict, feed_dict=feed_dict)
+ eval_step = step
# Stop queue runners.
coord.request_stop()
coord.join(threads, stop_grace_period_secs=120)
diff --git a/tensorflow/contrib/learn/python/learn/io/__init__.py b/tensorflow/contrib/learn/python/learn/io/__init__.py
index 2d563f648a..1a9a5fa7da 100644
--- a/tensorflow/contrib/learn/python/learn/io/__init__.py
+++ b/tensorflow/contrib/learn/python/learn/io/__init__.py
@@ -30,7 +30,3 @@ from tensorflow.contrib.learn.python.learn.io.pandas_io import extract_pandas_da
from tensorflow.contrib.learn.python.learn.io.pandas_io import extract_pandas_labels
from tensorflow.contrib.learn.python.learn.io.pandas_io import extract_pandas_matrix
from tensorflow.contrib.learn.python.learn.io.pandas_io import HAS_PANDAS
-
-# pylint: disable=g-import-not-at-top
-if HAS_PANDAS:
- from tensorflow.contrib.learn.python.learn.io.pandas_io import pd
diff --git a/tensorflow/contrib/learn/python/learn/io/data_feeder.py b/tensorflow/contrib/learn/python/learn/io/data_feeder.py
index ec7001bc94..9b2923a540 100644
--- a/tensorflow/contrib/learn/python/learn/io/data_feeder.py
+++ b/tensorflow/contrib/learn/python/learn/io/data_feeder.py
@@ -38,7 +38,7 @@ from .dask_io import HAS_DASK, extract_dask_data, extract_dask_labels
def _get_in_out_shape(x_shape, y_shape, n_classes, batch_size):
"""Returns shape for input and output of the data feeder."""
- if batch_size < 0:
+ if batch_size is None or batch_size < 0:
batch_size = x_shape[0]
x_shape = list(x_shape[1:]) if len(x_shape) > 1 else [1]
input_shape = [batch_size] + x_shape
@@ -321,7 +321,7 @@ class DataFeeder(object):
feed_dict[self._epoch_placeholder.name] = [self.epoch]
# take next batch of indices
- if self.batch_size < 0:
+ if self.batch_size is None or self.batch_size < 0:
batch_indices = self.indices
else:
end = min(self.X.shape[0], self.offset + self.batch_size)
@@ -334,7 +334,7 @@ class DataFeeder(object):
feed_dict[self._input_placeholder.name] = inp
# move offset and reset it if necessary
- if self.batch_size > 0:
+ if self.batch_size is not None and self.batch_size > 0:
self.offset += self.batch_size
if self.offset >= self.X.shape[0]:
self.indices = self.random_state.permutation(self.X.shape[0])
diff --git a/tensorflow/contrib/learn/python/learn/models.py b/tensorflow/contrib/learn/python/learn/models.py
index 82a8833d0a..9ed3f9a7e8 100644
--- a/tensorflow/contrib/learn/python/learn/models.py
+++ b/tensorflow/contrib/learn/python/learn/models.py
@@ -137,8 +137,8 @@ def logistic_regression(X,
uniform_unit_scaling_initialzer will be used.
"""
with vs.variable_scope('logistic_regression'):
- logging_ops.histogram_summary('logistic_regression.X', X)
- logging_ops.histogram_summary('logistic_regression.y', y)
+ logging_ops.histogram_summary('%s.X' % vs.get_variable_scope().name, X)
+ logging_ops.histogram_summary('%s.y' % vs.get_variable_scope().name, y)
# Set up the requested initialization.
if init_mean is None:
weights = vs.get_variable('weights',
@@ -152,8 +152,10 @@ def logistic_regression(X,
bias = vs.get_variable('bias', [y.get_shape()[-1]],
initializer=init_ops.random_normal_initializer(
init_mean, init_stddev))
- logging_ops.histogram_summary('logistic_regression.weights', weights)
- logging_ops.histogram_summary('logistic_regression.bias', bias)
+ logging_ops.histogram_summary('%s.weights' % vs.get_variable_scope().name,
+ weights)
+ logging_ops.histogram_summary('%s.bias' % vs.get_variable_scope().name,
+ bias)
# If no class weight provided, try to retrieve one from pre-defined
# tensor name in the graph.
if not class_weight:
diff --git a/tensorflow/contrib/learn/python/learn/ops/conv_ops.py b/tensorflow/contrib/learn/python/learn/ops/conv_ops.py
index fda98869e6..15c21cdae3 100644
--- a/tensorflow/contrib/learn/python/learn/ops/conv_ops.py
+++ b/tensorflow/contrib/learn/python/learn/ops/conv_ops.py
@@ -49,6 +49,8 @@ def conv2d(tensor_in,
strides: A list of ints, 1-D of length 4. The stride of the sliding
window for each dimension of input.
padding: A string: 'SAME' or 'VALID'. The type of padding algorthim to use.
+ See the [comment here]
+ (https://www.tensorflow.org/api_docs/python/nn.html#convolution)
bias: Boolean, if to add bias.
activation: Activation Op, optional. If provided applied on the output.
batch_norm: Whether to apply batch normalization.
diff --git a/tensorflow/contrib/learn/python/learn/preprocessing/tests/test_categorical.py b/tensorflow/contrib/learn/python/learn/preprocessing/tests/test_categorical.py
index 7090f6dd73..434e66cade 100644
--- a/tensorflow/contrib/learn/python/learn/preprocessing/tests/test_categorical.py
+++ b/tensorflow/contrib/learn/python/learn/preprocessing/tests/test_categorical.py
@@ -26,7 +26,6 @@ import numpy as np
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.io import HAS_PANDAS
-from tensorflow.contrib.learn.python.learn.io import pd
from tensorflow.contrib.learn.python.learn.preprocessing import categorical
@@ -41,6 +40,7 @@ class CategoricalTest(tf.test.TestCase):
def testSingleCategoricalProcessorPandasSingleDF(self):
if HAS_PANDAS:
+ import pandas as pd # pylint: disable=g-import-not-at-top
cat_processor = categorical.CategoricalProcessor()
data = pd.DataFrame({"Gender": ["Male", "Female", "Male"]})
x = list(cat_processor.fit_transform(data))
diff --git a/tensorflow/contrib/learn/python/learn/tests/test_io.py b/tensorflow/contrib/learn/python/learn/tests/test_io.py
index c0fb547738..3bbec5fde9 100644
--- a/tensorflow/contrib/learn/python/learn/tests/test_io.py
+++ b/tensorflow/contrib/learn/python/learn/tests/test_io.py
@@ -38,6 +38,7 @@ class IOTest(tf.test.TestCase):
def test_pandas_dataframe(self):
if HAS_PANDAS:
+ import pandas as pd # pylint: disable=g-import-not-at-top
random.seed(42)
iris = datasets.load_iris()
data = pd.DataFrame(iris.data)
@@ -51,6 +52,7 @@ class IOTest(tf.test.TestCase):
def test_pandas_series(self):
if HAS_PANDAS:
+ import pandas as pd # pylint: disable=g-import-not-at-top
random.seed(42)
iris = datasets.load_iris()
data = pd.DataFrame(iris.data)
@@ -62,6 +64,7 @@ class IOTest(tf.test.TestCase):
def test_string_data_formats(self):
if HAS_PANDAS:
+ import pandas as pd # pylint: disable=g-import-not-at-top
with self.assertRaises(ValueError):
learn.io.extract_pandas_data(pd.DataFrame({"Test": ["A", "B"]}))
with self.assertRaises(ValueError):
@@ -69,6 +72,8 @@ class IOTest(tf.test.TestCase):
def test_dask_io(self):
if HAS_DASK and HAS_PANDAS:
+ import pandas as pd # pylint: disable=g-import-not-at-top
+ import dask.dataframe as dd # pylint: disable=g-import-not-at-top
# test dask.dataframe
df = pd.DataFrame(
dict(a=list("aabbcc"), b=list(range(6))),
@@ -95,6 +100,8 @@ class IOTest(tf.test.TestCase):
def test_dask_iris_classification(self):
if HAS_DASK and HAS_PANDAS:
+ import pandas as pd # pylint: disable=g-import-not-at-top
+ import dask.dataframe as dd # pylint: disable=g-import-not-at-top
random.seed(42)
iris = datasets.load_iris()
data = pd.DataFrame(iris.data)
diff --git a/tensorflow/contrib/learn/python/learn/tests/test_nonlinear.py b/tensorflow/contrib/learn/python/learn/tests/test_nonlinear.py
index 119a51b6bf..007bb6ea6d 100644
--- a/tensorflow/contrib/learn/python/learn/tests/test_nonlinear.py
+++ b/tensorflow/contrib/learn/python/learn/tests/test_nonlinear.py
@@ -24,8 +24,6 @@ import random
import numpy as np
import tensorflow as tf
-from tensorflow.contrib.learn.python import learn
-from tensorflow.contrib.learn.python.learn import datasets
from tensorflow.contrib.learn.python.learn.estimators._sklearn import accuracy_score
from tensorflow.contrib.learn.python.learn.estimators._sklearn import mean_squared_error
@@ -35,9 +33,9 @@ class NonLinearTest(tf.test.TestCase):
def testIrisDNN(self):
random.seed(42)
- iris = datasets.load_iris()
- classifier = learn.TensorFlowDNNClassifier(hidden_units=[10, 20, 10],
- n_classes=3)
+ iris = tf.contrib.learn.datasets.load_iris()
+ classifier = tf.contrib.learn.TensorFlowDNNClassifier(
+ hidden_units=[10, 20, 10], n_classes=3)
classifier.fit(iris.data, iris.target)
score = accuracy_score(iris.target, classifier.predict(iris.data))
self.assertGreater(score, 0.9, "Failed with score = {0}".format(score))
@@ -51,12 +49,10 @@ class NonLinearTest(tf.test.TestCase):
def testBostonDNN(self):
random.seed(42)
- boston = datasets.load_boston()
- regressor = learn.TensorFlowDNNRegressor(hidden_units=[10, 20, 10],
- n_classes=0,
- batch_size=boston.data.shape[0],
- steps=300,
- learning_rate=0.001)
+ boston = tf.contrib.learn.datasets.load_boston()
+ regressor = tf.contrib.learn.TensorFlowDNNRegressor(
+ hidden_units=[10, 20, 10], n_classes=0,
+ batch_size=boston.data.shape[0], steps=300, learning_rate=0.001)
regressor.fit(boston.data, boston.target)
score = mean_squared_error(boston.target, regressor.predict(boston.data))
self.assertLess(score, 110, "Failed with score = {0}".format(score))
@@ -70,20 +66,18 @@ class NonLinearTest(tf.test.TestCase):
def testDNNDropout0(self):
# Dropout prob == 0.
- iris = datasets.load_iris()
- classifier = learn.TensorFlowDNNClassifier(hidden_units=[10, 20, 10],
- n_classes=3,
- dropout=0.0)
+ iris = tf.contrib.learn.datasets.load_iris()
+ classifier = tf.contrib.learn.TensorFlowDNNClassifier(
+ hidden_units=[10, 20, 10], n_classes=3, dropout=0.0)
classifier.fit(iris.data, iris.target)
score = accuracy_score(iris.target, classifier.predict(iris.data))
self.assertGreater(score, 0.9, "Failed with score = {0}".format(score))
def testDNNDropout0_1(self):
# Dropping only a little.
- iris = datasets.load_iris()
- classifier = learn.TensorFlowDNNClassifier(hidden_units=[10, 20, 10],
- n_classes=3,
- dropout=0.1)
+ iris = tf.contrib.learn.datasets.load_iris()
+ classifier = tf.contrib.learn.TensorFlowDNNClassifier(
+ hidden_units=[10, 20, 10], n_classes=3, dropout=0.1)
classifier.fit(iris.data, iris.target)
score = accuracy_score(iris.target, classifier.predict(iris.data))
# If the quality is lower - dropout is not working.
@@ -91,10 +85,9 @@ class NonLinearTest(tf.test.TestCase):
def testDNNDropout0_9(self):
# Dropping out most of it.
- iris = datasets.load_iris()
- classifier = learn.TensorFlowDNNClassifier(hidden_units=[10, 20, 10],
- n_classes=3,
- dropout=0.9)
+ iris = tf.contrib.learn.datasets.load_iris()
+ classifier = tf.contrib.learn.TensorFlowDNNClassifier(
+ hidden_units=[10, 20, 10], n_classes=3, dropout=0.9)
classifier.fit(iris.data, iris.target)
score = accuracy_score(iris.target, classifier.predict(iris.data))
self.assertGreater(score, 0.3, "Failed with score = {0}".format(score))
@@ -118,10 +111,10 @@ class NonLinearTest(tf.test.TestCase):
return tf.split(1, 5, X)
# Classification
- classifier = learn.TensorFlowRNNClassifier(rnn_size=2,
- cell_type="lstm",
- n_classes=2,
- input_op_fn=_input_fn)
+ classifier = tf.contrib.learn.TensorFlowRNNClassifier(rnn_size=2,
+ cell_type="lstm",
+ n_classes=2,
+ input_op_fn=_input_fn)
classifier.fit(data, labels)
# pylint: disable=pointless-statement
classifier.weights_
@@ -130,24 +123,22 @@ class NonLinearTest(tf.test.TestCase):
predictions = classifier.predict(test_data)
self.assertAllClose(predictions, np.array([1, 0]))
- classifier = learn.TensorFlowRNNClassifier(rnn_size=2,
- cell_type="rnn",
- n_classes=2,
- input_op_fn=_input_fn,
- num_layers=2)
+ classifier = tf.contrib.learn.TensorFlowRNNClassifier(rnn_size=2,
+ cell_type="rnn",
+ n_classes=2,
+ input_op_fn=_input_fn,
+ num_layers=2)
classifier.fit(data, labels)
- classifier = learn.TensorFlowRNNClassifier(rnn_size=2,
- cell_type="invalid_cell_type",
- n_classes=2,
- input_op_fn=_input_fn,
- num_layers=2)
+ classifier = tf.contrib.learn.TensorFlowRNNClassifier(
+ rnn_size=2, cell_type="invalid_cell_type", n_classes=2,
+ input_op_fn=_input_fn, num_layers=2)
with self.assertRaises(ValueError):
classifier.fit(data, labels)
# Regression
- regressor = learn.TensorFlowRNNRegressor(rnn_size=2,
- cell_type="gru",
- input_op_fn=_input_fn)
+ regressor = tf.contrib.learn.TensorFlowRNNRegressor(rnn_size=2,
+ cell_type="gru",
+ input_op_fn=_input_fn)
regressor.fit(data, targets)
# pylint: disable=pointless-statement
regressor.weights_
@@ -168,11 +159,11 @@ class NonLinearTest(tf.test.TestCase):
return tf.split(1, 5, X)
# Classification
- classifier = learn.TensorFlowRNNClassifier(rnn_size=2,
- cell_type="lstm",
- n_classes=2,
- input_op_fn=_input_fn,
- bidirectional=True)
+ classifier = tf.contrib.learn.TensorFlowRNNClassifier(rnn_size=2,
+ cell_type="lstm",
+ n_classes=2,
+ input_op_fn=_input_fn,
+ bidirectional=True)
classifier.fit(data, labels)
test_data = np.array(list([[1, 3, 3, 2, 1], [2, 3, 4,
5, 6]]), dtype=np.float32)
diff --git a/tensorflow/examples/skflow/iris.py b/tensorflow/examples/skflow/iris.py
index ea44428d54..f21c2342f0 100644
--- a/tensorflow/examples/skflow/iris.py
+++ b/tensorflow/examples/skflow/iris.py
@@ -25,11 +25,10 @@ X_train, X_test, y_train, y_test = cross_validation.train_test_split(iris.data,
test_size=0.2, random_state=42)
# Build 3 layer DNN with 10, 20, 10 units respectively.
-classifier = learn.TensorFlowDNNClassifier(hidden_units=[10, 20, 10],
- n_classes=3, steps=200)
+classifier = learn.DNNClassifier(hidden_units=[10, 20, 10], n_classes=3)
# Fit and predict.
-classifier.fit(X_train, y_train)
+classifier.fit(X_train, y_train, steps=200)
score = metrics.accuracy_score(y_test, classifier.predict(X_test))
print('Accuracy: {0:f}'.format(score))
diff --git a/tensorflow/examples/skflow/iris_custom_decay_dnn.py b/tensorflow/examples/skflow/iris_custom_decay_dnn.py
index b8b1a1dd14..c1e7d22d53 100644
--- a/tensorflow/examples/skflow/iris_custom_decay_dnn.py
+++ b/tensorflow/examples/skflow/iris_custom_decay_dnn.py
@@ -19,7 +19,6 @@ from sklearn import datasets, metrics
from sklearn.cross_validation import train_test_split
import tensorflow as tf
-from tensorflow.contrib import skflow
iris = datasets.load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data,
@@ -33,8 +32,9 @@ def exp_decay(global_step):
decay_steps=100, decay_rate=0.001)
# use customized decay function in learning_rate
-classifier = skflow.TensorFlowDNNClassifier(hidden_units=[10, 20, 10],
- n_classes=3, steps=800,
- learning_rate=exp_decay)
-classifier.fit(X_train, y_train)
+optimizer = tf.train.AdagradOptimizer(learning_rate=exp_decay)
+classifier = tf.contrib.learn.DNNClassifier(hidden_units=[10, 20, 10],
+ n_classes=3,
+ optimizer=optimizer)
+classifier.fit(X_train, y_train, steps=800)
score = metrics.accuracy_score(y_test, classifier.predict(X_test))