aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-07-08 09:06:54 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-08 10:17:45 -0700
commitc00c073f52c2fc7b6672022c75d0b2abb9d9af3a (patch)
treed5aad86afbf6697bcb6eaffc50c1c1f6f48cd0d0
parent9a9219be3531d12c804f671e3e236a0d05c01d70 (diff)
Begin removing feature column inference from linear and dnn estimators. Currently, the fit operation of each of them will infer feature columns from the passed in features. But it only works for dense float inputs.
Also, fixed some lint warnings. Change: 126921818
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column_ops_test.py22
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/__init__.py2
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn.py20
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn_test.py13
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py72
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator_test.py133
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/linear.py20
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/linear_test.py16
-rw-r--r--tensorflow/contrib/learn/python/learn/tests/base_test.py135
-rw-r--r--tensorflow/contrib/learn/python/learn/tests/estimators_test.py40
-rw-r--r--tensorflow/contrib/learn/python/learn/tests/io_test.py12
-rw-r--r--tensorflow/contrib/learn/python/learn/tests/multioutput_test.py5
-rw-r--r--tensorflow/contrib/learn/python/learn/tests/regression_test.py4
-rw-r--r--tensorflow/examples/skflow/boston.py11
-rw-r--r--tensorflow/examples/skflow/hdf5_classification.py33
-rw-r--r--tensorflow/examples/skflow/iris.py8
-rw-r--r--tensorflow/examples/skflow/iris_custom_decay_dnn.py9
-rw-r--r--tensorflow/examples/skflow/iris_custom_model.py4
-rw-r--r--tensorflow/examples/skflow/iris_run_config.py12
-rw-r--r--tensorflow/examples/skflow/iris_save_restore.py28
-rw-r--r--tensorflow/examples/skflow/iris_val_based_early_stopping.py5
-rw-r--r--tensorflow/examples/skflow/iris_with_pipeline.py15
-rw-r--r--tensorflow/examples/skflow/mnist.py60
-rw-r--r--tensorflow/examples/skflow/mnist_weights.py70
-rw-r--r--tensorflow/examples/skflow/out_of_core_data_classification.py41
25 files changed, 632 insertions, 158 deletions
diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
index 500effa09f..207f86dc8b 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
@@ -1314,12 +1314,32 @@ class ParseExampleTest(tf.test.TestCase):
class InferRealValuedColumnTest(tf.test.TestCase):
- def testTensor(self):
+ def testTensorInt32(self):
self.assertEqual(
tf.contrib.layers.infer_real_valued_columns(
tf.zeros(shape=[33, 4], dtype=tf.int32)),
[tf.contrib.layers.real_valued_column("", dimension=4, dtype=tf.int32)])
+ def testTensorInt64(self):
+ self.assertEqual(
+ tf.contrib.layers.infer_real_valued_columns(
+ tf.zeros(shape=[33, 4], dtype=tf.int64)),
+ [tf.contrib.layers.real_valued_column("", dimension=4, dtype=tf.int64)])
+
+ def testTensorFloat32(self):
+ self.assertEqual(
+ tf.contrib.layers.infer_real_valued_columns(
+ tf.zeros(shape=[33, 4], dtype=tf.float32)),
+ [tf.contrib.layers.real_valued_column(
+ "", dimension=4, dtype=tf.float32)])
+
+ def testTensorFloat64(self):
+ self.assertEqual(
+ tf.contrib.layers.infer_real_valued_columns(
+ tf.zeros(shape=[33, 4], dtype=tf.float64)),
+ [tf.contrib.layers.real_valued_column(
+ "", dimension=4, dtype=tf.float64)])
+
def testDictionary(self):
self.assertItemsEqual(
tf.contrib.layers.infer_real_valued_columns({
diff --git a/tensorflow/contrib/learn/python/learn/estimators/__init__.py b/tensorflow/contrib/learn/python/learn/estimators/__init__.py
index b6e6e57ebe..817fce7377 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/__init__.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/__init__.py
@@ -32,6 +32,8 @@ from tensorflow.contrib.learn.python.learn.estimators.dnn_linear_combined import
from tensorflow.contrib.learn.python.learn.estimators.dnn_linear_combined import DNNLinearCombinedRegressor
from tensorflow.contrib.learn.python.learn.estimators.estimator import BaseEstimator
from tensorflow.contrib.learn.python.learn.estimators.estimator import Estimator
+from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input
+from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input_fn
from tensorflow.contrib.learn.python.learn.estimators.estimator import ModeKeys
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearClassifier
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearRegressor
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn.py b/tensorflow/contrib/learn/python/learn/estimators/dnn.py
index 79b1e7b2ec..fdb598efc5 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn.py
@@ -24,6 +24,20 @@ from tensorflow.contrib.learn.python.learn.estimators import _sklearn
from tensorflow.contrib.learn.python.learn.estimators import dnn_linear_combined
from tensorflow.contrib.learn.python.learn.estimators.base import DeprecatedMixin
from tensorflow.python.ops import nn
+from tensorflow.python.platform import tf_logging as logging
+
+
+# TODO(b/29580537): Replace with @changing decorator.
+def _changing(feature_columns):
+ if feature_columns is not None:
+ return
+ logging.warn(
+ "Change warning: `feature_columns` will be required after 2016-08-01.\n"
+ "Instructions for updating:\n"
+ "Pass `tf.contrib.learn.infer_real_valued_columns_from_input(x)` or"
+ " `tf.contrib.learn.infer_real_valued_columns_from_input_fn(input_fn)`"
+ " as `feature_columns`, where `x` or `input_fn` is your argument to"
+ " `fit`, `evaluate`, or `predict`.")
class DNNClassifier(dnn_linear_combined.DNNLinearCombinedClassifier):
@@ -125,6 +139,7 @@ class DNNClassifier(dnn_linear_combined.DNNLinearCombinedClassifier):
Returns:
A `DNNClassifier` estimator.
"""
+ _changing(feature_columns)
super(DNNClassifier, self).__init__(
model_dir=model_dir,
n_classes=n_classes,
@@ -139,8 +154,7 @@ class DNNClassifier(dnn_linear_combined.DNNLinearCombinedClassifier):
config=config)
self._feature_columns_inferred = False
- # TODO(ptucker): Update this class to require caller pass `feature_columns` to
- # ctor, so we can remove feature_column inference.
+ # TODO(b/29580537): Remove feature_columns inference.
def _validate_dnn_feature_columns(self, features):
if self._dnn_feature_columns is None:
self._dnn_feature_columns = layers.infer_real_valued_columns(features)
@@ -273,6 +287,7 @@ class DNNRegressor(dnn_linear_combined.DNNLinearCombinedRegressor):
Returns:
A `DNNRegressor` estimator.
"""
+ _changing(feature_columns)
super(DNNRegressor, self).__init__(
model_dir=model_dir,
weight_column_name=weight_column_name,
@@ -286,6 +301,7 @@ class DNNRegressor(dnn_linear_combined.DNNLinearCombinedRegressor):
config=config)
self._feature_columns_inferred = False
+ # TODO(b/29580537): Remove feature_columns inference.
def _validate_dnn_feature_columns(self, features):
if self._dnn_feature_columns is None:
self._dnn_feature_columns = layers.infer_real_valued_columns(features)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py
index 4ceb6a996b..ea09f71785 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py
@@ -81,13 +81,22 @@ def boston_input_fn():
return features, target
-class InferedColumnTest(tf.test.TestCase):
+class FeatureColumnTest(tf.test.TestCase):
- def testTrain(self):
+ # TODO(b/29580537): Remove when we deprecate feature column inference.
+ def testTrainWithInferredFeatureColumns(self):
est = tf.contrib.learn.DNNRegressor(hidden_units=[3, 3])
est.fit(input_fn=boston_input_fn, steps=1)
_ = est.evaluate(input_fn=boston_input_fn, steps=1)
+ def testTrain(self):
+ feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input_fn(
+ boston_input_fn)
+ est = tf.contrib.learn.DNNRegressor(
+ feature_columns=feature_columns, hidden_units=[3, 3])
+ est.fit(input_fn=boston_input_fn, steps=1)
+ _ = est.evaluate(input_fn=boston_input_fn, steps=1)
+
if __name__ == '__main__':
tf.test.main()
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index 28b87bd421..5afcd3dbbe 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -29,6 +29,7 @@ import numpy as np
import six
from tensorflow.contrib import framework as contrib_framework
+from tensorflow.contrib import layers
from tensorflow.contrib.learn.python.learn import graph_actions
from tensorflow.contrib.learn.python.learn import monitors as monitors_lib
from tensorflow.contrib.learn.python.learn.estimators import _sklearn as sklearn
@@ -72,6 +73,9 @@ def _get_input_fn(x, y, input_fn, feed_fn, batch_size, shuffle=False, epochs=1):
contrib_framework.is_tensor(y)):
raise ValueError('Inputs cannot be tensors. Please provide input_fn.')
+ if feed_fn is not None:
+ raise ValueError('Can not provide both feed_fn and x or y.')
+
df = data_feeder.setup_train_data_feeder(x, y, n_classes=None,
batch_size=batch_size,
shuffle=shuffle,
@@ -86,6 +90,41 @@ def _get_input_fn(x, y, input_fn, feed_fn, batch_size, shuffle=False, epochs=1):
return input_fn, feed_fn
+def infer_real_valued_columns_from_input_fn(input_fn):
+ """Creates `FeatureColumn` objects for inputs defined by `input_fn`.
+
+ This interprets all inputs as dense, fixed-length float values. This creates
+ a local graph in which it calls `input_fn` to build the tensors, then discards
+ it.
+
+ Args:
+ input_fn: Function returning a tuple of input and target `Tensor` objects.
+
+ Returns:
+ List of `FeatureColumn` objects.
+ """
+ with ops.Graph().as_default():
+ features, _ = input_fn()
+ return layers.infer_real_valued_columns(features)
+
+
+def infer_real_valued_columns_from_input(x):
+ """Creates `FeatureColumn` objects for inputs defined by input `x`.
+
+ This interprets all inputs as dense, fixed-length float values.
+
+ Args:
+ x: Real-valued matrix of shape [n_samples, n_features...]. Can be
+ iterator that returns arrays of features.
+
+ Returns:
+ List of `FeatureColumn` objects.
+ """
+ input_fn, _ = _get_input_fn(
+ x=x, y=None, input_fn=None, feed_fn=None, batch_size=None)
+ return infer_real_valued_columns_from_input_fn(input_fn)
+
+
def _get_arguments(func):
"""Returns list of arguments this function has."""
if hasattr(func, '__code__'):
@@ -156,10 +195,10 @@ class BaseEstimator(sklearn.BaseEstimator):
"""Trains a model given training data `x` predictions and `y` targets.
Args:
- x: matrix or tensor of shape [n_samples, n_features...]. Can be
- iterator that returns arrays of features. The training input
- samples for fitting the model. If set, `input_fn` must be `None`.
- y: vector or matrix [n_samples] or [n_samples, n_outputs]. Can be
+ x: Matrix of shape [n_samples, n_features...]. Can be iterator that
+ returns arrays of features. The training input samples for fitting the
+ model. If set, `input_fn` must be `None`.
+ y: Vector or matrix [n_samples] or [n_samples, n_outputs]. Can be
iterator that returns array of targets. The training target values
(class labels in classification, real numbers in regression). If set,
`input_fn` must be `None`.
@@ -214,12 +253,12 @@ class BaseEstimator(sklearn.BaseEstimator):
to converge, and you want to split up training into subparts.
Args:
- x: matrix or tensor of shape [n_samples, n_features...]. Can be
- iterator that returns arrays of features. The training input
- samples for fitting the model. If set, `input_fn` must be `None`.
- y: vector or matrix [n_samples] or [n_samples, n_outputs]. Can be
- iterator that returns array of targets. The training target values
- (class label in classification, real numbers in regression). If set,
+ x: Matrix of shape [n_samples, n_features...]. Can be iterator that
+ returns arrays of features. The training input samples for fitting the
+ model. If set, `input_fn` must be `None`.
+ y: Vector or matrix [n_samples] or [n_samples, n_outputs]. Can be
+ iterator that returns array of targets. The training target values
+ (class labels in classification, real numbers in regression). If set,
`input_fn` must be `None`.
input_fn: Input function. If set, `x`, `y`, and `batch_size` must be
`None`.
@@ -264,8 +303,13 @@ class BaseEstimator(sklearn.BaseEstimator):
for which this evaluation was performed.
Args:
- x: features.
- y: targets.
+ x: Matrix of shape [n_samples, n_features...]. Can be iterator that
+ returns arrays of features. The training input samples for fitting the
+ model. If set, `input_fn` must be `None`.
+ y: Vector or matrix [n_samples] or [n_samples, n_outputs]. Can be
+ iterator that returns array of targets. The training target values
+ (class labels in classification, real numbers in regression). If set,
+ `input_fn` must be `None`.
input_fn: Input function. If set, `x`, `y`, and `batch_size` must be
`None`.
feed_fn: Function creating a feed dict every time it is called. Called
@@ -316,7 +360,9 @@ class BaseEstimator(sklearn.BaseEstimator):
"""Returns predictions for given features.
Args:
- x: Features. If set, `input_fn` must be `None`.
+ x: Matrix of shape [n_samples, n_features...]. Can be iterator that
+ returns arrays of features. The training input samples for fitting the
+ model. If set, `input_fn` must be `None`.
input_fn: Input function. If set, `x` and 'batch_size' must be `None`.
batch_size: Override default batch size. If set, 'input_fn' must be
'None'.
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
index 868ccaa5f1..0722699d73 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
@@ -23,30 +23,31 @@ import itertools
import tempfile
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.estimators import _sklearn
+_BOSTON_INPUT_DIM = 13
+_IRIS_INPUT_DIM = 4
+
+
def boston_input_fn():
boston = tf.contrib.learn.datasets.load_boston()
features = tf.cast(
- tf.reshape(
- tf.constant(boston.data), [-1, 13]), tf.float32)
+ tf.reshape(tf.constant(boston.data), [-1, _BOSTON_INPUT_DIM]), tf.float32)
target = tf.cast(
- tf.reshape(
- tf.constant(boston.target), [-1, 1]), tf.float32)
+ tf.reshape(tf.constant(boston.target), [-1, 1]), tf.float32)
return features, target
def iris_input_fn():
iris = tf.contrib.learn.datasets.load_iris()
features = tf.cast(
- tf.reshape(
- tf.constant(iris.data), [-1, 4]), tf.float32)
+ tf.reshape(tf.constant(iris.data), [-1, _IRIS_INPUT_DIM]), tf.float32)
target = tf.cast(
- tf.reshape(
- tf.constant(iris.target), [-1]), tf.int32)
+ tf.reshape(tf.constant(iris.target), [-1]), tf.int32)
return features, target
@@ -54,11 +55,10 @@ def boston_eval_fn():
boston = tf.contrib.learn.datasets.load_boston()
n_examples = len(boston.target)
features = tf.cast(
- tf.reshape(
- tf.constant(boston.data), [n_examples, 13]), tf.float32)
+ tf.reshape(tf.constant(boston.data), [n_examples, _BOSTON_INPUT_DIM]),
+ tf.float32)
target = tf.cast(
- tf.reshape(
- tf.constant(boston.target), [n_examples, 1]), tf.float32)
+ tf.reshape(tf.constant(boston.target), [n_examples, 1]), tf.float32)
return tf.concat(0, [features, features]), tf.concat(0, [target, target])
@@ -327,5 +327,114 @@ class EstimatorTest(tf.test.TestCase):
est.fit(input_fn=boston_input_fn, steps=21, monitors=[CheckCallsMonitor()])
+class InferRealValuedColumnsTest(tf.test.TestCase):
+
+ def testInvalidArgs(self):
+ with self.assertRaisesRegexp(ValueError, 'x or input_fn must be provided'):
+ tf.contrib.learn.infer_real_valued_columns_from_input(None)
+
+ with self.assertRaisesRegexp(ValueError, 'cannot be tensors'):
+ tf.contrib.learn.infer_real_valued_columns_from_input(tf.constant(1.0))
+
+ def _assert_single_feature_column(
+ self, expected_shape, expected_dtype, feature_columns):
+ self.assertEqual(1, len(feature_columns))
+ feature_column = feature_columns[0]
+ self.assertEqual('', feature_column.name)
+ self.assertEqual({
+ '': tf.FixedLenFeature(shape=expected_shape, dtype=expected_dtype)
+ }, feature_column.config)
+
+ # Note: See tf.contrib.learn.io.data_feeder for why int32 converts to float32.
+ def testInt32Input(self):
+ feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input(
+ np.ones(shape=[7, 8], dtype=np.int32))
+ self._assert_single_feature_column([8], tf.float32, feature_columns)
+
+ def testInt32InputFn(self):
+ feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input_fn(
+ lambda: (tf.ones(shape=[7, 8], dtype=tf.int32), None))
+ self._assert_single_feature_column([8], tf.int32, feature_columns)
+
+ # Note: See tf.contrib.learn.io.data_feeder for why int64 doesn't convert to
+ # float64.
+ def testInt64Input(self):
+ feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input(
+ np.ones(shape=[7, 8], dtype=np.int64))
+ self._assert_single_feature_column([8], tf.int64, feature_columns)
+
+ def testInt64InputFn(self):
+ feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input_fn(
+ lambda: (tf.ones(shape=[7, 8], dtype=tf.int64), None))
+ self._assert_single_feature_column([8], tf.int64, feature_columns)
+
+ def testFloat32Input(self):
+ feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input(
+ np.ones(shape=[7, 8], dtype=np.float32))
+ self._assert_single_feature_column([8], tf.float32, feature_columns)
+
+ def testFloat32InputFn(self):
+ feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input_fn(
+ lambda: (tf.ones(shape=[7, 8], dtype=tf.float32), None))
+ self._assert_single_feature_column([8], tf.float32, feature_columns)
+
+ # Note: See tf.contrib.learn.io.data_feeder for why float64 converts to
+ # float32.
+ def testFloat64Input(self):
+ feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input(
+ np.ones(shape=[7, 8], dtype=np.float64))
+ self._assert_single_feature_column([8], tf.float32, feature_columns)
+
+ def testFloat64InputFn(self):
+ feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input_fn(
+ lambda: (tf.ones(shape=[7, 8], dtype=tf.float64), None))
+ self._assert_single_feature_column([8], tf.float64, feature_columns)
+
+ def testBoolInput(self):
+ feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input(
+ np.array([[False for _ in xrange(8)] for _ in xrange(7)]))
+ self._assert_single_feature_column([8], tf.float32, feature_columns)
+
+ def testBoolInputFn(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'on integer or non floating types are not supported'):
+ # pylint: disable=g-long-lambda
+ tf.contrib.learn.infer_real_valued_columns_from_input_fn(
+ lambda: (tf.constant(False, shape=[7, 8], dtype=tf.bool), None))
+
+ def testInvalidStringInput(self):
+ # pylint: disable=g-long-lambda
+ with self.assertRaisesRegexp(
+ ValueError, 'could not convert string to float'):
+ tf.contrib.learn.infer_real_valued_columns_from_input(
+ np.array([['foo%d' % i for i in xrange(8)] for _ in xrange(7)]))
+
+ def testStringInput(self):
+ # pylint: disable=g-long-lambda
+ feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input(
+ np.array([['%d.0' % i for i in xrange(8)] for _ in xrange(7)]))
+ self._assert_single_feature_column([8], tf.float32, feature_columns)
+
+ def testStringInputFn(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'on integer or non floating types are not supported'):
+ # pylint: disable=g-long-lambda
+ tf.contrib.learn.infer_real_valued_columns_from_input_fn(
+ lambda: (
+ tf.constant([['%d.0' % i for i in xrange(8)] for _ in xrange(7)]),
+ None))
+
+ def testBostonInputFn(self):
+ feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input_fn(
+ boston_input_fn)
+ self._assert_single_feature_column(
+ [_BOSTON_INPUT_DIM], tf.float32, feature_columns)
+
+ def testIrisInputFn(self):
+ feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input_fn(
+ iris_input_fn)
+ self._assert_single_feature_column(
+ [_IRIS_INPUT_DIM], tf.float32, feature_columns)
+
if __name__ == '__main__':
tf.test.main()
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py
index 2b4920b59c..f7016c5c21 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/linear.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py
@@ -27,6 +27,20 @@ from tensorflow.contrib.learn.python.learn.estimators import sdca_optimizer
from tensorflow.contrib.learn.python.learn.estimators.base import DeprecatedMixin
from tensorflow.python.framework import ops
from tensorflow.python.ops import logging_ops
+from tensorflow.python.platform import tf_logging as logging
+
+
+# TODO(b/29580537): Replace with @changing decorator.
+def _changing(feature_columns):
+ if feature_columns is not None:
+ return
+ logging.warn(
+ "Change warning: `feature_columns` will be required after 2016-08-01.\n"
+ "Instructions for updating:\n"
+ "Pass `tf.contrib.learn.infer_real_valued_columns_from_input(x)` or"
+ " `tf.contrib.learn.infer_real_valued_columns_from_input_fn(input_fn)`"
+ " as `feature_columns`, where `x` or `input_fn` is your argument to"
+ " `fit`, `evaluate`, or `predict`.")
class LinearClassifier(dnn_linear_combined.DNNLinearCombinedClassifier):
@@ -124,6 +138,7 @@ class LinearClassifier(dnn_linear_combined.DNNLinearCombinedClassifier):
Returns:
A `LinearClassifier` estimator.
"""
+ _changing(feature_columns)
super(LinearClassifier, self).__init__(
model_dir=model_dir,
n_classes=n_classes,
@@ -135,8 +150,7 @@ class LinearClassifier(dnn_linear_combined.DNNLinearCombinedClassifier):
config=config)
self._feature_columns_inferred = False
- # TODO(ptucker): Update this class to require caller pass `feature_columns` to
- # ctor, so we can remove feature_column inference.
+ # TODO(b/29580537): Remove feature_columns inference.
def _validate_linear_feature_columns(self, features):
if self._linear_feature_columns is None:
self._linear_feature_columns = layers.infer_real_valued_columns(features)
@@ -275,6 +289,7 @@ class LinearRegressor(dnn_linear_combined.DNNLinearCombinedRegressor):
Returns:
A `LinearRegressor` estimator.
"""
+ _changing(feature_columns)
super(LinearRegressor, self).__init__(
model_dir=model_dir,
weight_column_name=weight_column_name,
@@ -286,6 +301,7 @@ class LinearRegressor(dnn_linear_combined.DNNLinearCombinedRegressor):
config=config)
self._feature_columns_inferred = False
+ # TODO(b/29580537): Remove feature_columns inference.
def _validate_linear_feature_columns(self, features):
if self._linear_feature_columns is None:
self._linear_feature_columns = layers.infer_real_valued_columns(features)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
index 9fd2e3af2e..b9e87fbacf 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
@@ -326,7 +326,9 @@ class LinearRegressorTest(tf.test.TestCase):
weights = 10 * rng.randn(n_weights)
y = np.dot(x, weights)
y += rng.randn(len(x)) * 0.05 + rng.normal(bias, 0.01)
- regressor = tf.contrib.learn.LinearRegressor()
+ feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input(x)
+ regressor = tf.contrib.learn.LinearRegressor(
+ feature_columns=feature_columns)
regressor.fit(x, y, batch_size=32, steps=20000)
# Have to flatten weights since they come in (x, 1) shape.
self.assertAllClose(weights, regressor.weights_.flatten(), rtol=1)
@@ -341,13 +343,21 @@ def boston_input_fn():
return features, target
-class InferedColumnTest(tf.test.TestCase):
+class FeatureColumnTest(tf.test.TestCase):
- def testTrain(self):
+ # TODO(b/29580537): Remove when we deprecate feature column inference.
+ def testTrainWithInferredFeatureColumns(self):
est = tf.contrib.learn.LinearRegressor()
est.fit(input_fn=boston_input_fn, steps=1)
_ = est.evaluate(input_fn=boston_input_fn, steps=1)
+ def testTrain(self):
+ feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input_fn(
+ boston_input_fn)
+ est = tf.contrib.learn.LinearRegressor(feature_columns=feature_columns)
+ est.fit(input_fn=boston_input_fn, steps=1)
+ _ = est.evaluate(input_fn=boston_input_fn, steps=1)
+
if __name__ == '__main__':
tf.test.main()
diff --git a/tensorflow/contrib/learn/python/learn/tests/base_test.py b/tensorflow/contrib/learn/python/learn/tests/base_test.py
index 34a58cddfc..180fb09701 100644
--- a/tensorflow/contrib/learn/python/learn/tests/base_test.py
+++ b/tensorflow/contrib/learn/python/learn/tests/base_test.py
@@ -34,7 +34,8 @@ from tensorflow.contrib.learn.python.learn.estimators._sklearn import log_loss
from tensorflow.contrib.learn.python.learn.estimators._sklearn import mean_squared_error
-class BaseTest(tf.test.TestCase):
+# TODO(b/29580537): Remove when we deprecate feature column inference.
+class InferredfeatureColumnTest(tf.test.TestCase):
"""Test base estimators."""
def testOneDim(self):
@@ -150,6 +151,138 @@ class BaseTest(tf.test.TestCase):
score = mean_squared_error(boston.target, regressor.predict(boston.data))
self.assertLess(score, 150, "Failed with score = {0}".format(score))
+
+class BaseTest(tf.test.TestCase):
+ """Test base estimators."""
+
+ def testOneDim(self):
+ random.seed(42)
+ x = np.random.rand(1000)
+ y = 2 * x + 3
+ feature_columns = learn.infer_real_valued_columns_from_input(x)
+ regressor = learn.TensorFlowLinearRegressor(feature_columns=feature_columns)
+ regressor.fit(x, y)
+ score = mean_squared_error(y, regressor.predict(x))
+ self.assertLess(score, 1.0, "Failed with score = {0}".format(score))
+
+ def testIris(self):
+ iris = datasets.load_iris()
+ classifier = learn.TensorFlowLinearClassifier(
+ feature_columns=learn.infer_real_valued_columns_from_input(iris.data),
+ n_classes=3)
+ classifier.fit(iris.data, [x for x in iris.target])
+ score = accuracy_score(iris.target, classifier.predict(iris.data))
+ self.assertGreater(score, 0.7, "Failed with score = {0}".format(score))
+
+ def testIrisClassWeight(self):
+ iris = datasets.load_iris()
+ # Note, class_weight are not supported anymore :( Use weight_column.
+ with self.assertRaises(ValueError):
+ classifier = learn.TensorFlowLinearClassifier(
+ feature_columns=learn.infer_real_valued_columns_from_input(iris.data),
+ n_classes=3, class_weight=[0.1, 0.8, 0.1])
+ classifier.fit(iris.data, iris.target)
+ score = accuracy_score(iris.target, classifier.predict(iris.data))
+ self.assertLess(score, 0.7, "Failed with score = {0}".format(score))
+
+ def testIrisAllVariables(self):
+ iris = datasets.load_iris()
+ classifier = learn.TensorFlowLinearClassifier(
+ feature_columns=learn.infer_real_valued_columns_from_input(iris.data),
+ n_classes=3)
+ classifier.fit(iris.data, [x for x in iris.target])
+ self.assertEqual(
+ classifier.get_variable_names(),
+ ["centered_bias_weight",
+ "centered_bias_weight/Adagrad",
+ "global_step",
+ "linear/_weight",
+ "linear/_weight/Ftrl",
+ "linear/_weight/Ftrl_1",
+ "linear/bias_weight",
+ "linear/bias_weight/Ftrl",
+ "linear/bias_weight/Ftrl_1"])
+
+ def testIrisSummaries(self):
+ iris = datasets.load_iris()
+ output_dir = tempfile.mkdtemp() + "learn_tests/"
+ classifier = learn.TensorFlowLinearClassifier(
+ feature_columns=learn.infer_real_valued_columns_from_input(iris.data),
+ n_classes=3, model_dir=output_dir)
+ classifier.fit(iris.data, iris.target)
+ score = accuracy_score(iris.target, classifier.predict(iris.data))
+ self.assertGreater(score, 0.5, "Failed with score = {0}".format(score))
+ # TODO(ipolosukhin): Check that summaries are correclty written.
+
+ def testIrisContinueTraining(self):
+ iris = datasets.load_iris()
+ classifier = learn.TensorFlowLinearClassifier(
+ feature_columns=learn.infer_real_valued_columns_from_input(iris.data),
+ n_classes=3,
+ learning_rate=0.01,
+ continue_training=True,
+ steps=250)
+ classifier.fit(iris.data, iris.target)
+ score1 = accuracy_score(iris.target, classifier.predict(iris.data))
+ classifier.fit(iris.data, iris.target, steps=500)
+ score2 = accuracy_score(iris.target, classifier.predict(iris.data))
+ self.assertGreater(
+ score2, score1,
+ "Failed with score2 {0} <= score1 {1}".format(score2, score1))
+
+ def testIrisStreaming(self):
+ iris = datasets.load_iris()
+
+ def iris_data():
+ while True:
+ for x in iris.data:
+ yield x
+
+ def iris_predict_data():
+ for x in iris.data:
+ yield x
+
+ def iris_target():
+ while True:
+ for y in iris.target:
+ yield y
+
+ classifier = learn.TensorFlowLinearClassifier(
+ feature_columns=learn.infer_real_valued_columns_from_input(iris.data),
+ n_classes=3, steps=100)
+ classifier.fit(iris_data(), iris_target())
+ score1 = accuracy_score(iris.target, classifier.predict(iris.data))
+ score2 = accuracy_score(iris.target,
+ classifier.predict(iris_predict_data()))
+ self.assertGreater(score1, 0.5, "Failed with score = {0}".format(score1))
+ self.assertEqual(score2, score1, "Scores from {0} iterator doesn't "
+ "match score {1} from full "
+ "data.".format(score2, score1))
+
+ def testIris_proba(self):
+ # If sklearn available.
+ if log_loss:
+ random.seed(42)
+ iris = datasets.load_iris()
+ classifier = learn.TensorFlowClassifier(
+ feature_columns=learn.infer_real_valued_columns_from_input(iris.data),
+ n_classes=3, steps=250)
+ classifier.fit(iris.data, iris.target)
+ score = log_loss(iris.target, classifier.predict_proba(iris.data))
+ self.assertLess(score, 0.8, "Failed with score = {0}".format(score))
+
+ def testBoston(self):
+ random.seed(42)
+ boston = datasets.load_boston()
+ regressor = learn.TensorFlowLinearRegressor(
+ feature_columns=learn.infer_real_valued_columns_from_input(boston.data),
+ batch_size=boston.data.shape[0],
+ steps=500,
+ learning_rate=0.001)
+ regressor.fit(boston.data, boston.target)
+ score = mean_squared_error(boston.target, regressor.predict(boston.data))
+ self.assertLess(score, 150, "Failed with score = {0}".format(score))
+
def testUnfitted(self):
estimator = learn.TensorFlowEstimator(model_fn=None, n_classes=1)
with self.assertRaises(base.NotFittedError):
diff --git a/tensorflow/contrib/learn/python/learn/tests/estimators_test.py b/tensorflow/contrib/learn/python/learn/tests/estimators_test.py
index 3b5db85e17..f7aaefbd50 100644
--- a/tensorflow/contrib/learn/python/learn/tests/estimators_test.py
+++ b/tensorflow/contrib/learn/python/learn/tests/estimators_test.py
@@ -29,6 +29,34 @@ from tensorflow.contrib.learn.python.learn.estimators._sklearn import accuracy_s
from tensorflow.contrib.learn.python.learn.estimators._sklearn import train_test_split
+# TODO(b/29580537): Remove when we deprecate feature column inference.
+class InferredfeatureColumnTest(tf.test.TestCase):
+ """Custom optimizer tests."""
+
+ def testIrisMomentum(self):
+ random.seed(42)
+
+ iris = datasets.load_iris()
+ x_train, x_test, y_train, y_test = train_test_split(iris.data,
+ iris.target,
+ test_size=0.2,
+ random_state=42)
+
+ def custom_optimizer(learning_rate):
+ return tf.train.MomentumOptimizer(learning_rate, 0.9)
+
+ classifier = learn.TensorFlowDNNClassifier(
+ hidden_units=[10, 20, 10],
+ n_classes=3,
+ steps=400,
+ learning_rate=0.01,
+ optimizer=custom_optimizer)
+ classifier.fit(x_train, y_train)
+ score = accuracy_score(y_test, classifier.predict(x_test))
+
+ self.assertGreater(score, 0.65, "Failed with score = {0}".format(score))
+
+
class CustomOptimizer(tf.test.TestCase):
"""Custom optimizer tests."""
@@ -44,11 +72,13 @@ class CustomOptimizer(tf.test.TestCase):
def custom_optimizer(learning_rate):
return tf.train.MomentumOptimizer(learning_rate, 0.9)
- classifier = learn.TensorFlowDNNClassifier(hidden_units=[10, 20, 10],
- n_classes=3,
- steps=400,
- learning_rate=0.01,
- optimizer=custom_optimizer)
+ classifier = learn.TensorFlowDNNClassifier(
+ hidden_units=[10, 20, 10],
+ feature_columns=learn.infer_real_valued_columns_from_input(x_train),
+ n_classes=3,
+ steps=400,
+ learning_rate=0.01,
+ optimizer=custom_optimizer)
classifier.fit(x_train, y_train)
score = accuracy_score(y_test, classifier.predict(x_test))
diff --git a/tensorflow/contrib/learn/python/learn/tests/io_test.py b/tensorflow/contrib/learn/python/learn/tests/io_test.py
index 9643923dc5..6459c3a53e 100644
--- a/tensorflow/contrib/learn/python/learn/tests/io_test.py
+++ b/tensorflow/contrib/learn/python/learn/tests/io_test.py
@@ -43,7 +43,9 @@ class IOTest(tf.test.TestCase):
iris = datasets.load_iris()
data = pd.DataFrame(iris.data)
labels = pd.DataFrame(iris.target)
- classifier = learn.TensorFlowLinearClassifier(n_classes=3)
+ classifier = learn.TensorFlowLinearClassifier(
+ feature_columns=learn.infer_real_valued_columns_from_input(data),
+ n_classes=3)
classifier.fit(data, labels)
score = accuracy_score(labels[0], classifier.predict(data))
self.assertGreater(score, 0.5, "Failed with score = {0}".format(score))
@@ -57,7 +59,9 @@ class IOTest(tf.test.TestCase):
iris = datasets.load_iris()
data = pd.DataFrame(iris.data)
labels = pd.Series(iris.target)
- classifier = learn.TensorFlowLinearClassifier(n_classes=3)
+ classifier = learn.TensorFlowLinearClassifier(
+ feature_columns=learn.infer_real_valued_columns_from_input(data),
+ n_classes=3)
classifier.fit(data, labels)
score = accuracy_score(labels, classifier.predict(data))
self.assertGreater(score, 0.5, "Failed with score = {0}".format(score))
@@ -108,7 +112,9 @@ class IOTest(tf.test.TestCase):
data = dd.from_pandas(data, npartitions=2)
labels = pd.DataFrame(iris.target)
labels = dd.from_pandas(labels, npartitions=2)
- classifier = learn.TensorFlowLinearClassifier(n_classes=3)
+ classifier = learn.TensorFlowLinearClassifier(
+ feature_columns=learn.infer_real_valued_columns_from_input(data),
+ n_classes=3)
classifier.fit(data, labels)
predictions = data.map_partitions(classifier.predict).compute()
score = accuracy_score(labels.compute(), predictions)
diff --git a/tensorflow/contrib/learn/python/learn/tests/multioutput_test.py b/tensorflow/contrib/learn/python/learn/tests/multioutput_test.py
index 4e86e6f3cb..7a2a3ecc9d 100644
--- a/tensorflow/contrib/learn/python/learn/tests/multioutput_test.py
+++ b/tensorflow/contrib/learn/python/learn/tests/multioutput_test.py
@@ -36,8 +36,9 @@ class MultiOutputTest(tf.test.TestCase):
rng = np.random.RandomState(1)
x = np.sort(200 * rng.rand(100, 1) - 100, axis=0)
y = np.array([np.pi * np.sin(x).ravel(), np.pi * np.cos(x).ravel()]).T
- regressor = learn.TensorFlowLinearRegressor(learning_rate=0.01,
- target_dimension=2)
+ regressor = learn.TensorFlowLinearRegressor(
+ feature_columns=learn.infer_real_valued_columns_from_input(x),
+ learning_rate=0.01, target_dimension=2)
regressor.fit(x, y)
score = mean_squared_error(regressor.predict(x), y)
self.assertLess(score, 10, "Failed with score = {0}".format(score))
diff --git a/tensorflow/contrib/learn/python/learn/tests/regression_test.py b/tensorflow/contrib/learn/python/learn/tests/regression_test.py
index d58252c6a1..17cc2a2a27 100644
--- a/tensorflow/contrib/learn/python/learn/tests/regression_test.py
+++ b/tensorflow/contrib/learn/python/learn/tests/regression_test.py
@@ -38,7 +38,9 @@ class RegressionTest(tf.test.TestCase):
weights = 10 * rng.randn(n_weights)
y = np.dot(x, weights)
y += rng.randn(len(x)) * 0.05 + rng.normal(bias, 0.01)
- regressor = learn.TensorFlowLinearRegressor(optimizer="SGD")
+ regressor = learn.TensorFlowLinearRegressor(
+ feature_columns=learn.infer_real_valued_columns_from_input(x),
+ optimizer="SGD")
regressor.fit(x, y, steps=200)
# Have to flatten weights since they come in (x, 1) shape.
self.assertAllClose(weights, regressor.weights_.flatten(), rtol=0.01)
diff --git a/tensorflow/examples/skflow/boston.py b/tensorflow/examples/skflow/boston.py
index c2953abd58..7aacb1b9ff 100644
--- a/tensorflow/examples/skflow/boston.py
+++ b/tensorflow/examples/skflow/boston.py
@@ -16,7 +16,8 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from sklearn import cross_validation, metrics
+from sklearn import cross_validation
+from sklearn import metrics
from sklearn import preprocessing
import tensorflow as tf
from tensorflow.contrib import learn
@@ -28,15 +29,17 @@ def main(unused_argv):
x, y = boston.data, boston.target
# Split dataset into train / test
- x_train, x_test, y_train, y_test = cross_validation.train_test_split(x, y,
- test_size=0.2, random_state=42)
+ x_train, x_test, y_train, y_test = cross_validation.train_test_split(
+ x, y, test_size=0.2, random_state=42)
# Scale data (training set) to 0 mean and unit standard deviation.
scaler = preprocessing.StandardScaler()
x_train = scaler.fit_transform(x_train)
# Build 2 layer fully connected DNN with 10, 10 units respectively.
- regressor = learn.DNNRegressor(hidden_units=[10, 10])
+ feature_columns = learn.infer_real_valued_columns_from_input(x_train)
+ regressor = learn.DNNRegressor(
+ feature_columns=feature_columns, hidden_units=[10, 10])
# Fit
regressor.fit(x_train, y_train, steps=5000, batch_size=1)
diff --git a/tensorflow/examples/skflow/hdf5_classification.py b/tensorflow/examples/skflow/hdf5_classification.py
index 0a4a7fd731..edcce6fe6f 100644
--- a/tensorflow/examples/skflow/hdf5_classification.py
+++ b/tensorflow/examples/skflow/hdf5_classification.py
@@ -11,39 +11,46 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+"""Example of DNNClassifier for Iris plant dataset, h5 format."""
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from sklearn import metrics, cross_validation
+from sklearn import cross_validation
+from sklearn import metrics
from tensorflow.contrib import learn
-import h5py
+import h5py # pylint: disable=g-bad-import-order
# Load dataset.
iris = learn.datasets.load_dataset('iris')
-X_train, X_test, y_train, y_test = cross_validation.train_test_split(iris.data, iris.target,
- test_size=0.2, random_state=42)
+x_train, x_test, y_train, y_test = cross_validation.train_test_split(
+ iris.data, iris.target, test_size=0.2, random_state=42)
-# Note that we are saving and load iris data as h5 format as a simple demonstration here.
+# Note that we are saving and load iris data as h5 format as a simple
+# demonstration here.
h5f = h5py.File('test_hdf5.h5', 'w')
-h5f.create_dataset('X_train', data=X_train)
-h5f.create_dataset('X_test', data=X_test)
+h5f.create_dataset('X_train', data=x_train)
+h5f.create_dataset('X_test', data=x_test)
h5f.create_dataset('y_train', data=y_train)
h5f.create_dataset('y_test', data=y_test)
h5f.close()
h5f = h5py.File('test_hdf5.h5', 'r')
-X_train = h5f['X_train']
-X_test = h5f['X_test']
+x_train = h5f['X_train']
+x_test = h5f['X_test']
y_train = h5f['y_train']
y_test = h5f['y_test']
# Build 3 layer DNN with 10, 20, 10 units respectively.
-classifier = learn.TensorFlowDNNClassifier(hidden_units=[10, 20, 10],
- n_classes=3, steps=200)
+feature_columns = learn.infer_real_valued_columns_from_input(x_train)
+classifier = learn.TensorFlowDNNClassifier(
+ feature_columns=feature_columns, hidden_units=[10, 20, 10], n_classes=3,
+ steps=200)
# Fit and predict.
-classifier.fit(X_train, y_train)
-score = metrics.accuracy_score(y_test, classifier.predict(X_test))
+classifier.fit(x_train, y_train)
+score = metrics.accuracy_score(y_test, classifier.predict(x_test))
print('Accuracy: {0:f}'.format(score))
diff --git a/tensorflow/examples/skflow/iris.py b/tensorflow/examples/skflow/iris.py
index 62a58440e7..9bd3faa942 100644
--- a/tensorflow/examples/skflow/iris.py
+++ b/tensorflow/examples/skflow/iris.py
@@ -11,12 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
"""Example of DNNClassifier for Iris plant dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from sklearn import metrics, cross_validation
+from sklearn import cross_validation
+from sklearn import metrics
import tensorflow as tf
from tensorflow.contrib import learn
@@ -28,7 +30,9 @@ def main(unused_argv):
iris.data, iris.target, test_size=0.2, random_state=42)
# Build 3 layer DNN with 10, 20, 10 units respectively.
- classifier = learn.DNNClassifier(hidden_units=[10, 20, 10], n_classes=3)
+ feature_columns = learn.infer_real_valued_columns_from_input(x_train)
+ classifier = learn.DNNClassifier(
+ feature_columns=feature_columns, hidden_units=[10, 20, 10], n_classes=3)
# Fit and predict.
classifier.fit(x_train, y_train, steps=200)
diff --git a/tensorflow/examples/skflow/iris_custom_decay_dnn.py b/tensorflow/examples/skflow/iris_custom_decay_dnn.py
index 1ce6a830e4..7a34ca9f13 100644
--- a/tensorflow/examples/skflow/iris_custom_decay_dnn.py
+++ b/tensorflow/examples/skflow/iris_custom_decay_dnn.py
@@ -15,7 +15,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from sklearn import datasets, metrics
+from sklearn import datasets
+from sklearn import metrics
from sklearn.cross_validation import train_test_split
import tensorflow as tf
@@ -27,12 +28,16 @@ def optimizer_exp_decay():
decay_steps=100, decay_rate=0.001)
return tf.train.AdagradOptimizer(learning_rate=learning_rate)
+
def main(unused_argv):
iris = datasets.load_iris()
x_train, x_test, y_train, y_test = train_test_split(
iris.data, iris.target, test_size=0.2, random_state=42)
- classifier = tf.contrib.learn.DNNClassifier(hidden_units=[10, 20, 10],
+ feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input(
+ x_train)
+ classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
+ hidden_units=[10, 20, 10],
n_classes=3,
optimizer=optimizer_exp_decay)
diff --git a/tensorflow/examples/skflow/iris_custom_model.py b/tensorflow/examples/skflow/iris_custom_model.py
index 009e375274..afce504b74 100644
--- a/tensorflow/examples/skflow/iris_custom_model.py
+++ b/tensorflow/examples/skflow/iris_custom_model.py
@@ -16,7 +16,9 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from sklearn import datasets, metrics, cross_validation
+from sklearn import cross_validation
+from sklearn import datasets
+from sklearn import metrics
import tensorflow as tf
from tensorflow.contrib import layers
from tensorflow.contrib import learn
diff --git a/tensorflow/examples/skflow/iris_run_config.py b/tensorflow/examples/skflow/iris_run_config.py
index c678c7c738..de9b44d460 100644
--- a/tensorflow/examples/skflow/iris_run_config.py
+++ b/tensorflow/examples/skflow/iris_run_config.py
@@ -11,11 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+"""Example of DNNClassifier for Iris plant dataset, with run config."""
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from sklearn import datasets, metrics, cross_validation
+from sklearn import cross_validation
+from sklearn import datasets
+from sklearn import metrics
import tensorflow as tf
@@ -32,7 +37,10 @@ def main(unused_argv):
num_cores=3, gpu_memory_fraction=0.6)
# Build 3 layer DNN with 10, 20, 10 units respectively.
- classifier = tf.contrib.learn.DNNClassifier(hidden_units=[10, 20, 10],
+ feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input(
+ x_train)
+ classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
+ hidden_units=[10, 20, 10],
n_classes=3,
config=run_config)
diff --git a/tensorflow/examples/skflow/iris_save_restore.py b/tensorflow/examples/skflow/iris_save_restore.py
index d29237a26f..84b2cb343f 100644
--- a/tensorflow/examples/skflow/iris_save_restore.py
+++ b/tensorflow/examples/skflow/iris_save_restore.py
@@ -11,35 +11,43 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+"""Example of DNNClassifier for Iris plant dataset, with save & restore."""
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import shutil
-from sklearn import datasets, metrics, cross_validation
+from sklearn import cross_validation
+from sklearn import datasets
+from sklearn import metrics
from tensorflow.contrib import learn
iris = datasets.load_iris()
-X_train, X_test, y_train, y_test = cross_validation.train_test_split(iris.data, iris.target,
- test_size=0.2, random_state=42)
+x_train, x_test, y_train, y_test = cross_validation.train_test_split(
+ iris.data, iris.target, test_size=0.2, random_state=42)
-classifier = learn.TensorFlowLinearClassifier(n_classes=3)
-classifier.fit(X_train, y_train)
-score = metrics.accuracy_score(y_test, classifier.predict(X_test))
+classifier = learn.TensorFlowLinearClassifier(
+ feature_columns=learn.infer_real_valued_columns_from_input(x_train),
+ n_classes=3)
+classifier.fit(x_train, y_train)
+score = metrics.accuracy_score(y_test, classifier.predict(x_test))
print('Accuracy: {0:f}'.format(score))
# Clean checkpoint folder if exists
try:
- shutil.rmtree('/tmp/skflow_examples/iris_custom_model')
+ shutil.rmtree('/tmp/skflow_examples/iris_custom_model')
except OSError:
- pass
+ pass
# Save model, parameters and learned variables.
classifier.save('/tmp/skflow_examples/iris_custom_model')
classifier = None
## Restore everything
-new_classifier = learn.TensorFlowEstimator.restore('/tmp/skflow_examples/iris_custom_model')
-score = metrics.accuracy_score(y_test, new_classifier.predict(X_test))
+new_classifier = learn.TensorFlowEstimator.restore(
+ '/tmp/skflow_examples/iris_custom_model')
+score = metrics.accuracy_score(y_test, new_classifier.predict(x_test))
print('Accuracy: {0:f}'.format(score))
diff --git a/tensorflow/examples/skflow/iris_val_based_early_stopping.py b/tensorflow/examples/skflow/iris_val_based_early_stopping.py
index 05dfa96a07..70dd8053aa 100644
--- a/tensorflow/examples/skflow/iris_val_based_early_stopping.py
+++ b/tensorflow/examples/skflow/iris_val_based_early_stopping.py
@@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+"""Example of DNNClassifier for Iris plant dataset, with early stopping."""
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -35,6 +38,7 @@ def main(unused_argv):
# classifier with early stopping on training data
classifier1 = learn.DNNClassifier(
+ feature_columns=learn.infer_real_valued_columns_from_input(x_train),
hidden_units=[10, 20, 10], n_classes=3, model_dir='/tmp/iris_model/')
classifier1.fit(x=x_train, y=y_train, steps=2000)
score1 = metrics.accuracy_score(y_test, classifier1.predict(x_test))
@@ -42,6 +46,7 @@ def main(unused_argv):
# classifier with early stopping on validation data, save frequently for
# monitor to pick up new checkpoints.
classifier2 = learn.DNNClassifier(
+ feature_columns=learn.infer_real_valued_columns_from_input(x_train),
hidden_units=[10, 20, 10], n_classes=3, model_dir='/tmp/iris_model_val/',
config=tf.contrib.learn.RunConfig(save_checkpoints_secs=1))
classifier2.fit(x=x_train, y=y_train, steps=2000, monitors=[val_monitor])
diff --git a/tensorflow/examples/skflow/iris_with_pipeline.py b/tensorflow/examples/skflow/iris_with_pipeline.py
index 5535cd9e3b..ee5f9aed81 100644
--- a/tensorflow/examples/skflow/iris_with_pipeline.py
+++ b/tensorflow/examples/skflow/iris_with_pipeline.py
@@ -11,15 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+"""Example of DNNClassifier for Iris plant dataset, with pipeline."""
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from sklearn.pipeline import Pipeline
-from sklearn.datasets import load_iris
from sklearn import cross_validation
-from sklearn.preprocessing import StandardScaler
+from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score
+from sklearn.pipeline import Pipeline
+from sklearn.preprocessing import StandardScaler
import tensorflow as tf
from tensorflow.contrib import learn
@@ -34,8 +37,10 @@ def main(unused_argv):
# will do the right thing.
scaler = StandardScaler()
- # DNN classifier
- classifier = learn.DNNClassifier(hidden_units=[10, 20, 10], n_classes=3)
+ # DNN classifier.
+ classifier = learn.DNNClassifier(
+ feature_columns=learn.infer_real_valued_columns_from_input(x_train),
+ hidden_units=[10, 20, 10], n_classes=3)
pipeline = Pipeline([('scaler', scaler),
('DNNclassifier', classifier)])
diff --git a/tensorflow/examples/skflow/mnist.py b/tensorflow/examples/skflow/mnist.py
index d1288a31e9..3b11708a27 100644
--- a/tensorflow/examples/skflow/mnist.py
+++ b/tensorflow/examples/skflow/mnist.py
@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""
-This example showcases how simple it is to build image classification networks.
+"""This showcases how simple it is to build image classification networks.
+
It follows description from this TensorFlow tutorial:
https://www.tensorflow.org/versions/master/tutorials/mnist/pros/index.html#deep-mnist-for-experts
"""
@@ -32,42 +32,50 @@ mnist = learn.datasets.load_dataset('mnist')
### Linear classifier.
+feature_columns = learn.infer_real_valued_columns_from_input(mnist.train.images)
classifier = learn.TensorFlowLinearClassifier(
- n_classes=10, batch_size=100, steps=1000, learning_rate=0.01)
+ feature_columns=feature_columns, n_classes=10, batch_size=100, steps=1000,
+ learning_rate=0.01)
classifier.fit(mnist.train.images, mnist.train.labels)
-score = metrics.accuracy_score(mnist.test.labels, classifier.predict(mnist.test.images))
+score = metrics.accuracy_score(
+ mnist.test.labels, classifier.predict(mnist.test.images))
print('Accuracy: {0:f}'.format(score))
### Convolutional network
+
def max_pool_2x2(tensor_in):
- return tf.nn.max_pool(tensor_in, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
- padding='SAME')
+ return tf.nn.max_pool(
+ tensor_in, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
+
def conv_model(X, y):
- # reshape X to 4d tensor with 2nd and 3rd dimensions being image width and height
- # final dimension being the number of color channels
- X = tf.reshape(X, [-1, 28, 28, 1])
- # first conv layer will compute 32 features for each 5x5 patch
- with tf.variable_scope('conv_layer1'):
- h_conv1 = learn.ops.conv2d(X, n_filters=32, filter_shape=[5, 5],
- bias=True, activation=tf.nn.relu)
- h_pool1 = max_pool_2x2(h_conv1)
- # second conv layer will compute 64 features for each 5x5 patch
- with tf.variable_scope('conv_layer2'):
- h_conv2 = learn.ops.conv2d(h_pool1, n_filters=64, filter_shape=[5, 5],
- bias=True, activation=tf.nn.relu)
- h_pool2 = max_pool_2x2(h_conv2)
- # reshape tensor into a batch of vectors
- h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
- # densely connected layer with 1024 neurons
- h_fc1 = learn.ops.dnn(h_pool2_flat, [1024], activation=tf.nn.relu, dropout=0.5)
- return learn.models.logistic_regression(h_fc1, y)
+ # pylint: disable=invalid-name,missing-docstring
+ # reshape X to 4d tensor with 2nd and 3rd dimensions being image width and
+ # height final dimension being the number of color channels.
+ X = tf.reshape(X, [-1, 28, 28, 1])
+ # first conv layer will compute 32 features for each 5x5 patch
+ with tf.variable_scope('conv_layer1'):
+ h_conv1 = learn.ops.conv2d(X, n_filters=32, filter_shape=[5, 5],
+ bias=True, activation=tf.nn.relu)
+ h_pool1 = max_pool_2x2(h_conv1)
+ # second conv layer will compute 64 features for each 5x5 patch.
+ with tf.variable_scope('conv_layer2'):
+ h_conv2 = learn.ops.conv2d(h_pool1, n_filters=64, filter_shape=[5, 5],
+ bias=True, activation=tf.nn.relu)
+ h_pool2 = max_pool_2x2(h_conv2)
+ # reshape tensor into a batch of vectors
+ h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
+ # densely connected layer with 1024 neurons.
+ h_fc1 = learn.ops.dnn(
+ h_pool2_flat, [1024], activation=tf.nn.relu, dropout=0.5)
+ return learn.models.logistic_regression(h_fc1, y)
-# Training and predicting
+# Training and predicting.
classifier = learn.TensorFlowEstimator(
model_fn=conv_model, n_classes=10, batch_size=100, steps=20000,
learning_rate=0.001)
classifier.fit(mnist.train.images, mnist.train.labels)
-score = metrics.accuracy_score(mnist.test.labels, classifier.predict(mnist.test.images))
+score = metrics.accuracy_score(
+ mnist.test.labels, classifier.predict(mnist.test.images))
print('Accuracy: {0:f}'.format(score))
diff --git a/tensorflow/examples/skflow/mnist_weights.py b/tensorflow/examples/skflow/mnist_weights.py
index b0c2ea583e..9ad019f9a4 100644
--- a/tensorflow/examples/skflow/mnist_weights.py
+++ b/tensorflow/examples/skflow/mnist_weights.py
@@ -1,4 +1,4 @@
-#t Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""
-This example demonstrates one way to access the weights of a custom skflow
-model. It is otherwise identical to the standard MNIST convolutional code.
+"""This demonstrates one way to access the weights of a custom skflow model.
+
+It is otherwise identical to the standard MNIST convolutional code.
"""
from __future__ import absolute_import
@@ -23,7 +23,6 @@ from __future__ import print_function
from sklearn import metrics
import tensorflow as tf
-from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib import learn
### Download and load MNIST data.
@@ -32,54 +31,63 @@ mnist = learn.datasets.load_dataset('mnist')
### Linear classifier.
+feature_columns = learn.infer_real_valued_columns_from_input(mnist.train.images)
classifier = learn.TensorFlowLinearClassifier(
- n_classes=10, batch_size=100, steps=1000, learning_rate=0.01)
+ feature_columns=feature_columns, n_classes=10, batch_size=100, steps=1000,
+ learning_rate=0.01)
classifier.fit(mnist.train.images, mnist.train.labels)
-score = metrics.accuracy_score(mnist.test.labels, classifier.predict(mnist.test.images))
+score = metrics.accuracy_score(
+ mnist.test.labels, classifier.predict(mnist.test.images))
print('Accuracy: {0:f}'.format(score))
### Convolutional network
+
def max_pool_2x2(tensor_in):
- return tf.nn.max_pool(tensor_in, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
- padding='SAME')
+ return tf.nn.max_pool(tensor_in, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
+ padding='SAME')
+
def conv_model(X, y):
- # reshape X to 4d tensor with 2nd and 3rd dimensions being image width and height
- # final dimension being the number of color channels
- X = tf.reshape(X, [-1, 28, 28, 1])
- # first conv layer will compute 32 features for each 5x5 patch
- with tf.variable_scope('conv_layer1'):
- h_conv1 = learn.ops.conv2d(X, n_filters=32, filter_shape=[5, 5],
- bias=True, activation=tf.nn.relu)
- h_pool1 = max_pool_2x2(h_conv1)
- # second conv layer will compute 64 features for each 5x5 patch
- with tf.variable_scope('conv_layer2'):
- h_conv2 = learn.ops.conv2d(h_pool1, n_filters=64, filter_shape=[5, 5],
- bias=True, activation=tf.nn.relu)
- h_pool2 = max_pool_2x2(h_conv2)
- # reshape tensor into a batch of vectors
- h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
- # densely connected layer with 1024 neurons
- h_fc1 = learn.ops.dnn(h_pool2_flat, [1024], activation=tf.nn.relu, dropout=0.5)
- return learn.models.logistic_regression(h_fc1, y)
+ # pylint: disable=invalid-name,missing-docstring
+ # reshape X to 4d tensor with 2nd and 3rd dimensions being image width and
+ # height final dimension being the number of color channels
+ X = tf.reshape(X, [-1, 28, 28, 1])
+ # first conv layer will compute 32 features for each 5x5 patch
+ with tf.variable_scope('conv_layer1'):
+ h_conv1 = learn.ops.conv2d(X, n_filters=32, filter_shape=[5, 5],
+ bias=True, activation=tf.nn.relu)
+ h_pool1 = max_pool_2x2(h_conv1)
+ # second conv layer will compute 64 features for each 5x5 patch
+ with tf.variable_scope('conv_layer2'):
+ h_conv2 = learn.ops.conv2d(h_pool1, n_filters=64, filter_shape=[5, 5],
+ bias=True, activation=tf.nn.relu)
+ h_pool2 = max_pool_2x2(h_conv2)
+ # reshape tensor into a batch of vectors
+ h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
+ # densely connected layer with 1024 neurons
+ h_fc1 = learn.ops.dnn(
+ h_pool2_flat, [1024], activation=tf.nn.relu, dropout=0.5)
+ return learn.models.logistic_regression(h_fc1, y)
# Training and predicting
classifier = learn.TensorFlowEstimator(
model_fn=conv_model, n_classes=10, batch_size=100, steps=20000,
learning_rate=0.001)
classifier.fit(mnist.train.images, mnist.train.labels)
-score = metrics.accuracy_score(mnist.test.labels, classifier.predict(mnist.test.images))
+score = metrics.accuracy_score(
+ mnist.test.labels, classifier.predict(mnist.test.images))
print('Accuracy: {0:f}'.format(score))
# Examining fitted weights
## General usage is classifier.get_tensor_value('foo')
## 'foo' must be the variable scope of the desired tensor followed by the
-## graph path.
+## graph path.
-## To understand the mechanism and figure out the right scope and path, you can do logging.
-## Then use TensorBoard or a text editor on the log file to look at available strings.
+## To understand the mechanism and figure out the right scope and path, you can
+## do logging. Then use TensorBoard or a text editor on the log file to look at
+## available strings.
## First Convolutional Layer
print('1st Convolutional Layer weights and Bias')
diff --git a/tensorflow/examples/skflow/out_of_core_data_classification.py b/tensorflow/examples/skflow/out_of_core_data_classification.py
index 6328941e6d..5f612db3d7 100644
--- a/tensorflow/examples/skflow/out_of_core_data_classification.py
+++ b/tensorflow/examples/skflow/out_of_core_data_classification.py
@@ -11,41 +11,52 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+"""Example of loading karge data sets into out-of-core dataframe."""
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from sklearn import datasets, metrics, cross_validation
-
-import pandas as pd
+from sklearn import cross_validation
+from sklearn import datasets
+from sklearn import metrics
+# pylint: disable=g-bad-import-order
import dask.dataframe as dd
-
+import pandas as pd
from tensorflow.contrib import learn
+# pylint: enable=g-bad-import-order
# Sometimes when your dataset is too large to hold in the memory
-# you may want to load it into a out-of-core dataframe as provided by dask library
-# to firstly draw sample batches and then load into memory for training.
+# you may want to load it into a out-of-core dataframe as provided by dask
+# library to firstly draw sample batches and then load into memory for training.
# Load dataset.
iris = datasets.load_iris()
-X_train, X_test, y_train, y_test = cross_validation.train_test_split(iris.data, iris.target,
- test_size=0.2, random_state=42)
+x_train, x_test, y_train, y_test = cross_validation.train_test_split(
+ iris.data, iris.target, test_size=0.2, random_state=42)
# Note that we use iris here just for demo purposes
# You can load your own large dataset into a out-of-core dataframe
# using dask's methods, e.g. read_csv() in dask
# details please see: http://dask.pydata.org/en/latest/dataframe.html
-# We firstly load them into pandas dataframe and then convert into dask dataframe
-X_train, y_train, X_test, y_test = [pd.DataFrame(data) for data in [X_train, y_train, X_test, y_test]]
-X_train, y_train, X_test, y_test = [dd.from_pandas(data, npartitions=2) for data in [X_train, y_train, X_test, y_test]]
+# We firstly load them into pandas dataframe and then convert into dask
+# dataframe.
+x_train, y_train, x_test, y_test = [
+ pd.DataFrame(data) for data in [x_train, y_train, x_test, y_test]]
+x_train, y_train, x_test, y_test = [
+ dd.from_pandas(data, npartitions=2)
+ for data in [x_train, y_train, x_test, y_test]]
# Initialize a TensorFlow linear classifier
-classifier = learn.TensorFlowLinearClassifier(n_classes=3)
+classifier = learn.TensorFlowLinearClassifier(
+ feature_columns=learn.infer_real_valued_columns_from_input(x_train),
+ n_classes=3)
-# Fit the model using training set
-classifier.fit(X_train, y_train)
+# Fit the model using training set.
+classifier.fit(x_train, y_train)
# Make predictions on each partitions of testing data
-predictions = X_test.map_partitions(classifier.predict).compute()
+predictions = x_test.map_partitions(classifier.predict).compute()
# Calculate accuracy
score = metrics.accuracy_score(y_test.compute(), predictions)