diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-03-06 10:29:35 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-06 10:33:42 -0800 |
commit | 155743816c0d94ca44186147a9ad1c26f93985a9 (patch) | |
tree | ca4f03bf043a0509dea6cc20270a003fc1e2e31f /tensorflow/contrib/feature_column | |
parent | 432650b580611e8a0da7bd8bbd69235bcaa1bd4c (diff) |
Checks that sequence_length is equal among sequence feature columns.
PiperOrigin-RevId: 188042426
Diffstat (limited to 'tensorflow/contrib/feature_column')
-rw-r--r-- | tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py | 17 | ||||
-rw-r--r-- | tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py | 30 |
2 files changed, 45 insertions, 2 deletions
diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py index ba17b568b6..b25d7e513b 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py @@ -127,8 +127,9 @@ def sequence_input_layer( shape=array_ops.concat([shape[:2], [num_elements]], axis=0))) sequence_lengths.append(sequence_length) fc._verify_static_batch_size_equality(output_tensors, ordered_columns) - # TODO(b/73160931): Verify sequence_length equality. - return array_ops.concat(output_tensors, -1), sequence_lengths[0] + fc._verify_static_batch_size_equality(sequence_lengths, ordered_columns) + sequence_length = _assert_all_equal_and_return(sequence_lengths) + return array_ops.concat(output_tensors, -1), sequence_length # TODO(b/73160931): Add remaining categorical columns. @@ -312,6 +313,18 @@ def sequence_numeric_column( dtype=dtype) +def _assert_all_equal_and_return(tensors, name=None): + """Asserts that all tensors are equal and returns the first one.""" + with ops.name_scope(name, 'assert_all_equal', values=tensors): + if len(tensors) == 1: + return tensors[0] + assert_equal_ops = [] + for t in tensors[1:]: + assert_equal_ops.append(check_ops.assert_equal(tensors[0], t)) + with ops.control_dependencies(assert_equal_ops): + return array_ops.identity(tensors[0]) + + class _SequenceDenseColumn(fc._FeatureColumn): """Represents dense sequence data.""" diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py index 39caa602d9..5c1e76fc62 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py @@ -198,6 +198,36 @@ class SequenceInputLayerTest(test.TestCase): self.assertAllEqual( expected_sequence_length, sequence_length.eval(session=sess)) + def test_sequence_length_not_equal(self): + """Tests that an error is raised when sequence lengths are not equal.""" + # Input a with sequence_length = [2, 1] + sparse_input_a = sparse_tensor.SparseTensorValue( + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)) + # Input b with sequence_length = [1, 1] + sparse_input_b = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0)), + values=(1., 10.), + dense_shape=(2, 2)) + numeric_column_a = sfc.sequence_numeric_column('aaa') + numeric_column_b = sfc.sequence_numeric_column('bbb') + + _, sequence_length = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b, + }, + feature_columns=[numeric_column_a, numeric_column_b]) + + with monitored_session.MonitoredSession() as sess: + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'\[Condition x == y did not hold element-wise:\] ' + r'\[x \(sequence_input_layer/aaa/sequence_length:0\) = \] \[2 1\] ' + r'\[y \(sequence_input_layer/bbb/sequence_length:0\) = \] \[1 1\]'): + sess.run(sequence_length) + def _assert_sparse_tensor_value(test_case, expected, actual): test_case.assertEqual(np.int64, np.array(actual.indices).dtype) |