aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/feature_column
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-06 10:29:35 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-06 10:33:42 -0800
commit155743816c0d94ca44186147a9ad1c26f93985a9 (patch)
treeca4f03bf043a0509dea6cc20270a003fc1e2e31f /tensorflow/contrib/feature_column
parent432650b580611e8a0da7bd8bbd69235bcaa1bd4c (diff)
Checks that sequence_length is equal among sequence feature columns.
PiperOrigin-RevId: 188042426
Diffstat (limited to 'tensorflow/contrib/feature_column')
-rw-r--r--tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py17
-rw-r--r--tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py30
2 files changed, 45 insertions, 2 deletions
diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py
index ba17b568b6..b25d7e513b 100644
--- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py
+++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py
@@ -127,8 +127,9 @@ def sequence_input_layer(
shape=array_ops.concat([shape[:2], [num_elements]], axis=0)))
sequence_lengths.append(sequence_length)
fc._verify_static_batch_size_equality(output_tensors, ordered_columns)
- # TODO(b/73160931): Verify sequence_length equality.
- return array_ops.concat(output_tensors, -1), sequence_lengths[0]
+ fc._verify_static_batch_size_equality(sequence_lengths, ordered_columns)
+ sequence_length = _assert_all_equal_and_return(sequence_lengths)
+ return array_ops.concat(output_tensors, -1), sequence_length
# TODO(b/73160931): Add remaining categorical columns.
@@ -312,6 +313,18 @@ def sequence_numeric_column(
dtype=dtype)
+def _assert_all_equal_and_return(tensors, name=None):
+ """Asserts that all tensors are equal and returns the first one."""
+ with ops.name_scope(name, 'assert_all_equal', values=tensors):
+ if len(tensors) == 1:
+ return tensors[0]
+ assert_equal_ops = []
+ for t in tensors[1:]:
+ assert_equal_ops.append(check_ops.assert_equal(tensors[0], t))
+ with ops.control_dependencies(assert_equal_ops):
+ return array_ops.identity(tensors[0])
+
+
class _SequenceDenseColumn(fc._FeatureColumn):
"""Represents dense sequence data."""
diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
index 39caa602d9..5c1e76fc62 100644
--- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
+++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
@@ -198,6 +198,36 @@ class SequenceInputLayerTest(test.TestCase):
self.assertAllEqual(
expected_sequence_length, sequence_length.eval(session=sess))
+ def test_sequence_length_not_equal(self):
+ """Tests that an error is raised when sequence lengths are not equal."""
+ # Input a with sequence_length = [2, 1]
+ sparse_input_a = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (0, 1), (1, 0)),
+ values=(0., 1., 10.),
+ dense_shape=(2, 2))
+ # Input b with sequence_length = [1, 1]
+ sparse_input_b = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0)),
+ values=(1., 10.),
+ dense_shape=(2, 2))
+ numeric_column_a = sfc.sequence_numeric_column('aaa')
+ numeric_column_b = sfc.sequence_numeric_column('bbb')
+
+ _, sequence_length = sfc.sequence_input_layer(
+ features={
+ 'aaa': sparse_input_a,
+ 'bbb': sparse_input_b,
+ },
+ feature_columns=[numeric_column_a, numeric_column_b])
+
+ with monitored_session.MonitoredSession() as sess:
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ r'\[Condition x == y did not hold element-wise:\] '
+ r'\[x \(sequence_input_layer/aaa/sequence_length:0\) = \] \[2 1\] '
+ r'\[y \(sequence_input_layer/bbb/sequence_length:0\) = \] \[1 1\]'):
+ sess.run(sequence_length)
+
def _assert_sparse_tensor_value(test_case, expected, actual):
test_case.assertEqual(np.int64, np.array(actual.indices).dtype)