aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-04-28 13:50:53 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-28 15:28:23 -0700
commit56bca499698cee57cfc7424089f0b0c183cd3bfd (patch)
tree7b720fb444929b455c6b77f6805a6b48373415ff /tensorflow
parenta6fdccc5be02dd1d3c8a70a6656db94d4d525e76 (diff)
Support label_keys in DNNLinearCombinedClassifier and in LinearClassifier.
Change: 154585848
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py39
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py53
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/linear.py35
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/linear_test.py43
4 files changed, 164 insertions, 6 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
index b525213eb7..0ff5d6e8dc 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
@@ -500,9 +500,36 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
...
def input_fn_eval: # returns x, y (where y represents label's class index).
...
+ def input_fn_predict: # returns x, None.
+ ...
+ estimator.fit(input_fn=input_fn_train)
+ estimator.evaluate(input_fn=input_fn_eval)
+ # predict_classes returns class indices.
+ estimator.predict_classes(input_fn=input_fn_predict)
+ ```
+
+ If the user specifies `label_keys` in constructor, labels must be strings from
+ the `label_keys` vocabulary. Example:
+
+ ```python
+ label_keys = ['label0', 'label1', 'label2']
+ estimator = DNNLinearCombinedClassifier(
+ n_classes=n_classes,
+ linear_feature_columns=[sparse_feature_a_x_sparse_feature_b],
+ dnn_feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb],
+ dnn_hidden_units=[1000, 500, 100],
+ label_keys=label_keys)
+
+ def input_fn_train: # returns x, y (where y is one of label_keys).
+ pass
estimator.fit(input_fn=input_fn_train)
+
+ def input_fn_eval: # returns x, y (where y is one of label_keys).
+ pass
estimator.evaluate(input_fn=input_fn_eval)
- estimator.predict(x=x) # returns predicted labels (i.e. label's class index).
+ def input_fn_predict: # returns x, None
+ # predict_classes returns one of label_keys.
+ estimator.predict_classes(input_fn=input_fn_predict)
```
Input of `fit` and `evaluate` should have following features,
@@ -542,6 +569,7 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
feature_engineering_fn=None,
embedding_lr_multipliers=None,
input_layer_min_slice_size=None,
+ label_keys=None,
fix_global_step_increment_bug=False):
"""Constructs a DNNLinearCombinedClassifier instance.
@@ -593,6 +621,8 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
learning rate for the embedding variables.
input_layer_min_slice_size: Optional. The min slice size of input layer
partitions. If not provided, will use the default of 64M.
+ label_keys: Optional list of strings with size `[n_classes]` defining the
+ label vocabulary. Only supported for `n_classes` > 2.
fix_global_step_increment_bug: If `False`, the estimator needs two fit
steps to optimize both linear and dnn parts. If `True`, this bug is
fixed. New users must set this to `True`, but it the default value is
@@ -606,7 +636,8 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
head = head_lib.multi_class_head(
n_classes=n_classes,
weight_column_name=weight_column_name,
- enable_centered_bias=enable_centered_bias)
+ enable_centered_bias=enable_centered_bias,
+ label_keys=label_keys)
linear_feature_columns = tuple(linear_feature_columns or [])
dnn_feature_columns = tuple(dnn_feature_columns or [])
self._feature_columns = linear_feature_columns + dnn_feature_columns
@@ -817,9 +848,11 @@ class DNNLinearCombinedRegressor(estimator.Estimator):
...
def input_fn_eval: # returns x, y
...
+ def input_fn_predict: # returns x, None
+ ...
estimator.train(input_fn_train)
estimator.evaluate(input_fn_eval)
- estimator.predict(x)
+ estimator.predict(input_fn_predict)
```
Input of `fit`, `train`, and `evaluate` should have following features,
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py
index 301211ee82..14caa0a5b5 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py
@@ -493,6 +493,59 @@ class DNNLinearCombinedClassifierTest(test.TestCase):
input_fn=test_data.iris_input_multiclass_fn, steps=100)
_assert_metrics_in_range(('accuracy',), scores)
+ def testMultiClassLabelKeys(self):
+ """Tests n_classes > 2 with label_keys vocabulary for labels."""
+ # Byte literals needed for python3 test to pass.
+ label_keys = [b'label0', b'label1', b'label2']
+
+ def _input_fn(num_epochs=None):
+ features = {
+ 'age':
+ input_lib.limit_epochs(
+ constant_op.constant([[.8], [0.2], [.1]]),
+ num_epochs=num_epochs),
+ 'language':
+ sparse_tensor.SparseTensor(
+ values=input_lib.limit_epochs(
+ ['en', 'fr', 'zh'], num_epochs=num_epochs),
+ indices=[[0, 0], [0, 1], [2, 0]],
+ dense_shape=[3, 2])
+ }
+ labels = constant_op.constant(
+ [[label_keys[1]], [label_keys[0]], [label_keys[0]]],
+ dtype=dtypes.string)
+ return features, labels
+
+ language_column = feature_column.sparse_column_with_hash_bucket(
+ 'language', hash_bucket_size=20)
+
+ classifier = dnn_linear_combined.DNNLinearCombinedClassifier(
+ n_classes=3,
+ linear_feature_columns=[language_column],
+ dnn_feature_columns=[
+ feature_column.embedding_column(
+ language_column, dimension=1),
+ feature_column.real_valued_column('age')
+ ],
+ dnn_hidden_units=[3, 3],
+ label_keys=label_keys)
+
+ classifier.fit(input_fn=_input_fn, steps=50)
+
+ scores = classifier.evaluate(input_fn=_input_fn, steps=1)
+ _assert_metrics_in_range(('accuracy',), scores)
+ self.assertIn('loss', scores)
+ predict_input_fn = functools.partial(_input_fn, num_epochs=1)
+ predicted_classes = list(
+ classifier.predict_classes(
+ input_fn=predict_input_fn, as_iterable=True))
+ self.assertEqual(3, len(predicted_classes))
+ for pred in predicted_classes:
+ self.assertIn(pred, label_keys)
+ predictions = list(
+ classifier.predict(input_fn=predict_input_fn, as_iterable=True))
+ self.assertAllEqual(predicted_classes, predictions)
+
def testLoss(self):
"""Tests loss calculation."""
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py
index d1b4aedb81..bff4dc8d63 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/linear.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py
@@ -333,9 +333,34 @@ class LinearClassifier(estimator.Estimator):
...
def input_fn_eval: # returns x, y (where y represents label's class index).
...
+ def input_fn_predict: # returns x, None.
+ ...
estimator.fit(input_fn=input_fn_train)
estimator.evaluate(input_fn=input_fn_eval)
- estimator.predict(x=x) # returns predicted labels (i.e. label's class index).
+ # predict_classes returns class indices.
+ estimator.predict_classes(input_fn=input_fn_predict)
+ ```
+
+ If the user specifies `label_keys` in constructor, labels must be strings from
+ the `label_keys` vocabulary. Example:
+
+ ```python
+ label_keys = ['label0', 'label1', 'label2']
+ estimator = LinearClassifier(
+ n_classes=n_classes,
+ feature_columns=[sparse_column_a, sparse_feature_a_x_sparse_feature_b],
+ label_keys=label_keys)
+
+ def input_fn_train: # returns x, y (where y is one of label_keys).
+ pass
+ estimator.fit(input_fn=input_fn_train)
+
+ def input_fn_eval: # returns x, y (where y is one of label_keys).
+ pass
+ estimator.evaluate(input_fn=input_fn_eval)
+ def input_fn_predict: # returns x, None
+ # predict_classes returns one of label_keys.
+ estimator.predict_classes(input_fn=input_fn_predict)
```
Input of `fit` and `evaluate` should have following features,
@@ -363,7 +388,8 @@ class LinearClassifier(estimator.Estimator):
enable_centered_bias=False,
_joint_weight=False,
config=None,
- feature_engineering_fn=None):
+ feature_engineering_fn=None,
+ label_keys=None):
"""Construct a `LinearClassifier` estimator object.
Args:
@@ -398,6 +424,8 @@ class LinearClassifier(estimator.Estimator):
labels which are the output of `input_fn` and
returns features and labels which will be fed
into the model.
+ label_keys: Optional list of strings with size `[n_classes]` defining the
+ label vocabulary. Only supported for `n_classes` > 2.
Returns:
A `LinearClassifier` estimator.
@@ -419,7 +447,8 @@ class LinearClassifier(estimator.Estimator):
head = head_lib.multi_class_head(
n_classes,
weight_column_name=weight_column_name,
- enable_centered_bias=enable_centered_bias)
+ enable_centered_bias=enable_centered_bias,
+ label_keys=label_keys)
params = {
"head": head,
"feature_columns": feature_columns,
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
index fc64377452..ededd3cdb9 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
@@ -172,6 +172,49 @@ class LinearClassifierTest(test.TestCase):
scores = classifier.evaluate(x=train_x, y=train_y, steps=1)
self.assertGreater(scores['accuracy'], 0.9)
+ def testMultiClassLabelKeys(self):
+ """Tests n_classes > 2 with label_keys vocabulary for labels."""
+ # Byte literals needed for python3 test to pass.
+ label_keys = [b'label0', b'label1', b'label2']
+
+ def _input_fn(num_epochs=None):
+ features = {
+ 'language':
+ sparse_tensor.SparseTensor(
+ values=input_lib.limit_epochs(
+ ['en', 'fr', 'zh'], num_epochs=num_epochs),
+ indices=[[0, 0], [0, 1], [2, 0]],
+ dense_shape=[3, 2])
+ }
+ labels = constant_op.constant(
+ [[label_keys[1]], [label_keys[0]], [label_keys[0]]],
+ dtype=dtypes.string)
+ return features, labels
+
+ language_column = feature_column_lib.sparse_column_with_hash_bucket(
+ 'language', hash_bucket_size=20)
+
+ classifier = linear.LinearClassifier(
+ n_classes=3,
+ feature_columns=[language_column],
+ label_keys=label_keys)
+
+ classifier.fit(input_fn=_input_fn, steps=50)
+
+ scores = classifier.evaluate(input_fn=_input_fn, steps=1)
+ self.assertGreater(scores['accuracy'], 0.9)
+ self.assertIn('loss', scores)
+ predict_input_fn = functools.partial(_input_fn, num_epochs=1)
+ predicted_classes = list(
+ classifier.predict_classes(
+ input_fn=predict_input_fn, as_iterable=True))
+ self.assertEqual(3, len(predicted_classes))
+ for pred in predicted_classes:
+ self.assertIn(pred, label_keys)
+ predictions = list(
+ classifier.predict(input_fn=predict_input_fn, as_iterable=True))
+ self.assertAllEqual(predicted_classes, predictions)
+
def testLogisticRegression_MatrixData(self):
"""Tests binary classification using matrix data as input."""