aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/labeled_tensor
diff options
context:
space:
mode:
authorGravatar Stephan Hoyer <shoyer@google.com>2017-04-28 16:42:09 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-28 17:55:31 -0700
commit8acd664fa1dfbaafe6399bdbb969eddd9446d6a5 (patch)
treee32484dea22a10fe9824d47be8d32b62f96f84d7 /tensorflow/contrib/labeled_tensor
parentc383ff9040ee05c648ba5d0e9c65c68b5dd30159 (diff)
Fixup tf.contrib.labeled_tensor.select to handle tuples properly
You can use a tuple to label a point along an axis, so tuples should not be converted into lists of points for indexing. Change: 154603014
Diffstat (limited to 'tensorflow/contrib/labeled_tensor')
-rw-r--r--tensorflow/contrib/labeled_tensor/python/ops/ops.py34
-rw-r--r--tensorflow/contrib/labeled_tensor/python/ops/ops_test.py7
2 files changed, 23 insertions, 18 deletions
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/ops.py b/tensorflow/contrib/labeled_tensor/python/ops/ops.py
index 98842494fa..c957b41a49 100644
--- a/tensorflow/contrib/labeled_tensor/python/ops/ops.py
+++ b/tensorflow/contrib/labeled_tensor/python/ops/ops.py
@@ -51,8 +51,7 @@ def _gather_1d_on_axis(labeled_tensor, indexer, axis, name=None):
@tc.returns(core.LabeledTensor)
@tc.accepts(core.LabeledTensorLike,
tc.Mapping(string_types,
- tc.Union(slice, collections.Hashable,
- collections.Sequence)),
+ tc.Union(slice, collections.Hashable, list)),
tc.Optional(string_types))
def select(labeled_tensor, selection, name=None):
"""Slice out a subset of the tensor.
@@ -110,23 +109,22 @@ def select(labeled_tensor, selection, name=None):
slices[axis_name] = slice(start, stop)
- else:
- # We're allowing anything NumPy treats as a scalar or 1D array.
- value = np.asarray(value)
- if value.ndim == 0:
- slices[axis_name] = axis.index(value.item())
- elif value.ndim == 1:
- if indexers:
- raise NotImplementedError(
- 'select does not yet support more than one list selection at '
- 'the same time')
- indexer = [axis.index(v) for v in value.tolist()]
- indexers[axis_name] = ops.convert_to_tensor(
- indexer, dtype=dtypes.int64)
- else:
+ # Needs to be after checking for slices, since slice objects claim to be
+ # instances of collections.Hashable but hash() on them fails.
+ elif isinstance(value, collections.Hashable):
+ slices[axis_name] = axis.index(value)
+
+ elif isinstance(value, list):
+ if indexers:
raise NotImplementedError(
- 'select does not yet support selections with more than one '
- 'dimension: %s on axis %r' % (value, axis_name))
+ 'select does not yet support more than one list selection at '
+ 'the same time')
+ indexer = [axis.index(v) for v in value]
+ indexers[axis_name] = ops.convert_to_tensor(indexer, dtype=dtypes.int64)
+
+ else:
+ # If type checking is working properly, this shouldn't be possible.
+ raise TypeError('cannot handle arbitrary types')
if indexers and slices:
raise NotImplementedError(
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py b/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py
index ea5e008752..0727f4cf88 100644
--- a/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py
+++ b/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py
@@ -121,6 +121,13 @@ class SelectTest(Base):
golden_lt = core.LabeledTensor(self.tensor[1, 1, :, :], [self.a2, self.a3])
self.assertLabeledTensorsEqual(select_lt, golden_lt)
+ def test_tuple(self):
+ original_lt = core.LabeledTensor(constant_op.constant([5, 6]),
+ [('x', [(1, 2), (3, 4)])])
+ select_lt = ops.select(original_lt, {'x': (1, 2)})
+ golden_lt = core.LabeledTensor(constant_op.constant(5), [])
+ self.assertLabeledTensorsEqual(select_lt, golden_lt)
+
def test_invalid_input(self):
with self.assertRaises(ValueError):
ops.select(self.original_lt, {'foo': 1})