aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-22 15:17:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-22 15:48:41 -0700
commit3facf91c0468bec1bb8151dd42816c2827a31e6d (patch)
tree6228402486c85bd5bdceee0e7776b39978358107 /tensorflow/contrib/estimator
parent44e55c111fc9240cf777fa771d0128a1f8d64e0b (diff)
Added ability to forward sparse tensors in `forward_feature`.
PiperOrigin-RevId: 209839159
Diffstat (limited to 'tensorflow/contrib/estimator')
-rw-r--r--tensorflow/contrib/estimator/python/estimator/extenders.py29
-rw-r--r--tensorflow/contrib/estimator/python/estimator/extenders_test.py129
2 files changed, 122 insertions, 36 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/extenders.py b/tensorflow/contrib/estimator/python/estimator/extenders.py
index 26449b4651..e3c44bea66 100644
--- a/tensorflow/contrib/estimator/python/estimator/extenders.py
+++ b/tensorflow/contrib/estimator/python/estimator/extenders.py
@@ -26,6 +26,7 @@ from tensorflow.python.estimator.export.export_output import PredictOutput
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import sparse_ops
from tensorflow.python.training import optimizer as optimizer_lib
from tensorflow.python.util import function_utils
@@ -140,7 +141,7 @@ def clip_gradients_by_norm(optimizer, clip_norm):
name='ClipByNorm' + optimizer.get_name())
-def forward_features(estimator, keys=None):
+def forward_features(estimator, keys=None, sparse_default_values=None):
"""Forward features to predictions dictionary.
In some cases, user wants to see some of the features in estimators prediction
@@ -148,39 +149,36 @@ def forward_features(estimator, keys=None):
runs inference on the users graph and returns the results. Keys are essential
because there is no order guarantee on the outputs so they need to be rejoined
to the inputs via keys or transclusion of the inputs in the outputs.
-
Example:
-
```python
def input_fn():
features, labels = ...
features['unique_example_id'] = ...
features, labels
-
estimator = tf.estimator.LinearClassifier(...)
estimator = tf.contrib.estimator.forward_features(
estimator, 'unique_example_id')
estimator.train(...)
assert 'unique_example_id' in estimator.predict(...)
```
-
Args:
estimator: A `tf.estimator.Estimator` object.
- keys: a `string` or a `list` of `string`. If it is `None`, all of the
+ keys: A `string` or a `list` of `string`. If it is `None`, all of the
`features` in `dict` is forwarded to the `predictions`. If it is a
`string`, only given key is forwarded. If it is a `list` of strings, all
the given `keys` are forwarded.
+ sparse_default_values: A dict of `str` keys mapping the name of the sparse
+ features to be converted to dense, to the default value to use. Only
+ sparse features indicated in the dictionary are converted to dense and the
+ provided default value is used.
Returns:
A new `tf.estimator.Estimator` which forwards features to predictions.
-
Raises:
ValueError:
* if `keys` is already part of `predictions`. We don't allow
override.
* if 'keys' does not exist in `features`.
- * if feature key refers to a `SparseTensor`, since we don't support
- `SparseTensor` in `predictions`. `SparseTensor` is common in `features`.
TypeError: if `keys` type is not one of `string` or list/tuple of `string`.
"""
@@ -231,11 +229,18 @@ def forward_features(estimator, keys=None):
for key in get_keys(features):
feature = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
features[key])
+ if sparse_default_values and (key in sparse_default_values):
+ if not isinstance(feature, sparse_tensor_lib.SparseTensor):
+ raise ValueError(
+ 'Feature ({}) is expected to be a `SparseTensor`.'.format(key))
+ feature = sparse_ops.sparse_tensor_to_dense(
+ feature, default_value=sparse_default_values[key])
if not isinstance(feature, ops.Tensor):
raise ValueError(
- 'Forwarded feature ({}) should be a Tensor. Please use keys '
- 'argument of forward_features to filter unwanted features. Type of '
- 'features[{}] is {}.'.format(key, key, type(feature)))
+ 'Feature ({}) should be a Tensor. Please use `keys` '
+ 'argument of forward_features to filter unwanted features, or'
+ 'add key to argument `sparse_default_values`.'
+ 'Type of features[{}] is {}.'.format(key, key, type(feature)))
predictions[key] = feature
spec = spec._replace(predictions=predictions)
if spec.export_outputs:
diff --git a/tensorflow/contrib/estimator/python/estimator/extenders_test.py b/tensorflow/contrib/estimator/python/estimator/extenders_test.py
index 407af2deaf..c8fdaa8791 100644
--- a/tensorflow/contrib/estimator/python/estimator/extenders_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/extenders_test.py
@@ -14,6 +14,7 @@
# ==============================================================================
"""extenders tests."""
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -23,6 +24,7 @@ import tempfile
import numpy as np
from tensorflow.contrib.estimator.python.estimator import extenders
+from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.predictor import from_saved_model
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator_lib
@@ -170,19 +172,53 @@ class ClipGradientsByNormTest(test.TestCase):
class ForwardFeaturesTest(test.TestCase):
"""Tests forward_features."""
- def test_forward_single_key(self):
-
- def input_fn():
- return {'x': [[3.], [5.]], 'id': [[101], [102]]}, [[1.], [2.]]
+ def _export_estimator(self, estimator, serving_input_fn):
+ tmpdir = tempfile.mkdtemp()
+ export_dir_base = os.path.join(
+ compat.as_bytes(tmpdir), compat.as_bytes('export'))
+ export_dir = estimator.export_savedmodel(export_dir_base, serving_input_fn)
+ self.assertTrue(gfile.Exists(export_dir))
+ return export_dir, tmpdir
+ def make_dummy_input_fn(self):
+ def _input_fn():
+ dataset = dataset_ops.Dataset.from_tensors({
+ 'x': [[3.], [5.]],
+ 'id': [[101], [102]],
+ 'sparse_id': sparse_tensor.SparseTensor(
+ values=[1, 2, 3],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2]),
+ 'labels': [[1.], [2.]]
+ })
+ def _split(x):
+ labels = x.pop('labels')
+ return x, labels
+ dataset = dataset.map(_split)
+ return dataset
+ return _input_fn
+
+ def test_forward_keys(self):
+
+ input_fn = self.make_dummy_input_fn()
estimator = linear.LinearRegressor([fc.numeric_column('x')])
estimator.train(input_fn=input_fn, steps=1)
- self.assertNotIn('id', next(estimator.predict(input_fn=input_fn)))
- estimator = extenders.forward_features(estimator, 'id')
- predictions = next(estimator.predict(input_fn=input_fn))
- self.assertIn('id', predictions)
- self.assertEqual(101, predictions['id'])
+ forwarded_keys = ['id', 'sparse_id']
+
+ for key in forwarded_keys:
+ self.assertNotIn(key, next(estimator.predict(input_fn=input_fn)))
+
+ estimator = extenders.forward_features(
+ estimator, forwarded_keys, sparse_default_values={'sparse_id': 1})
+
+ expected_results = [101, 2, 102, 5]
+ predictions = estimator.predict(input_fn=input_fn)
+ for _ in range(2):
+ prediction = next(predictions)
+ for key in forwarded_keys:
+ self.assertIn(key, prediction)
+ self.assertEqual(expected_results.pop(0), sum(prediction[key]))
def test_forward_in_exported(self):
@@ -205,11 +241,7 @@ class ForwardFeaturesTest(test.TestCase):
estimator = extenders.forward_features(estimator, 'id')
# export saved model
- tmpdir = tempfile.mkdtemp()
- export_dir_base = os.path.join(
- compat.as_bytes(tmpdir), compat.as_bytes('export'))
- export_dir = estimator.export_savedmodel(export_dir_base, serving_input_fn)
- self.assertTrue(gfile.Exists(export_dir))
+ export_dir, tmpdir = self._export_estimator(estimator, serving_input_fn)
# restore model
predict_fn = from_saved_model(export_dir, signature_def_key='predict')
@@ -222,6 +254,47 @@ class ForwardFeaturesTest(test.TestCase):
# Clean up.
gfile.DeleteRecursively(tmpdir)
+ def test_forward_in_exported_sparse(self):
+ features_columns = [fc.indicator_column(
+ fc.categorical_column_with_vocabulary_list('x', range(10)))]
+
+ classifier = linear.LinearClassifier(feature_columns=features_columns)
+
+ def train_input_fn():
+ dataset = dataset_ops.Dataset.from_tensors({
+ 'x': sparse_tensor.SparseTensor(
+ values=[1, 2, 3],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2]),
+ 'labels': [[0], [1]]
+ })
+ def _split(x):
+ labels = x.pop('labels')
+ return x, labels
+ dataset = dataset.map(_split)
+ return dataset
+
+ classifier.train(train_input_fn, max_steps=1)
+
+ classifier = extenders.forward_features(
+ classifier, keys=['x'], sparse_default_values={'x': 0})
+
+ def serving_input_fn():
+ features_ph = array_ops.placeholder(dtype=dtypes.int32, name='x',
+ shape=[None])
+ features = {'x': layers.dense_to_sparse(features_ph)}
+ return estimator_lib.export.ServingInputReceiver(features,
+ {'x': features_ph})
+ export_dir, tmpdir = self._export_estimator(classifier, serving_input_fn)
+ prediction_fn = from_saved_model(export_dir, signature_def_key='predict')
+
+ features = (0, 2)
+ prediction = prediction_fn({'x': features})
+
+ self.assertIn('x', prediction)
+ self.assertEqual(features, tuple(prediction['x']))
+ gfile.DeleteRecursively(tmpdir)
+
def test_forward_list(self):
def input_fn():
@@ -266,7 +339,6 @@ class ForwardFeaturesTest(test.TestCase):
extenders.forward_features(estimator, ['x', estimator])
def test_key_should_be_in_features(self):
-
def input_fn():
return {'x': [[3.], [5.]], 'id': [[101], [102]]}, [[1.], [2.]]
@@ -279,27 +351,36 @@ class ForwardFeaturesTest(test.TestCase):
next(estimator.predict(input_fn=input_fn))
def test_forwarded_feature_should_not_be_a_sparse_tensor(self):
-
def input_fn():
return {
'x': [[3.], [5.]],
- 'id':
- sparse_tensor.SparseTensor(
- values=['1', '2'],
- indices=[[0, 0], [1, 0]],
- dense_shape=[2, 1])
- }, [[1.], [2.]]
+ 'id': sparse_tensor.SparseTensor(
+ values=['1', '2'],
+ indices=[[0, 0], [1, 0]],
+ dense_shape=[2, 1])
+ }, [[1.], [2.]]
estimator = linear.LinearRegressor([fc.numeric_column('x')])
estimator.train(input_fn=input_fn, steps=1)
estimator = extenders.forward_features(estimator)
with self.assertRaisesRegexp(ValueError,
- 'Forwarded feature.* should be a Tensor.'):
+ 'Feature .* should be a Tensor.*'):
next(estimator.predict(input_fn=input_fn))
- def test_predictions_should_be_dict(self):
+ def test_forwarded_feature_should_be_a_sparse_tensor(self):
+ input_fn = self.make_dummy_input_fn()
+
+ estimator = linear.LinearRegressor([fc.numeric_column('x')])
+ estimator.train(input_fn=input_fn, steps=1)
+ estimator = extenders.forward_features(
+ estimator, sparse_default_values={'id': 0, 'sparse_id': 0})
+ with self.assertRaisesRegexp(
+ ValueError, 'Feature .* is expected to be a `SparseTensor`.'):
+ next(estimator.predict(input_fn=input_fn))
+
+ def test_predictions_should_be_dict(self):
def input_fn():
return {'x': [[3.], [5.]], 'id': [[101], [102]]}