aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/labeled_tensor
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-01-26 11:02:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-26 11:25:23 -0800
commit6e9265f74e20e6818651470887846d8292083f66 (patch)
treed09e22fe5dfc1a0eceb0b0aa361285b8381b9188 /tensorflow/contrib/labeled_tensor
parentc36c8248302f4733cca445df4e8b3e198320a63a (diff)
Adds foldl to LabeledTensor.
Addresses a deprecation warning for tf.neg. Does a bunch of annoying autoformatting. Change: 145694642
Diffstat (limited to 'tensorflow/contrib/labeled_tensor')
-rw-r--r--tensorflow/contrib/labeled_tensor/__init__.py2
-rw-r--r--tensorflow/contrib/labeled_tensor/python/ops/core.py63
-rw-r--r--tensorflow/contrib/labeled_tensor/python/ops/ops.py224
-rw-r--r--tensorflow/contrib/labeled_tensor/python/ops/ops_test.py28
4 files changed, 210 insertions, 107 deletions
diff --git a/tensorflow/contrib/labeled_tensor/__init__.py b/tensorflow/contrib/labeled_tensor/__init__.py
index 71bb7b95f5..64c83cbad8 100644
--- a/tensorflow/contrib/labeled_tensor/__init__.py
+++ b/tensorflow/contrib/labeled_tensor/__init__.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Labels for TensorFlow."""
from __future__ import absolute_import
@@ -107,6 +106,7 @@ reshape = _ops.reshape
rename_axis = _ops.rename_axis
random_crop = _ops.random_crop
map_fn = _ops.map_fn
+foldl = _ops.foldl
squeeze = _ops.squeeze
matmul = _ops.matmul
tile = _ops.tile
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/core.py b/tensorflow/contrib/labeled_tensor/python/ops/core.py
index 214e21aa19..393c7f93f3 100644
--- a/tensorflow/contrib/labeled_tensor/python/ops/core.py
+++ b/tensorflow/contrib/labeled_tensor/python/ops/core.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Core classes and core ops for LabeledTensor.
Core ops are ops which will eventually be called by LabeledTensor methods,
@@ -43,7 +42,6 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
-
# pylint: disable=invalid-name
# Types coercible to Axis.labels
@@ -124,10 +122,8 @@ class Axis(object):
@tc.returns(bool)
def __eq__(self, other):
- return (isinstance(other, Axis) and
- self.name == other.name and
- self.size == other.size and
- self.labels == other.labels)
+ return (isinstance(other, Axis) and self.name == other.name and
+ self.size == other.size and self.labels == other.labels)
def __hash__(self):
return hash((self.name, self.size, self.labels))
@@ -504,9 +500,8 @@ def _convert_labeled_tensor_to_tensor(value, *args, **kwargs):
return ops.internal_convert_to_tensor(value.tensor, *args, **kwargs)
-ops.register_tensor_conversion_function(
- LabeledTensor, _convert_labeled_tensor_to_tensor)
-
+ops.register_tensor_conversion_function(LabeledTensor,
+ _convert_labeled_tensor_to_tensor)
# tc class for anything that can be coerced into a LabeledTensor
# pylint: disable=invalid-name
@@ -609,14 +604,16 @@ def identity(labeled_tensor, name=None):
with ops.name_scope(name, 'lt_identity', [labeled_tensor]) as scope:
labeled_tensor = convert_to_labeled_tensor(labeled_tensor)
return LabeledTensor(
- array_ops.identity(labeled_tensor.tensor, name=scope),
+ array_ops.identity(
+ labeled_tensor.tensor, name=scope),
labeled_tensor.axes)
# We don't call this slice because that shadows a built-in. Instead, we alias
# this to lt.slice in __init__.py.
@tc.returns(LabeledTensor)
-@tc.accepts(LabeledTensorLike, tc.Mapping(string_types, tc.Union(int, slice)),
+@tc.accepts(LabeledTensorLike,
+ tc.Mapping(string_types, tc.Union(int, slice)),
tc.Optional(string_types))
def slice_function(labeled_tensor, selection, name=None):
"""Slice out a subset of the tensor.
@@ -667,13 +664,14 @@ def slice_function(labeled_tensor, selection, name=None):
# If the slice is an int this dimension now has size 1, so we remove it.
assert isinstance(s, int)
- return LabeledTensor(array_ops.identity(sliced_tensor, name=scope),
- sliced_axes)
+ return LabeledTensor(
+ array_ops.identity(
+ sliced_tensor, name=scope), sliced_axes)
@tc.returns(LabeledTensor)
-@tc.accepts(LabeledTensorLike, tc.Optional(tc.Collection(string_types)),
- tc.Optional(string_types))
+@tc.accepts(LabeledTensorLike,
+ tc.Optional(tc.Collection(string_types)), tc.Optional(string_types))
def transpose(labeled_tensor, axis_order=None, name=None):
"""Permute a tensor's axes.
@@ -707,9 +705,8 @@ def transpose(labeled_tensor, axis_order=None, name=None):
permutation = [axis_names.index(n) for n in axis_order]
# Note: TensorFlow doesn't copy data for the identity tranpose.
- transpose_tensor = array_ops.transpose(labeled_tensor.tensor,
- permutation,
- name=scope)
+ transpose_tensor = array_ops.transpose(
+ labeled_tensor.tensor, permutation, name=scope)
permuted_axes = [labeled_tensor.axes[n] for n in axis_order]
@@ -717,8 +714,11 @@ def transpose(labeled_tensor, axis_order=None, name=None):
@tc.returns(LabeledTensor)
-@tc.accepts(LabeledTensorLike, tc.Collection(tc.Union(string_types, tc.Tuple(
- string_types, collections.Hashable))), tc.Optional(string_types))
+@tc.accepts(
+ LabeledTensorLike,
+ tc.Collection(
+ tc.Union(string_types, tc.Tuple(string_types, collections.Hashable))),
+ tc.Optional(string_types))
def expand_dims(labeled_tensor, axes, name=None):
"""Insert dimensions of size 1.
@@ -762,11 +762,12 @@ def expand_dims(labeled_tensor, axes, name=None):
shape.append(1)
- reshaped_tensor = array_ops.reshape(labeled_tensor.tensor, shape,
- name=scope)
+ reshaped_tensor = array_ops.reshape(
+ labeled_tensor.tensor, shape, name=scope)
return LabeledTensor(reshaped_tensor, reshaped_axes)
+
# This should only be added to a graph collection once.
_AXIS_ORDER_KEY = ('__axis_order',)
@@ -881,8 +882,8 @@ def check_axis_order(labeled_tensor, axis_order=None):
@tc.returns(LabeledTensor)
-@tc.accepts(LabeledTensorLike, tc.Optional(tc.Collection(string_types)),
- tc.Optional(string_types))
+@tc.accepts(LabeledTensorLike,
+ tc.Optional(tc.Collection(string_types)), tc.Optional(string_types))
def impose_axis_order(labeled_tensor, axis_order=None, name=None):
"""Impose desired axis order on a labeled tensor.
@@ -1036,12 +1037,10 @@ def align(labeled_tensor_0, labeled_tensor_1, name=None):
'impose_axis_order to reorder axes on one of more of the inputs.' %
(axes_0.keys(), axes_1.keys()))
- labeled_tensor_0 = expand_dims(labeled_tensor_0,
- new_axis_names,
- name=scope + '0')
- labeled_tensor_1 = expand_dims(labeled_tensor_1,
- new_axis_names,
- name=scope + '1')
+ labeled_tensor_0 = expand_dims(
+ labeled_tensor_0, new_axis_names, name=scope + '0')
+ labeled_tensor_1 = expand_dims(
+ labeled_tensor_1, new_axis_names, name=scope + '1')
broadcast_axes = []
for axis_name in new_axis_names:
@@ -1189,8 +1188,8 @@ logical_xor = define_binary_op('logical_xor', math_ops.logical_xor)
maximum = define_binary_op('maximum', math_ops.maximum)
minimum = define_binary_op('minimum', math_ops.minimum)
-squared_difference = define_binary_op(
- 'squared_difference', math_ops.squared_difference)
+squared_difference = define_binary_op('squared_difference',
+ math_ops.squared_difference)
igamma = define_binary_op('igamma', math_ops.igamma)
igammac = define_binary_op('igammac', math_ops.igammac)
zeta = define_binary_op('zeta', math_ops.zeta)
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/ops.py b/tensorflow/contrib/labeled_tensor/python/ops/ops.py
index 6b8514fe62..92da7e6a13 100644
--- a/tensorflow/contrib/labeled_tensor/python/ops/ops.py
+++ b/tensorflow/contrib/labeled_tensor/python/ops/ops.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Non-core ops for LabeledTensor."""
from __future__ import absolute_import
from __future__ import division
@@ -29,6 +28,7 @@ from tensorflow.contrib.labeled_tensor.python.ops import core
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import numerics
from tensorflow.python.ops import random_ops
@@ -40,18 +40,19 @@ from tensorflow.python.training import input # pylint: disable=redefined-builti
tc.Optional(string_types))
def _gather_1d_on_axis(labeled_tensor, indexer, axis, name=None):
with ops.name_scope(name, 'lt_take', [labeled_tensor]) as scope:
- temp_axes = core.Axes(
- [axis] + list(labeled_tensor.axes.remove(axis.name).values()))
+ temp_axes = core.Axes([axis] + list(
+ labeled_tensor.axes.remove(axis.name).values()))
transposed = core.transpose(labeled_tensor, temp_axes.keys())
- indexed = core.LabeledTensor(array_ops.gather(transposed.tensor, indexer),
- temp_axes)
+ indexed = core.LabeledTensor(
+ array_ops.gather(transposed.tensor, indexer), temp_axes)
return core.transpose(indexed, labeled_tensor.axes.keys(), name=scope)
@tc.returns(core.LabeledTensor)
@tc.accepts(core.LabeledTensorLike,
- tc.Mapping(string_types, tc.Union(
- slice, collections.Hashable, collections.Sequence)),
+ tc.Mapping(string_types,
+ tc.Union(slice, collections.Hashable,
+ collections.Sequence)),
tc.Optional(string_types))
def select(labeled_tensor, selection, name=None):
"""Slice out a subset of the tensor.
@@ -143,8 +144,9 @@ def select(labeled_tensor, selection, name=None):
@tc.returns(core.LabeledTensor)
-@tc.accepts(tc.Collection(core.LabeledTensorLike), string_types,
- tc.Optional(string_types))
+@tc.accepts(
+ tc.Collection(core.LabeledTensorLike), string_types,
+ tc.Optional(string_types))
def concat(labeled_tensors, axis_name, name=None):
"""Concatenate tensors along a dimension.
@@ -165,8 +167,9 @@ def concat(labeled_tensors, axis_name, name=None):
have incompatible axes, or if `axis_name` isn't the name of an axis.
"""
with ops.name_scope(name, 'lt_concat', labeled_tensors) as scope:
- labeled_tensors = [core.convert_to_labeled_tensor(lt)
- for lt in labeled_tensors]
+ labeled_tensors = [
+ core.convert_to_labeled_tensor(lt) for lt in labeled_tensors
+ ]
if len(labeled_tensors) < 1:
raise ValueError('concat expects at least 1 tensor, but received %s' %
@@ -212,8 +215,7 @@ def concat(labeled_tensors, axis_name, name=None):
@tc.returns(core.LabeledTensor)
@tc.accepts(
tc.Collection(core.LabeledTensorLike),
- tc.Union(string_types, core.AxisLike),
- int, tc.Optional(string_types))
+ tc.Union(string_types, core.AxisLike), int, tc.Optional(string_types))
def pack(labeled_tensors, new_axis, axis_position=0, name=None):
"""Pack tensors along a new axis.
@@ -235,8 +237,9 @@ def pack(labeled_tensors, new_axis, axis_position=0, name=None):
don't have identical axes.
"""
with ops.name_scope(name, 'lt_pack', labeled_tensors) as scope:
- labeled_tensors = [core.convert_to_labeled_tensor(lt)
- for lt in labeled_tensors]
+ labeled_tensors = [
+ core.convert_to_labeled_tensor(lt) for lt in labeled_tensors
+ ]
if len(labeled_tensors) < 1:
raise ValueError('pack expects at least 1 tensors, but received %s' %
@@ -256,8 +259,8 @@ def pack(labeled_tensors, new_axis, axis_position=0, name=None):
@tc.returns(tc.List(core.LabeledTensor))
-@tc.accepts(core.LabeledTensorLike, tc.Optional(string_types),
- tc.Optional(string_types))
+@tc.accepts(core.LabeledTensorLike,
+ tc.Optional(string_types), tc.Optional(string_types))
def unpack(labeled_tensor, axis_name=None, name=None):
"""Unpack the tensor.
@@ -287,13 +290,13 @@ def unpack(labeled_tensor, axis_name=None, name=None):
axis = axis_names.index(axis_name)
unpack_ops = array_ops.unstack(labeled_tensor.tensor, axis=axis, name=scope)
- axes = [a for i, a in enumerate(labeled_tensor.axes.values())
- if i != axis]
+ axes = [a for i, a in enumerate(labeled_tensor.axes.values()) if i != axis]
return [core.LabeledTensor(t, axes) for t in unpack_ops]
@tc.returns(core.LabeledTensor)
-@tc.accepts(core.LabeledTensorLike, tc.Collection(string_types),
+@tc.accepts(core.LabeledTensorLike,
+ tc.Collection(string_types),
tc.Collection(tc.Union(string_types, core.AxisLike)),
tc.Optional(string_types))
def reshape(labeled_tensor, existing_axes, new_axes, name=None):
@@ -409,8 +412,9 @@ def _batch_helper(default_name,
allow_smaller_final_batch,
name=None):
with ops.name_scope(name, default_name, labeled_tensors) as scope:
- labeled_tensors = [core.convert_to_labeled_tensor(lt)
- for lt in labeled_tensors]
+ labeled_tensors = [
+ core.convert_to_labeled_tensor(lt) for lt in labeled_tensors
+ ]
batch_ops = batch_fn([t.tensor for t in labeled_tensors], scope)
# TODO(shoyer): Remove this when they sanitize the TF API.
@@ -481,13 +485,14 @@ def batch(labeled_tensors,
"""
def fn(tensors, scope):
- return input.batch(tensors,
- batch_size=batch_size,
- num_threads=num_threads,
- capacity=capacity,
- enqueue_many=enqueue_many,
- allow_smaller_final_batch=allow_smaller_final_batch,
- name=scope)
+ return input.batch(
+ tensors,
+ batch_size=batch_size,
+ num_threads=num_threads,
+ capacity=capacity,
+ enqueue_many=enqueue_many,
+ allow_smaller_final_batch=allow_smaller_final_batch,
+ name=scope)
return _batch_helper('lt_batch', fn, batch_size, enqueue_many,
labeled_tensors, allow_smaller_final_batch, name)
@@ -552,7 +557,8 @@ def shuffle_batch(labeled_tensors,
@tc.returns(core.LabeledTensor)
-@tc.accepts(core.LabeledTensorLike, tc.Mapping(string_types, int),
+@tc.accepts(core.LabeledTensorLike,
+ tc.Mapping(string_types, int),
tc.Optional(int), tc.Optional(string_types))
def random_crop(labeled_tensor, shape_map, seed=None, name=None):
"""Randomly crops a tensor to a given size.
@@ -593,10 +599,8 @@ def random_crop(labeled_tensor, shape_map, seed=None, name=None):
shape.append(len(axis))
axes.append(axis)
- crop_op = random_ops.random_crop(labeled_tensor.tensor,
- shape,
- seed=seed,
- name=scope)
+ crop_op = random_ops.random_crop(
+ labeled_tensor.tensor, shape, seed=seed, name=scope)
return core.LabeledTensor(crop_op, axes)
@@ -622,14 +626,82 @@ def map_fn(fn, labeled_tensor, name=None):
"""
with ops.name_scope(name, 'lt_map_fn', [labeled_tensor]) as scope:
labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
+
unpack_lts = unpack(labeled_tensor)
- map_lts = [fn(t) for t in unpack_lts]
- return pack(map_lts, list(labeled_tensor.axes.values())[0], name=scope)
+
+ # TODO(ericmc): Fix this upstream.
+ if labeled_tensor.dtype == dtypes.string:
+ # We must construct the full graph here, because functional_ops.map_fn
+ # doesn't work for string-valued tensors.
+ # Constructing the full graph may be slow.
+ map_lts = [fn(t) for t in unpack_lts]
+ return pack(map_lts, list(labeled_tensor.axes.values())[0], name=scope)
+ else:
+ # Figure out what the axis labels should be, but use tf.map_fn to
+ # construct the graph because it's efficient.
+ # It may be slow to construct the full graph, so we infer the labels from
+ # the first element.
+ # TODO(ericmc): This builds a subgraph which then gets thrown away.
+ # Find a more elegant solution.
+ first_map_lt = fn(unpack_lts[0])
+ final_axes = list(labeled_tensor.axes.values())[:1] + list(
+ first_map_lt.axes.values())
+
+ @tc.returns(ops.Tensor)
+ @tc.accepts(ops.Tensor)
+ def tf_fn(tensor):
+ original_axes = labeled_tensor.axes.values()[1:]
+ tensor_lt = core.LabeledTensor(tensor, original_axes)
+ return fn(tensor_lt).tensor
+
+ map_op = functional_ops.map_fn(tf_fn, labeled_tensor.tensor)
+ map_lt = core.LabeledTensor(map_op, final_axes)
+
+ return core.identity(map_lt, name=scope)
@tc.returns(core.LabeledTensor)
-@tc.accepts(core.LabeledTensorLike, tc.Optional(tc.Collection(string_types)),
- tc.Optional(string_types))
+@tc.accepts(collections.Callable, core.LabeledTensorLike,
+ core.LabeledTensorLike, tc.Optional(string_types))
+def foldl(fn, labeled_tensor, initial_value, name=None):
+ """Left fold on the list of tensors unpacked from labeled_tensor.
+
+ See tf.foldl.
+
+ Args:
+ fn: The function to apply to each unpacked LabeledTensor.
+ It should have type (LabeledTensor, LabeledTensor) -> LabeledTensor.
+ Its arguments are (accumulated_value, next_value).
+ labeled_tensor: The input tensor.
+ initial_value: The initial value of the accumulator.
+ name: Optional op name.
+
+ Returns:
+ The accumulated value.
+ """
+ with ops.name_scope(name, 'lt_foldl',
+ [labeled_tensor, initial_value]) as scope:
+ labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
+ initial_value = core.convert_to_labeled_tensor(initial_value)
+
+ @tc.returns(ops.Tensor)
+ @tc.accepts(ops.Tensor, ops.Tensor)
+ def tf_fn(accumulator, next_element):
+ accumulator_lt = core.LabeledTensor(accumulator, initial_value.axes)
+ next_element_lt = core.LabeledTensor(next_element,
+ labeled_tensor.axes.values()[1:])
+ return fn(accumulator_lt, next_element_lt).tensor
+
+ foldl_op = functional_ops.foldl(
+ tf_fn, labeled_tensor.tensor, initializer=initial_value.tensor)
+ foldl_lt = core.LabeledTensor(foldl_op, initial_value.axes)
+
+ return core.identity(foldl_lt, name=scope)
+
+
+@tc.returns(core.LabeledTensor)
+@tc.accepts(core.LabeledTensorLike,
+ tc.Optional(tc.Collection(string_types)), tc.Optional(string_types))
def squeeze(labeled_tensor, axis_names=None, name=None):
"""Remove size-1 dimensions.
@@ -672,17 +744,17 @@ def squeeze(labeled_tensor, axis_names=None, name=None):
axes.append(axis)
if squeeze_dimensions:
- squeeze_op = array_ops.squeeze(labeled_tensor.tensor,
- squeeze_dimensions,
- name=scope)
+ squeeze_op = array_ops.squeeze(
+ labeled_tensor.tensor, squeeze_dimensions, name=scope)
else:
squeeze_op = array_ops.identity(labeled_tensor.tensor, name=scope)
return core.LabeledTensor(squeeze_op, axes)
+
# pylint: disable=invalid-name
-ReduceAxis = tc.Union(
- string_types, tc.Tuple(string_types, collections.Hashable))
+ReduceAxis = tc.Union(string_types,
+ tc.Tuple(string_types, collections.Hashable))
ReduceAxes = tc.Optional(tc.Union(ReduceAxis, tc.Collection(ReduceAxis)))
# pylint: enable=invalid-name
@@ -766,8 +838,9 @@ def matmul(a, b, name=None):
axis_scope_order = core.get_axis_order()
if axis_scope_order is not None:
result_axis_names = [axis.name for axis in result_axes]
- new_axis_names = [name for name in axis_scope_order
- if name in result_axis_names]
+ new_axis_names = [
+ name for name in axis_scope_order if name in result_axis_names
+ ]
if new_axis_names != result_axis_names:
# switch a and b
b, a = a, b
@@ -792,10 +865,8 @@ def matmul(a, b, name=None):
b_tensor = b.tensor
transpose_b = list(b.axes.keys()).index(shared_axis) == 1
- result_op = math_ops.matmul(a_tensor,
- b_tensor,
- transpose_a=transpose_a,
- transpose_b=transpose_b)
+ result_op = math_ops.matmul(
+ a_tensor, b_tensor, transpose_a=transpose_a, transpose_b=transpose_b)
if squeeze_dims:
result_op = array_ops.squeeze(result_op, squeeze_dims)
@@ -881,9 +952,8 @@ def define_reduce_op(op_name, reduce_fn):
else:
intermediate_axes.append(axis)
- reduce_op = reduce_fn(labeled_tensor.tensor,
- reduction_dimensions,
- keep_dims=True)
+ reduce_op = reduce_fn(
+ labeled_tensor.tensor, reduction_dimensions, keep_dims=True)
reduce_lt = core.LabeledTensor(reduce_op, intermediate_axes)
return squeeze(reduce_lt, axes_to_squeeze, name=scope)
@@ -906,7 +976,8 @@ reduce_sum = define_reduce_op('reduce_sum', math_ops.reduce_sum)
@tc.returns(core.LabeledTensor)
-@tc.accepts(core.LabeledTensorLike, tc.Mapping(str, tc.Union(int, ops.Tensor)),
+@tc.accepts(core.LabeledTensorLike,
+ tc.Mapping(str, tc.Union(int, ops.Tensor)),
tc.Optional(string_types))
def tile(labeled_tensor, multiples, name=None):
"""Constructs a tensor by tiling a given tensor.
@@ -938,16 +1009,20 @@ def tile(labeled_tensor, multiples, name=None):
'names %r on the input labeled tensor' %
(multiples.keys(), labeled_tensor.axes))
- labeled_axes = [name for name in multiples
- if labeled_tensor.axes[name].labels is not None]
+ labeled_axes = [
+ name for name in multiples
+ if labeled_tensor.axes[name].labels is not None
+ ]
if labeled_axes:
raise ValueError('cannot tile axes with tick labels: %r' % labeled_axes)
multiples_list = [multiples.get(name, 1) for name in labeled_tensor.axes]
tile_op = array_ops.tile(labeled_tensor.tensor, multiples_list, name=scope)
- new_axes = [axis.name if axis.labels is None else axis
- for axis in labeled_tensor.axes.values()]
+ new_axes = [
+ axis.name if axis.labels is None else axis
+ for axis in labeled_tensor.axes.values()
+ ]
return core.LabeledTensor(tile_op, new_axes)
@@ -997,19 +1072,21 @@ def pad(labeled_tensor, paddings, mode='CONSTANT', name=None):
new_axes.append(axis)
padding_pairs.append((0, 0))
- pad_op = array_ops.pad(
- labeled_tensor.tensor, padding_pairs, mode, name=scope)
+ pad_op = array_ops.pad(labeled_tensor.tensor,
+ padding_pairs,
+ mode,
+ name=scope)
return core.LabeledTensor(pad_op, new_axes)
@tc.returns(core.LabeledTensor)
-@tc.accepts(tc.Union(np.ndarray, list, tuple, core.Scalar),
- tc.Optional(dtypes.DType),
- tc.Optional(tc.Union(
- core.Axes,
- tc.Collection(tc.Union(string_types, core.AxisLike)))),
- tc.Optional(string_types))
+@tc.accepts(
+ tc.Union(np.ndarray, list, tuple, core.Scalar),
+ tc.Optional(dtypes.DType),
+ tc.Optional(
+ tc.Union(core.Axes, tc.Collection(
+ tc.Union(string_types, core.AxisLike)))), tc.Optional(string_types))
def constant(value, dtype=None, axes=None, name=None):
"""Creates a constant tensor.
@@ -1050,8 +1127,8 @@ def constant(value, dtype=None, axes=None, name=None):
@tc.returns(core.LabeledTensor)
-@tc.accepts(core.LabeledTensorLike, tc.Optional(dtypes.DType),
- tc.Optional(string_types))
+@tc.accepts(core.LabeledTensorLike,
+ tc.Optional(dtypes.DType), tc.Optional(string_types))
def zeros_like(labeled_tensor, dtype=None, name=None):
"""Creates an identical tensor with all elements set to zero.
@@ -1070,8 +1147,8 @@ def zeros_like(labeled_tensor, dtype=None, name=None):
@tc.returns(core.LabeledTensor)
-@tc.accepts(core.LabeledTensorLike, tc.Optional(dtypes.DType),
- tc.Optional(string_types))
+@tc.accepts(core.LabeledTensorLike,
+ tc.Optional(dtypes.DType), tc.Optional(string_types))
def ones_like(labeled_tensor, dtype=None, name=None):
"""Creates an identical tensor with all elements set to one.
@@ -1090,8 +1167,8 @@ def ones_like(labeled_tensor, dtype=None, name=None):
@tc.returns(core.LabeledTensor)
-@tc.accepts(core.LabeledTensorLike, tc.Optional(dtypes.DType),
- tc.Optional(string_types))
+@tc.accepts(core.LabeledTensorLike,
+ tc.Optional(dtypes.DType), tc.Optional(string_types))
def cast(labeled_tensor, dtype=None, name=None):
"""Casts a labeled tensor to a new type.
@@ -1127,9 +1204,8 @@ def verify_tensor_all_finite(labeled_tensor, message, name=None):
with ops.name_scope(name, 'lt_verify_tensor_all_finite',
[labeled_tensor]) as scope:
labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
- op = numerics.verify_tensor_all_finite(labeled_tensor.tensor,
- msg=message,
- name=scope)
+ op = numerics.verify_tensor_all_finite(
+ labeled_tensor.tensor, msg=message, name=scope)
return core.LabeledTensor(op, labeled_tensor.axes)
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py b/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py
index bbe77f9fef..ea5e008752 100644
--- a/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py
+++ b/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py
@@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test as test_lib
@@ -477,6 +478,33 @@ class MapFnTest(Base):
slice_lt = core.slice_function(self.original_lt, {'channel': 1})
self.assertLabeledTensorsEqual(map_lt, slice_lt)
+ def test_string(self):
+
+ def fn(entry_lt):
+ op = string_ops.string_join([entry_lt, 'world'])
+ return core.LabeledTensor(op, [])
+
+ tensor_lt = ops.constant(['hi', 'bye'], axes=['batch'])
+ map_lt = ops.map_fn(fn, tensor_lt)
+ golden_lt = ops.constant(['hiworld', 'byeworld'], axes=['batch'])
+
+ self.assertLabeledTensorsEqual(map_lt, golden_lt)
+
+
+class FoldlTest(Base):
+
+ def test_name(self):
+ foldl_lt = ops.foldl(core.add, self.original_lt,
+ core.slice_function(self.original_lt, {'x': 0}))
+ self.assertIn('lt_foldl', foldl_lt.name)
+
+ def test_sum(self):
+ initializer_lt = ops.constant([0, 10], axes=['y'])
+ tensor_lt = ops.constant([[1, 2], [3, 4], [5, 6]], axes=['x', 'y'])
+ foldl_lt = ops.foldl(core.add, tensor_lt, initializer_lt)
+ golden_lt = ops.constant([9, 22], axes=['y'])
+ self.assertLabeledTensorsEqual(foldl_lt, golden_lt)
+
class SqueezeTest(Base):