aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-01 10:33:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-01 11:53:35 -0700
commit3fbd3ce4d6fc2898ac50348bb8221e4e49fefe10 (patch)
tree3b23a1fb1e85625f15019f419cbecf31655f8782
parent26f20568e42cdc1e9a522392c11ef6d85924e57f (diff)
Refactor DNNLinearCombinedClassifier from inheritance to composition.
Change: 137848991
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py453
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py114
2 files changed, 532 insertions, 35 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
index 64a92f5ffb..57a98e419c 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
@@ -19,7 +19,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
+import math
+import re
import six
from tensorflow.contrib import layers
@@ -27,13 +28,22 @@ from tensorflow.contrib.framework import deprecated
from tensorflow.contrib.framework import deprecated_arg_values
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.layers.python.layers import feature_column_ops
+from tensorflow.contrib.layers.python.layers import optimizers
+from tensorflow.contrib.learn.python.learn import evaluable
+from tensorflow.contrib.learn.python.learn import session_run_hook
+from tensorflow.contrib.learn.python.learn import trainable
from tensorflow.contrib.learn.python.learn.estimators import composable_model
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
+from tensorflow.contrib.learn.python.learn.utils import export
from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
@@ -307,7 +317,236 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
return logits
-class DNNLinearCombinedClassifier(_DNNLinearCombinedBaseEstimator):
+_CENTERED_BIAS_WEIGHT = "centered_bias_weight"
+
+# The default learning rates are a historical artifact of the initial
+# implementation, but seem a reasonable choice.
+_DNN_LEARNING_RATE = 0.05
+_LINEAR_LEARNING_RATE = 0.2
+
+
+def _as_iterable(preds, output):
+ for pred in preds:
+ yield pred[output]
+
+
+def _get_feature_dict(features):
+ if isinstance(features, dict):
+ return features
+ return {"": features}
+
+
+def _get_optimizer(optimizer):
+ if callable(optimizer):
+ return optimizer()
+ else:
+ return optimizer
+
+
+def _linear_learning_rate(num_linear_feature_columns):
+ """Returns the default learning rate of the linear model.
+
+ The calculation is a historical artifact of this initial implementation, but
+ has proven a reasonable choice.
+
+ Args:
+ num_linear_feature_columns: The number of feature columns of the linear
+ model.
+
+ Returns:
+ A float.
+ """
+ default_learning_rate = 1. / math.sqrt(num_linear_feature_columns)
+ return min(_LINEAR_LEARNING_RATE, default_learning_rate)
+
+
+def _add_hidden_layer_summary(value, tag):
+ logging_ops.scalar_summary("%s:fraction_of_zero_values" % tag,
+ nn.zero_fraction(value))
+ logging_ops.histogram_summary("%s:activation" % tag, value)
+
+
+def _dnn_linear_combined_model_fn(features, labels, mode, params):
+ """Deep Neural Net and Linear combined model_fn.
+
+ Args:
+ features: `Tensor` or dict of `Tensor` (depends on data passed to `fit`).
+ labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of dtype
+ `int32` or `int64` in the range `[0, n_classes)`.
+ mode: Defines whether this is training, evaluation or prediction.
+ See `ModeKeys`.
+ params: A dict of hyperparameters.
+ The following hyperparameters are expected:
+ * head: A `Head` instance.
+ * linear_feature_columns: An iterable containing all the feature columns
+ used by the Linear model.
+ * linear_optimizer: string, `Optimizer` object, or callable that defines
+ the optimizer to use for training the Linear model.
+ * joint_linear_weights: If True a single (possibly partitioned) variable
+ will be used to store the linear model weights. It's faster, but
+ requires all columns are sparse and have the 'sum' combiner.
+ * dnn_feature_columns: An iterable containing all the feature columns used
+ by the DNN model.
+ * dnn_optimizer: string, `Optimizer` object, or callable that defines the
+ optimizer to use for training the DNN model.
+ * dnn_hidden_units: List of hidden units per DNN layer.
+ * dnn_activation_fn: Activation function applied to each DNN layer. If
+ `None`, will use `tf.nn.relu`.
+ * dnn_dropout: When not `None`, the probability we will drop out a given
+ DNN coordinate.
+ * gradient_clip_norm: A float > 0. If provided, gradients are
+ clipped to their global norm with this clipping ratio.
+ * num_ps_replicas: The number of parameter server replicas.
+
+ Returns:
+ `estimator.ModelFnOps`
+
+ Raises:
+ ValueError: If both `linear_feature_columns` and `dnn_features_columns`
+ are empty at the same time.
+ """
+ head = params["head"]
+ linear_feature_columns = params.get("linear_feature_columns")
+ linear_optimizer = params.get("linear_optimizer")
+ joint_linear_weights = params.get("joint_linear_weights")
+ dnn_feature_columns = params.get("dnn_feature_columns")
+ dnn_optimizer = params.get("dnn_optimizer")
+ dnn_hidden_units = params.get("dnn_hidden_units")
+ dnn_activation_fn = params.get("dnn_activation_fn")
+ dnn_dropout = params.get("dnn_dropout")
+ gradient_clip_norm = params.get("gradient_clip_norm")
+ num_ps_replicas = params["num_ps_replicas"]
+
+ if not linear_feature_columns and not dnn_feature_columns:
+ raise ValueError(
+ "Either linear_feature_columns or dnn_feature_columns must be defined.")
+
+ features = _get_feature_dict(features)
+
+ # Build DNN Logits.
+ dnn_parent_scope = "dnn"
+
+ if not dnn_feature_columns:
+ dnn_logits = None
+ else:
+ input_layer_partitioner = (
+ partitioned_variables.min_max_variable_partitioner(
+ max_partitions=num_ps_replicas,
+ min_slice_size=64 << 20))
+ with variable_scope.variable_scope(
+ dnn_parent_scope + "/input_from_feature_columns",
+ values=features.values(),
+ partitioner=input_layer_partitioner) as scope:
+ net = layers.input_from_feature_columns(
+ columns_to_tensors=features,
+ feature_columns=dnn_feature_columns,
+ weight_collections=[dnn_parent_scope],
+ scope=scope)
+
+ hidden_layer_partitioner = (
+ partitioned_variables.min_max_variable_partitioner(
+ max_partitions=num_ps_replicas))
+ for layer_id, num_hidden_units in enumerate(dnn_hidden_units):
+ with variable_scope.variable_scope(
+ dnn_parent_scope + "/hiddenlayer_%d" % layer_id,
+ values=[net],
+ partitioner=hidden_layer_partitioner) as scope:
+ net = layers.fully_connected(
+ net,
+ num_hidden_units,
+ activation_fn=dnn_activation_fn,
+ variables_collections=[dnn_parent_scope],
+ scope=scope)
+ if dnn_dropout is not None and mode == estimator.ModeKeys.TRAIN:
+ net = layers.dropout(
+ net,
+ keep_prob=(1.0 - dnn_dropout))
+ # TODO(b/31209633): Consider adding summary before dropout.
+ _add_hidden_layer_summary(net, scope.name)
+
+ with variable_scope.variable_scope(
+ dnn_parent_scope + "/logits",
+ values=[net],
+ partitioner=hidden_layer_partitioner) as scope:
+ dnn_logits = layers.fully_connected(
+ net,
+ head.logits_dimension,
+ activation_fn=None,
+ variables_collections=[dnn_parent_scope],
+ scope=scope)
+ _add_hidden_layer_summary(dnn_logits, scope.name)
+
+ # Build Linear logits.
+ linear_parent_scope = "linear"
+
+ if not linear_feature_columns:
+ linear_logits = None
+ else:
+ linear_partitioner = partitioned_variables.min_max_variable_partitioner(
+ max_partitions=num_ps_replicas,
+ min_slice_size=64 << 20)
+ with variable_scope.variable_scope(
+ linear_parent_scope,
+ values=features.values(),
+ partitioner=linear_partitioner) as scope:
+ if joint_linear_weights:
+ linear_logits, _, _ = layers.joint_weighted_sum_from_feature_columns(
+ columns_to_tensors=features,
+ feature_columns=linear_feature_columns,
+ num_outputs=head.logits_dimension,
+ weight_collections=[linear_parent_scope],
+ scope=scope)
+ else:
+ linear_logits, _, _ = layers.weighted_sum_from_feature_columns(
+ columns_to_tensors=features,
+ feature_columns=linear_feature_columns,
+ num_outputs=head.logits_dimension,
+ weight_collections=[linear_parent_scope],
+ scope=scope)
+
+ # Combine logits and build full model.
+ if dnn_logits is not None and linear_logits is not None:
+ logits = dnn_logits + linear_logits
+ elif dnn_logits is not None:
+ logits = dnn_logits
+ else:
+ logits = linear_logits
+
+ def _make_training_op(training_loss):
+ """Training op for the DNN linear combined model."""
+ train_ops = []
+ if dnn_logits is not None:
+ train_ops.append(
+ optimizers.optimize_loss(
+ loss=training_loss,
+ global_step=contrib_variables.get_global_step(),
+ learning_rate=_DNN_LEARNING_RATE,
+ optimizer=_get_optimizer(dnn_optimizer),
+ clip_gradients=gradient_clip_norm,
+ variables=ops.get_collection(dnn_parent_scope),
+ name=dnn_parent_scope,
+ # Empty summaries, because head already logs "loss" summary.
+ summaries=[]))
+ if linear_logits is not None:
+ train_ops.append(
+ optimizers.optimize_loss(
+ loss=training_loss,
+ global_step=contrib_variables.get_global_step(),
+ learning_rate=_linear_learning_rate(len(linear_feature_columns)),
+ optimizer=_get_optimizer(linear_optimizer),
+ clip_gradients=gradient_clip_norm,
+ variables=ops.get_collection(linear_parent_scope),
+ name=linear_parent_scope,
+ # Empty summaries, because head already logs "loss" summary.
+ summaries=[]))
+
+ return control_flow_ops.group(*train_ops)
+
+ return head.head_ops(
+ features, labels, mode, _make_training_op, logits=logits)
+
+
+class DNNLinearCombinedClassifier(evaluable.Evaluable, trainable.Trainable):
"""A classifier for TensorFlow Linear and DNN joined training models.
Example:
@@ -423,30 +662,71 @@ class DNNLinearCombinedClassifier(_DNNLinearCombinedBaseEstimator):
ValueError: If both `linear_feature_columns` and `dnn_features_columns`
are empty at the same time.
"""
-
if n_classes < 2:
raise ValueError("n_classes should be greater than 1. Given: {}".format(
n_classes))
+ self._linear_optimizer = linear_optimizer or "Ftrl"
+ linear_feature_columns = linear_feature_columns or []
+ dnn_feature_columns = dnn_feature_columns or []
+ self._feature_columns = linear_feature_columns + dnn_feature_columns
+ if not self._feature_columns:
+ raise ValueError("Either linear_feature_columns or dnn_feature_columns "
+ "must be defined.")
+ self._dnn_hidden_units = dnn_hidden_units
+ self._enable_centered_bias = enable_centered_bias
+
head = head_lib._multi_class_head( # pylint: disable=protected-access
n_classes=n_classes,
weight_column_name=weight_column_name,
enable_centered_bias=enable_centered_bias)
- super(DNNLinearCombinedClassifier, self).__init__(
+ self._estimator = estimator.Estimator(
+ model_fn=_dnn_linear_combined_model_fn,
model_dir=model_dir,
- linear_feature_columns=linear_feature_columns,
- linear_optimizer=linear_optimizer,
- _joint_linear_weights=_joint_linear_weights,
- dnn_feature_columns=dnn_feature_columns,
- dnn_optimizer=dnn_optimizer,
- dnn_hidden_units=dnn_hidden_units,
- dnn_activation_fn=dnn_activation_fn,
- dnn_dropout=dnn_dropout,
- gradient_clip_norm=gradient_clip_norm,
- head=head,
config=config,
- feature_engineering_fn=feature_engineering_fn,
- default_prediction_key=head_lib.PredictionKey.CLASSES,
- enable_centered_bias=enable_centered_bias)
+ params={
+ "head": head,
+ "linear_feature_columns": linear_feature_columns,
+ "linear_optimizer": self._linear_optimizer,
+ "joint_linear_weights": _joint_linear_weights,
+ "dnn_feature_columns": dnn_feature_columns,
+ "dnn_optimizer": dnn_optimizer or "Adagrad",
+ "dnn_hidden_units": dnn_hidden_units,
+ "dnn_activation_fn": dnn_activation_fn,
+ "dnn_dropout": dnn_dropout,
+ "gradient_clip_norm": gradient_clip_norm,
+ "num_ps_replicas": config.num_ps_replicas if config else 0,
+ },
+ feature_engineering_fn=feature_engineering_fn)
+
+ def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None,
+ monitors=None, max_steps=None):
+ """See trainable.Trainable."""
+ # TODO(roumposg): Remove when deprecated monitors are removed.
+ if monitors is not None:
+ deprecated_monitors = [
+ m for m in monitors
+ if not isinstance(m, session_run_hook.SessionRunHook)
+ ]
+ for monitor in deprecated_monitors:
+ monitor.set_estimator(self)
+ monitor._lock_estimator() # pylint: disable=protected-access
+
+ result = self._estimator.fit(x=x, y=y, input_fn=input_fn, steps=steps,
+ batch_size=batch_size, monitors=monitors,
+ max_steps=max_steps)
+
+ if monitors is not None:
+ for monitor in deprecated_monitors:
+ monitor._unlock_estimator() # pylint: disable=protected-access
+
+ return result
+
+ def evaluate(self, x=None, y=None, input_fn=None, feed_fn=None,
+ batch_size=None, steps=None, metrics=None, name=None):
+ """See evaluable.Evaluable."""
+ return self._estimator.evaluate(
+ x=x, y=y, input_fn=input_fn, feed_fn=feed_fn, batch_size=batch_size,
+ steps=steps, metrics=metrics, name=name)
@deprecated_arg_values(
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
@@ -467,12 +747,13 @@ class DNNLinearCombinedClassifier(_DNNLinearCombinedBaseEstimator):
Numpy array of predicted classes (or an iterable of predicted classes if
as_iterable is True).
"""
- predictions = self.predict_proba(
- x=x, input_fn=input_fn, batch_size=batch_size, as_iterable=as_iterable)
+ preds = self._estimator.predict(x=x, input_fn=input_fn,
+ batch_size=batch_size,
+ outputs=[head_lib.PredictionKey.CLASSES],
+ as_iterable=as_iterable)
if as_iterable:
- return (np.argmax(p, axis=0) for p in predictions)
- else:
- return np.argmax(predictions, axis=1)
+ return _as_iterable(preds, output=head_lib.PredictionKey.CLASSES)
+ return preds[head_lib.PredictionKey.CLASSES].reshape(-1)
@deprecated_arg_values(
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
@@ -494,14 +775,132 @@ class DNNLinearCombinedClassifier(_DNNLinearCombinedBaseEstimator):
Numpy array of predicted probabilities (or an iterable of predicted
probabilities if as_iterable is True).
"""
- return super(DNNLinearCombinedClassifier, self).predict(
- x=x, input_fn=input_fn, batch_size=batch_size, as_iterable=as_iterable)
+ preds = self._estimator.predict(
+ x=x, input_fn=input_fn,
+ batch_size=batch_size,
+ outputs=[head_lib.PredictionKey.PROBABILITIES],
+ as_iterable=as_iterable)
+ if as_iterable:
+ return _as_iterable(preds, output=head_lib.PredictionKey.PROBABILITIES)
+ return preds[head_lib.PredictionKey.PROBABILITIES]
def _get_predict_ops(self, features):
- """See base class."""
- return super(DNNLinearCombinedClassifier, self)._get_predict_ops(features)[
+ """See `Estimator` class."""
+ # pylint: disable=protected-access
+ return self._estimator._get_predict_ops(features)[
head_lib.PredictionKey.PROBABILITIES]
+ def get_variable_names(self):
+ """Returns list of all variable names in this model.
+
+ Returns:
+ List of names.
+ """
+ return self._estimator.get_variable_names()
+
+ def get_variable_value(self, name):
+ """Returns value of the variable given by name.
+
+ Args:
+ name: string, name of the tensor.
+
+ Returns:
+ `Tensor` object.
+ """
+ return self._estimator.get_variable_value(name)
+
+ def export(self,
+ export_dir,
+ input_fn=None,
+ input_feature_key=None,
+ use_deprecated_input_fn=True,
+ signature_fn=None,
+ default_batch_size=1,
+ exports_to_keep=None):
+ """See BasEstimator.export."""
+ def default_input_fn(unused_estimator, examples):
+ return layers.parse_feature_columns_from_examples(
+ examples, self._feature_columns)
+ self._estimator.export(
+ export_dir=export_dir,
+ input_fn=input_fn or default_input_fn,
+ input_feature_key=input_feature_key,
+ use_deprecated_input_fn=use_deprecated_input_fn,
+ signature_fn=(
+ signature_fn or export.classification_signature_fn_with_prob),
+ prediction_key=head_lib.PredictionKey.PROBABILITIES,
+ default_batch_size=default_batch_size,
+ exports_to_keep=exports_to_keep)
+
+ @property
+ def model_dir(self):
+ return self._estimator.model_dir
+
+ @property
+ @deprecated("2016-10-30",
+ "This method will be removed after the deprecation date. "
+ "To inspect variables, use get_variable_names() and "
+ "get_variable_value().")
+ def dnn_weights_(self):
+ hiddenlayer_weights = [
+ self.get_variable_value("dnn/hiddenlayer_%d/weights" % i)
+ for i, _ in enumerate(self._dnn_hidden_units)
+ ]
+ logits_weights = [self.get_variable_value("dnn/logits/weights")]
+ return hiddenlayer_weights + logits_weights
+
+ @property
+ @deprecated("2016-10-30",
+ "This method will be removed after the deprecation date. "
+ "To inspect variables, use get_variable_names() and "
+ "get_variable_value().")
+ def linear_weights_(self):
+ values = {}
+ if isinstance(self._linear_optimizer, str):
+ optimizer_name = self._linear_optimizer
+ else:
+ optimizer_name = self._linear_optimizer.get_name()
+ optimizer_regex = r".*/"+optimizer_name + r"(_\d)?$"
+ for name in self.get_variable_names():
+ if (name.startswith("linear/") and
+ name != "linear/bias_weight" and
+ name != "linear/learning_rate" and
+ not re.match(optimizer_regex, name)):
+ values[name] = self.get_variable_value(name)
+ if len(values) == 1:
+ return values[list(values.keys())[0]]
+ return values
+
+ @property
+ @deprecated("2016-10-30",
+ "This method will be removed after the deprecation date. "
+ "To inspect variables, use get_variable_names() and "
+ "get_variable_value().")
+ def dnn_bias_(self):
+ hiddenlayer_bias = [self.get_variable_value("dnn/hiddenlayer_%d/biases" % i)
+ for i, _ in enumerate(self._dnn_hidden_units)]
+ logits_bias = [self.get_variable_value("dnn/logits/biases")]
+ if not self._enable_centered_bias:
+ return hiddenlayer_bias + logits_bias
+ centered_bias = [self.get_variable_value(_CENTERED_BIAS_WEIGHT)]
+ return hiddenlayer_bias + logits_bias + centered_bias
+
+ @property
+ @deprecated("2016-10-30",
+ "This method will be removed after the deprecation date. "
+ "To inspect variables, use get_variable_names() and "
+ "get_variable_value().")
+ def linear_bias_(self):
+ linear_bias = self.get_variable_value("linear/bias_weight")
+ if not self._enable_centered_bias:
+ return linear_bias
+ centered_bias = [self.get_variable_value(_CENTERED_BIAS_WEIGHT)]
+ return linear_bias + centered_bias
+
+ @property
+ def config(self):
+ return self._estimator.config
+
class DNNLinearCombinedRegressor(_DNNLinearCombinedBaseEstimator):
"""A regressor for TensorFlow Linear and DNN joined training models.
@@ -649,5 +1048,3 @@ class DNNLinearCombinedRegressor(_DNNLinearCombinedBaseEstimator):
"""See base class."""
return super(DNNLinearCombinedRegressor, self)._get_predict_ops(features)[
head_lib.PredictionKey.SCORES]
-
-
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py
index ad574f0790..dae1879646 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py
@@ -27,6 +27,7 @@ import tensorflow as tf
from tensorflow.contrib.learn.python.learn.estimators import _sklearn
from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils
+from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec
def _get_quantile_based_buckets(feature_values, num_buckets):
@@ -65,6 +66,15 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
estimator_test_utils.assert_estimator_contract(
self, tf.contrib.learn.DNNLinearCombinedClassifier)
+ def testNoFeatureColumns(self):
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'Either linear_feature_columns or dnn_feature_columns must be defined'):
+ tf.contrib.learn.DNNLinearCombinedClassifier(
+ linear_feature_columns=None,
+ dnn_feature_columns=None,
+ dnn_hidden_units=[3, 3])
+
def testLogisticRegression_MatrixData(self):
"""Tests binary classification using matrix data as input."""
iris = _prepare_iris_data_for_logistic_regression()
@@ -80,6 +90,7 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
classifier.fit(input_fn=_iris_input_logistic_fn, steps=100)
scores = classifier.evaluate(input_fn=_iris_input_logistic_fn, steps=100)
+ self.assertIn('auc', scores.keys())
self.assertGreater(scores['accuracy'], 0.9)
def testLogisticRegression_TensorData(self):
@@ -120,6 +131,7 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
classifier.fit(input_fn=_input_fn, steps=100)
scores = classifier.evaluate(input_fn=_input_fn, steps=100)
+ self.assertIn('auc', scores.keys())
self.assertGreater(scores['accuracy'], 0.9)
def testTrainWithPartitionedVariables(self):
@@ -397,9 +409,15 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
input_fn=_input_fn,
steps=100,
metrics={
- 'my_accuracy': tf.contrib.metrics.streaming_accuracy,
- ('my_precision', 'classes'): tf.contrib.metrics.streaming_precision,
- ('my_metric', 'probabilities'): _my_metric_op
+ 'my_accuracy': MetricSpec(
+ metric_fn=tf.contrib.metrics.streaming_accuracy,
+ prediction_key='classes'),
+ 'my_precision': MetricSpec(
+ metric_fn=tf.contrib.metrics.streaming_precision,
+ prediction_key='classes'),
+ 'my_metric': MetricSpec(
+ metric_fn=_my_metric_op,
+ prediction_key='probabilities')
})
self.assertTrue(
set(['loss', 'my_accuracy', 'my_precision', 'my_metric'
@@ -412,7 +430,7 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
# Test the case where the 2nd element of the key is neither "classes" nor
# "probabilities".
- with self.assertRaises(KeyError):
+ with self.assertRaisesRegexp(KeyError, 'bad_type'):
classifier.evaluate(
input_fn=_input_fn,
steps=100,
@@ -428,6 +446,17 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
tf.contrib.metrics.streaming_accuracy
})
+ # Test the case where the prediction_key is neither "classes" nor
+ # "probabilities".
+ with self.assertRaisesRegexp(KeyError, 'bad_type'):
+ classifier.evaluate(
+ input_fn=_input_fn,
+ steps=100,
+ metrics={
+ 'bad_name': MetricSpec(
+ metric_fn=tf.contrib.metrics.streaming_auc,
+ prediction_key='bad_type')})
+
def testVariableQuery(self):
"""Tests bias is centered or not."""
def _input_fn_train():
@@ -447,6 +476,39 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
for name in var_names:
classifier.get_variable_value(name)
+ def testExport(self):
+ """Tests export model for servo."""
+
+ def input_fn():
+ return {
+ 'age': tf.constant([1]),
+ 'language': tf.SparseTensor(values=['english'],
+ indices=[[0, 0]],
+ shape=[1, 1])
+ }, tf.constant([[1]])
+
+ language = tf.contrib.layers.sparse_column_with_hash_bucket('language', 100)
+
+ classifier = tf.contrib.learn.DNNLinearCombinedClassifier(
+ linear_feature_columns=[
+ tf.contrib.layers.real_valued_column('age'),
+ language,
+ ],
+ dnn_feature_columns=[
+ tf.contrib.layers.embedding_column(language, dimension=1),
+ ],
+ dnn_hidden_units=[3, 3])
+ classifier.fit(input_fn=input_fn, steps=100)
+
+ export_dir = tempfile.mkdtemp()
+ input_feature_key = 'examples'
+ def serving_input_fn():
+ features, targets = input_fn()
+ features[input_feature_key] = tf.placeholder(tf.string)
+ return features, targets
+ classifier.export(export_dir, serving_input_fn, input_feature_key,
+ use_deprecated_input_fn=False)
+
def testCenteredBias(self):
"""Tests bias is centered or not."""
def _input_fn_train():
@@ -461,7 +523,7 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
dnn_hidden_units=[3, 3],
enable_centered_bias=True)
- classifier.fit(input_fn=_input_fn_train, steps=500)
+ classifier.fit(input_fn=_input_fn_train, steps=1000)
# logodds(0.75) = 1.09861228867
self.assertAlmostEqual(
1.0986,
@@ -483,7 +545,7 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
enable_centered_bias=False)
classifier.fit(input_fn=_input_fn_train, steps=500)
- self.assertFalse('centered_bias_weight' in classifier.get_variable_names())
+ self.assertNotIn('centered_bias_weight', classifier.get_variable_names())
def testLinearOnly(self):
"""Tests that linear-only instantiation works."""
@@ -822,6 +884,44 @@ class DNNLinearCombinedRegressorTest(tf.test.TestCase):
metrics={('my_error', 'predictions'
): tf.contrib.metrics.streaming_mean_squared_error})
+ def testExport(self):
+ """Tests export model for servo."""
+ labels = [1., 0., 0.2]
+ def _input_fn(num_epochs=None):
+ features = {
+ 'age': tf.train.limit_epochs(tf.constant([[0.8], [0.15], [0.]]),
+ num_epochs=num_epochs),
+ 'language': tf.SparseTensor(values=['en', 'fr', 'zh'],
+ indices=[[0, 0], [0, 1], [2, 0]],
+ shape=[3, 2])
+ }
+ return features, tf.constant(labels, dtype=tf.float32)
+
+ language_column = tf.contrib.layers.sparse_column_with_hash_bucket(
+ 'language', hash_bucket_size=20)
+
+ regressor = tf.contrib.learn.DNNLinearCombinedRegressor(
+ linear_feature_columns=[
+ language_column,
+ tf.contrib.layers.real_valued_column('age')
+ ],
+ dnn_feature_columns=[
+ tf.contrib.layers.embedding_column(language_column, dimension=1),
+ ],
+ dnn_hidden_units=[3, 3],
+ config=tf.contrib.learn.RunConfig(tf_random_seed=1))
+
+ regressor.fit(input_fn=_input_fn, steps=100)
+
+ export_dir = tempfile.mkdtemp()
+ input_feature_key = 'examples'
+ def serving_input_fn():
+ features, targets = _input_fn()
+ features[input_feature_key] = tf.placeholder(tf.string)
+ return features, targets
+ regressor.export(export_dir, serving_input_fn, input_feature_key,
+ use_deprecated_input_fn=False)
+
def testTrainSaveLoad(self):
"""Tests regression with restarting training / evaluate."""
def _input_fn(num_epochs=None):
@@ -1009,7 +1109,7 @@ class FeatureEngineeringFunctionTest(tf.test.TestCase):
config=tf.contrib.learn.RunConfig(tf_random_seed=1))
estimator_without_fe_fn.fit(input_fn=input_fn, steps=100)
- # predictions = y
+ # predictions = y
prediction_with_fe_fn = next(
estimator_with_fe_fn.predict(input_fn=input_fn, as_iterable=True))
self.assertAlmostEqual(1000., prediction_with_fe_fn, delta=1.0)