aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/feature_column
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-06-18 21:00:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-18 21:04:00 -0700
commit6070ae0e148f50dbc8f36e1654f0a3f53b8b067e (patch)
tree165e4c050220180a76512e304b70eee0cd02a2db /tensorflow/contrib/feature_column
parent60b78d6152e6f8d985f3086930ff986c140c36bf (diff)
Merge changes from github.
PiperOrigin-RevId: 201110240
Diffstat (limited to 'tensorflow/contrib/feature_column')
-rw-r--r--tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py22
-rw-r--r--tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py41
2 files changed, 59 insertions, 4 deletions
diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py
index 84a413c791..05bcdac2ca 100644
--- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py
+++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py
@@ -346,7 +346,8 @@ def sequence_numeric_column(
key,
shape=(1,),
default_value=0.,
- dtype=dtypes.float32):
+ dtype=dtypes.float32,
+ normalizer_fn=None):
"""Returns a feature column that represents sequences of numeric data.
Example:
@@ -370,6 +371,12 @@ def sequence_numeric_column(
default_value: A single value compatible with `dtype` that is used for
padding the sparse data into a dense `Tensor`.
dtype: The type of values.
+ normalizer_fn: If not `None`, a function that can be used to normalize the
+ value of the tensor after `default_value` is applied for parsing.
+ Normalizer function takes the input `Tensor` as its argument, and returns
+ the output `Tensor`. (e.g. lambda x: (x - 3.0) / 4.2). Please note that
+ even though the most common use case of this function is normalization, it
+ can be used for any kind of Tensorflow transformations.
Returns:
A `_SequenceNumericColumn`.
@@ -383,12 +390,16 @@ def sequence_numeric_column(
if not (dtype.is_integer or dtype.is_floating):
raise ValueError('dtype must be convertible to float. '
'dtype: {}, key: {}'.format(dtype, key))
+ if normalizer_fn is not None and not callable(normalizer_fn):
+ raise TypeError(
+ 'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn))
return _SequenceNumericColumn(
key,
shape=shape,
default_value=default_value,
- dtype=dtype)
+ dtype=dtype,
+ normalizer_fn=normalizer_fn)
def _assert_all_equal_and_return(tensors, name=None):
@@ -407,7 +418,7 @@ class _SequenceNumericColumn(
fc._SequenceDenseColumn,
collections.namedtuple(
'_SequenceNumericColumn',
- ['key', 'shape', 'default_value', 'dtype'])):
+ ['key', 'shape', 'default_value', 'dtype', 'normalizer_fn'])):
"""Represents sequences of numeric data."""
@property
@@ -419,7 +430,10 @@ class _SequenceNumericColumn(
return {self.key: parsing_ops.VarLenFeature(self.dtype)}
def _transform_feature(self, inputs):
- return inputs.get(self.key)
+ input_tensor = inputs.get(self.key)
+ if self.normalizer_fn is not None:
+ input_tensor = self.normalizer_fn(input_tensor)
+ return input_tensor
@property
def _variable_shape(self):
diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
index ee74cf56dc..45d7b74046 100644
--- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
+++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
@@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
from tensorflow.python.training import monitored_session
@@ -947,6 +948,7 @@ class SequenceNumericColumnTest(test.TestCase):
self.assertEqual((1,), a.shape)
self.assertEqual(0., a.default_value)
self.assertEqual(dtypes.float32, a.dtype)
+ self.assertIsNone(a.normalizer_fn)
def test_shape_saved_as_tuple(self):
a = sfc.sequence_numeric_column('aaa', shape=[1, 2])
@@ -965,6 +967,10 @@ class SequenceNumericColumnTest(test.TestCase):
ValueError, 'dtype must be convertible to float'):
sfc.sequence_numeric_column('aaa', dtype=dtypes.string)
+ def test_normalizer_fn_must_be_callable(self):
+ with self.assertRaisesRegexp(TypeError, 'must be a callable'):
+ sfc.sequence_numeric_column('aaa', normalizer_fn='NotACallable')
+
def test_get_sequence_dense_tensor(self):
sparse_input = sparse_tensor.SparseTensorValue(
# example 0, values [[0.], [1]]
@@ -985,6 +991,41 @@ class SequenceNumericColumnTest(test.TestCase):
self.assertAllEqual(
expected_dense_tensor, dense_tensor.eval(session=sess))
+ def test_get_sequence_dense_tensor_with_normalizer_fn(self):
+
+ def _increment_two(input_sparse_tensor):
+ return sparse_ops.sparse_add(
+ input_sparse_tensor,
+ sparse_tensor.SparseTensor(((0, 0), (1, 1)), (2.0, 2.0), (2, 2))
+ )
+
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, values [[0.], [1]]
+ # example 1, [[10.]]
+ indices=((0, 0), (0, 1), (1, 0)),
+ values=(0., 1., 10.),
+ dense_shape=(2, 2))
+
+ # Before _increment_two:
+ # [[0.], [1.]],
+ # [[10.], [0.]],
+ # After _increment_two:
+ # [[2.], [1.]],
+ # [[10.], [2.]],
+ expected_dense_tensor = [
+ [[2.], [1.]],
+ [[10.], [2.]],
+ ]
+ numeric_column = sfc.sequence_numeric_column(
+ 'aaa', normalizer_fn=_increment_two)
+
+ dense_tensor, _ = numeric_column._get_sequence_dense_tensor(
+ _LazyBuilder({'aaa': sparse_input}))
+
+ with monitored_session.MonitoredSession() as sess:
+ self.assertAllEqual(
+ expected_dense_tensor, dense_tensor.eval(session=sess))
+
def test_get_sequence_dense_tensor_with_shape(self):
"""Tests get_sequence_dense_tensor with shape !=(1,)."""
sparse_input = sparse_tensor.SparseTensorValue(