aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py')
-rw-r--r--tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py41
1 files changed, 41 insertions, 0 deletions
diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
index ee74cf56dc..45d7b74046 100644
--- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
+++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
@@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
from tensorflow.python.training import monitored_session
@@ -947,6 +948,7 @@ class SequenceNumericColumnTest(test.TestCase):
self.assertEqual((1,), a.shape)
self.assertEqual(0., a.default_value)
self.assertEqual(dtypes.float32, a.dtype)
+ self.assertIsNone(a.normalizer_fn)
def test_shape_saved_as_tuple(self):
a = sfc.sequence_numeric_column('aaa', shape=[1, 2])
@@ -965,6 +967,10 @@ class SequenceNumericColumnTest(test.TestCase):
ValueError, 'dtype must be convertible to float'):
sfc.sequence_numeric_column('aaa', dtype=dtypes.string)
+ def test_normalizer_fn_must_be_callable(self):
+ with self.assertRaisesRegexp(TypeError, 'must be a callable'):
+ sfc.sequence_numeric_column('aaa', normalizer_fn='NotACallable')
+
def test_get_sequence_dense_tensor(self):
sparse_input = sparse_tensor.SparseTensorValue(
# example 0, values [[0.], [1]]
@@ -985,6 +991,41 @@ class SequenceNumericColumnTest(test.TestCase):
self.assertAllEqual(
expected_dense_tensor, dense_tensor.eval(session=sess))
+ def test_get_sequence_dense_tensor_with_normalizer_fn(self):
+
+ def _increment_two(input_sparse_tensor):
+ return sparse_ops.sparse_add(
+ input_sparse_tensor,
+ sparse_tensor.SparseTensor(((0, 0), (1, 1)), (2.0, 2.0), (2, 2))
+ )
+
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, values [[0.], [1]]
+ # example 1, [[10.]]
+ indices=((0, 0), (0, 1), (1, 0)),
+ values=(0., 1., 10.),
+ dense_shape=(2, 2))
+
+ # Before _increment_two:
+ # [[0.], [1.]],
+ # [[10.], [0.]],
+ # After _increment_two:
+ # [[2.], [1.]],
+ # [[10.], [2.]],
+ expected_dense_tensor = [
+ [[2.], [1.]],
+ [[10.], [2.]],
+ ]
+ numeric_column = sfc.sequence_numeric_column(
+ 'aaa', normalizer_fn=_increment_two)
+
+ dense_tensor, _ = numeric_column._get_sequence_dense_tensor(
+ _LazyBuilder({'aaa': sparse_input}))
+
+ with monitored_session.MonitoredSession() as sess:
+ self.assertAllEqual(
+ expected_dense_tensor, dense_tensor.eval(session=sess))
+
def test_get_sequence_dense_tensor_with_shape(self):
"""Tests get_sequence_dense_tensor with shape !=(1,)."""
sparse_input = sparse_tensor.SparseTensorValue(