diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-03-01 13:00:40 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-01 13:04:20 -0800 |
commit | deef58ba3913c4ab9ca93876cd30744db00c4a6a (patch) | |
tree | ee3da315f141e5ae853271cc7802eee77c985a74 /tensorflow/contrib/feature_column | |
parent | 1df40b152216bde47dd9ac1fa65bec57434920e1 (diff) |
Cast sequence_length to an integer.
PiperOrigin-RevId: 187520920
Diffstat (limited to 'tensorflow/contrib/feature_column')
-rw-r--r-- | tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py | 15 |
2 files changed, 10 insertions, 7 deletions
diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py index e99033bbec..e446043bdd 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py @@ -295,7 +295,7 @@ def _sequence_length_from_sparse_tensor(sp_tensor, num_elements=1): row_ids = sp_tensor.indices[:, 0] column_ids = sp_tensor.indices[:, 1] column_ids += array_ops.ones_like(column_ids) - seq_length = ( + seq_length = math_ops.to_int64( math_ops.segment_max(column_ids, segment_ids=row_ids) / num_elements) # If the last n rows do not have ids, seq_length will have shape # [batch_size - n]. Pad the remaining values with zeros. diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py index 8c37ccf11b..105213680e 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py @@ -221,8 +221,9 @@ class SequenceCategoricalColumnWithIdentityTest(test.TestCase): sequence_length = column._sequence_length(_LazyBuilder({'aaa': inputs})) with monitored_session.MonitoredSession() as sess: - self.assertAllEqual( - expected_sequence_length, sequence_length.eval(session=sess)) + sequence_length = sess.run(sequence_length) + self.assertAllEqual(expected_sequence_length, sequence_length) + self.assertEqual(np.int64, sequence_length.dtype) def test_sequence_length_with_zeros(self): column = sfc.sequence_categorical_column_with_identity( @@ -311,8 +312,9 @@ class SequenceEmbeddingColumnTest(test.TestCase): _LazyBuilder({'aaa': sparse_input})) with monitored_session.MonitoredSession() as sess: - self.assertAllEqual( - expected_sequence_length, sequence_length.eval(session=sess)) + sequence_length = sess.run(sequence_length) + self.assertAllEqual(expected_sequence_length, sequence_length) + self.assertEqual(np.int64, sequence_length.dtype) def test_sequence_length_with_empty_rows(self): """Tests _sequence_length when some examples do not have ids.""" @@ -423,8 +425,9 @@ class SequenceNumericColumnTest(test.TestCase): _LazyBuilder({'aaa': sparse_input})) with monitored_session.MonitoredSession() as sess: - self.assertAllEqual( - expected_sequence_length, sequence_length.eval(session=sess)) + sequence_length = sess.run(sequence_length) + self.assertAllEqual(expected_sequence_length, sequence_length) + self.assertEqual(np.int64, sequence_length.dtype) def test_sequence_length_with_shape(self): """Tests _sequence_length with shape !=(1,).""" |