diff options
author | Stephan Hoyer <shoyer@google.com> | 2017-04-28 16:42:09 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-04-28 17:55:31 -0700 |
commit | 8acd664fa1dfbaafe6399bdbb969eddd9446d6a5 (patch) | |
tree | e32484dea22a10fe9824d47be8d32b62f96f84d7 /tensorflow/contrib/labeled_tensor | |
parent | c383ff9040ee05c648ba5d0e9c65c68b5dd30159 (diff) |
Fixup tf.contrib.labeled_tensor.select to handle tuples properly
You can use a tuple to label a point along an axis, so tuples should not be
converted into lists of points for indexing.
Change: 154603014
Diffstat (limited to 'tensorflow/contrib/labeled_tensor')
-rw-r--r-- | tensorflow/contrib/labeled_tensor/python/ops/ops.py | 34 | ||||
-rw-r--r-- | tensorflow/contrib/labeled_tensor/python/ops/ops_test.py | 7 |
2 files changed, 23 insertions, 18 deletions
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/ops.py b/tensorflow/contrib/labeled_tensor/python/ops/ops.py index 98842494fa..c957b41a49 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/ops.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/ops.py @@ -51,8 +51,7 @@ def _gather_1d_on_axis(labeled_tensor, indexer, axis, name=None): @tc.returns(core.LabeledTensor) @tc.accepts(core.LabeledTensorLike, tc.Mapping(string_types, - tc.Union(slice, collections.Hashable, - collections.Sequence)), + tc.Union(slice, collections.Hashable, list)), tc.Optional(string_types)) def select(labeled_tensor, selection, name=None): """Slice out a subset of the tensor. @@ -110,23 +109,22 @@ def select(labeled_tensor, selection, name=None): slices[axis_name] = slice(start, stop) - else: - # We're allowing anything NumPy treats as a scalar or 1D array. - value = np.asarray(value) - if value.ndim == 0: - slices[axis_name] = axis.index(value.item()) - elif value.ndim == 1: - if indexers: - raise NotImplementedError( - 'select does not yet support more than one list selection at ' - 'the same time') - indexer = [axis.index(v) for v in value.tolist()] - indexers[axis_name] = ops.convert_to_tensor( - indexer, dtype=dtypes.int64) - else: + # Needs to be after checking for slices, since slice objects claim to be + # instances of collections.Hashable but hash() on them fails. + elif isinstance(value, collections.Hashable): + slices[axis_name] = axis.index(value) + + elif isinstance(value, list): + if indexers: raise NotImplementedError( - 'select does not yet support selections with more than one ' - 'dimension: %s on axis %r' % (value, axis_name)) + 'select does not yet support more than one list selection at ' + 'the same time') + indexer = [axis.index(v) for v in value] + indexers[axis_name] = ops.convert_to_tensor(indexer, dtype=dtypes.int64) + + else: + # If type checking is working properly, this shouldn't be possible. + raise TypeError('cannot handle arbitrary types') if indexers and slices: raise NotImplementedError( diff --git a/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py b/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py index ea5e008752..0727f4cf88 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py @@ -121,6 +121,13 @@ class SelectTest(Base): golden_lt = core.LabeledTensor(self.tensor[1, 1, :, :], [self.a2, self.a3]) self.assertLabeledTensorsEqual(select_lt, golden_lt) + def test_tuple(self): + original_lt = core.LabeledTensor(constant_op.constant([5, 6]), + [('x', [(1, 2), (3, 4)])]) + select_lt = ops.select(original_lt, {'x': (1, 2)}) + golden_lt = core.LabeledTensor(constant_op.constant(5), []) + self.assertLabeledTensorsEqual(select_lt, golden_lt) + def test_invalid_input(self): with self.assertRaises(ValueError): ops.select(self.original_lt, {'foo': 1}) |