aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-01 13:00:40 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-01 13:04:20 -0800
commitdeef58ba3913c4ab9ca93876cd30744db00c4a6a (patch)
treeee3da315f141e5ae853271cc7802eee77c985a74 /tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
parent1df40b152216bde47dd9ac1fa65bec57434920e1 (diff)
Cast sequence_length to an integer.
PiperOrigin-RevId: 187520920
Diffstat (limited to 'tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py')
-rw-r--r--tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py15
1 files changed, 9 insertions, 6 deletions
diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
index 8c37ccf11b..105213680e 100644
--- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
+++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
@@ -221,8 +221,9 @@ class SequenceCategoricalColumnWithIdentityTest(test.TestCase):
sequence_length = column._sequence_length(_LazyBuilder({'aaa': inputs}))
with monitored_session.MonitoredSession() as sess:
- self.assertAllEqual(
- expected_sequence_length, sequence_length.eval(session=sess))
+ sequence_length = sess.run(sequence_length)
+ self.assertAllEqual(expected_sequence_length, sequence_length)
+ self.assertEqual(np.int64, sequence_length.dtype)
def test_sequence_length_with_zeros(self):
column = sfc.sequence_categorical_column_with_identity(
@@ -311,8 +312,9 @@ class SequenceEmbeddingColumnTest(test.TestCase):
_LazyBuilder({'aaa': sparse_input}))
with monitored_session.MonitoredSession() as sess:
- self.assertAllEqual(
- expected_sequence_length, sequence_length.eval(session=sess))
+ sequence_length = sess.run(sequence_length)
+ self.assertAllEqual(expected_sequence_length, sequence_length)
+ self.assertEqual(np.int64, sequence_length.dtype)
def test_sequence_length_with_empty_rows(self):
"""Tests _sequence_length when some examples do not have ids."""
@@ -423,8 +425,9 @@ class SequenceNumericColumnTest(test.TestCase):
_LazyBuilder({'aaa': sparse_input}))
with monitored_session.MonitoredSession() as sess:
- self.assertAllEqual(
- expected_sequence_length, sequence_length.eval(session=sess))
+ sequence_length = sess.run(sequence_length)
+ self.assertAllEqual(expected_sequence_length, sequence_length)
+ self.assertEqual(np.int64, sequence_length.dtype)
def test_sequence_length_with_shape(self):
"""Tests _sequence_length with shape !=(1,)."""