aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/labeled_tensor
diff options
context:
space:
mode:
authorGravatar Stephan Hoyer <shoyer@google.com>2016-11-14 17:24:44 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-14 17:44:53 -0800
commit9d20f4ea4b0b5792bf88ef886d0143b7aa780522 (patch)
tree7007220d84d18a058a7c5ed02a695af728e15a3e /tensorflow/contrib/labeled_tensor
parent887892a499590fd24a052074d5d32ae9393e3a35 (diff)
Initial version of tf.contrib.labeled_tensor
Change: 139143754
Diffstat (limited to 'tensorflow/contrib/labeled_tensor')
-rw-r--r--tensorflow/contrib/labeled_tensor/BUILD166
-rw-r--r--tensorflow/contrib/labeled_tensor/README.md8
-rw-r--r--tensorflow/contrib/labeled_tensor/__init__.py139
-rw-r--r--tensorflow/contrib/labeled_tensor/python/ops/_typecheck.py322
-rw-r--r--tensorflow/contrib/labeled_tensor/python/ops/core.py1197
-rw-r--r--tensorflow/contrib/labeled_tensor/python/ops/core_test.py842
-rw-r--r--tensorflow/contrib/labeled_tensor/python/ops/io_ops.py178
-rw-r--r--tensorflow/contrib/labeled_tensor/python/ops/io_ops_test.py106
-rw-r--r--tensorflow/contrib/labeled_tensor/python/ops/nn.py42
-rw-r--r--tensorflow/contrib/labeled_tensor/python/ops/nn_test.py70
-rw-r--r--tensorflow/contrib/labeled_tensor/python/ops/ops.py1207
-rw-r--r--tensorflow/contrib/labeled_tensor/python/ops/ops_test.py918
-rw-r--r--tensorflow/contrib/labeled_tensor/python/ops/sugar.py131
-rw-r--r--tensorflow/contrib/labeled_tensor/python/ops/sugar_test.py106
-rw-r--r--tensorflow/contrib/labeled_tensor/python/ops/test_util.py47
15 files changed, 5479 insertions, 0 deletions
diff --git a/tensorflow/contrib/labeled_tensor/BUILD b/tensorflow/contrib/labeled_tensor/BUILD
new file mode 100644
index 0000000000..82d9dc9a45
--- /dev/null
+++ b/tensorflow/contrib/labeled_tensor/BUILD
@@ -0,0 +1,166 @@
+# Description:
+# Labels for TensorFlow.
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package(default_visibility = ["//tensorflow:__subpackages__"])
+
+py_library(
+ name = "labeled_tensor",
+ srcs = ["__init__.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":core",
+ ":io_ops",
+ ":nn",
+ ":ops",
+ ":sugar",
+ ],
+)
+
+py_library(
+ name = "_typecheck",
+ srcs = ["python/ops/_typecheck.py"],
+ srcs_version = "PY2AND3",
+ visibility = [":__subpackages__"],
+)
+
+py_library(
+ name = "core",
+ srcs = ["python/ops/core.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":_typecheck",
+ ],
+)
+
+py_library(
+ name = "test_util",
+ srcs = ["python/ops/test_util.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":_typecheck",
+ ":core",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_test(
+ name = "core_test",
+ size = "small",
+ srcs = [
+ "python/ops/core_test.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":core",
+ ":test_util",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_library(
+ name = "io_ops",
+ srcs = ["python/ops/io_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":core",
+ ],
+)
+
+py_test(
+ name = "io_ops_test",
+ size = "small",
+ srcs = [
+ "python/ops/io_ops_test.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":io_ops",
+ ":ops",
+ ":test_util",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_library(
+ name = "nn",
+ srcs = ["python/ops/nn.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":core",
+ ],
+)
+
+py_test(
+ name = "nn_test",
+ size = "small",
+ srcs = [
+ "python/ops/nn_test.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":nn",
+ ":test_util",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_library(
+ name = "ops",
+ srcs = ["python/ops/ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":core",
+ ],
+)
+
+py_test(
+ name = "ops_test",
+ srcs = [
+ "python/ops/ops_test.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":ops",
+ ":test_util",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_library(
+ name = "sugar",
+ srcs = ["python/ops/sugar.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":core",
+ ":ops",
+ ],
+)
+
+py_test(
+ name = "sugar_test",
+ size = "small",
+ srcs = [
+ "python/ops/sugar_test.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":sugar",
+ ":test_util",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+)
diff --git a/tensorflow/contrib/labeled_tensor/README.md b/tensorflow/contrib/labeled_tensor/README.md
new file mode 100644
index 0000000000..50c6750fd0
--- /dev/null
+++ b/tensorflow/contrib/labeled_tensor/README.md
@@ -0,0 +1,8 @@
+# Labels for TensorFlow
+
+LabeledTensor is a library for adding semantically meaningful dimension and
+coordinate labels to tensors in Tensorflow.
+
+Maintainers:
+- Stephan Hoyer (shoyer@google.com, github.com/shoyer)
+- Eric Christiansen (ericmc@google.com, github.com/emchristiansen)
diff --git a/tensorflow/contrib/labeled_tensor/__init__.py b/tensorflow/contrib/labeled_tensor/__init__.py
new file mode 100644
index 0000000000..75299a3a0e
--- /dev/null
+++ b/tensorflow/contrib/labeled_tensor/__init__.py
@@ -0,0 +1,139 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Labels for TensorFlow."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.labeled_tensor.python.ops import core as _core
+from tensorflow.contrib.labeled_tensor.python.ops import io_ops as _io_ops
+from tensorflow.contrib.labeled_tensor.python.ops import nn
+from tensorflow.contrib.labeled_tensor.python.ops import ops as _ops
+from tensorflow.contrib.labeled_tensor.python.ops import sugar as _sugar
+
+# pylint: disable=invalid-name
+
+# Core types.
+Axis = _core.Axis
+Axes = _core.Axes
+LabeledTensor = _core.LabeledTensor
+
+as_axis = _core.as_axis
+convert_to_labeled_tensor = _core.convert_to_labeled_tensor
+
+identity = _core.identity
+slice = _core.slice_function # pylint: disable=redefined-builtin
+transpose = _core.transpose
+expand_dims = _core.expand_dims
+align = _core.align
+
+axis_order_scope = _core.axis_order_scope
+check_axis_order = _core.check_axis_order
+impose_axis_order = _core.impose_axis_order
+AxisOrderError = _core.AxisOrderError
+
+define_unary_op = _core.define_unary_op
+define_binary_op = _core.define_binary_op
+define_reduce_op = _ops.define_reduce_op
+
+abs = _core.abs_function # pylint: disable=redefined-builtin
+neg = _core.neg
+sign = _core.sign
+inv = _core.inv
+square = _core.square
+round = _core.round_function # pylint: disable=redefined-builtin
+sqrt = _core.sqrt
+rsqrt = _core.rsqrt
+exp = _core.exp
+log = _core.log
+ceil = _core.ceil
+floor = _core.floor
+cos = _core.cos
+sin = _core.sin
+tan = _core.tan
+acos = _core.acos
+asin = _core.asin
+atan = _core.atan
+lgamma = _core.lgamma
+digamma = _core.digamma
+erf = _core.erf
+erfc = _core.erfc
+logical_not = _core.logical_not
+
+add = _core.add
+sub = _core.sub
+mul = _core.mul
+div = _core.div
+mod = _core.mod
+pow = _core.pow_function # pylint: disable=redefined-builtin
+
+equal = _core.equal
+greater = _core.greater
+greater_equal = _core.greater_equal
+not_equal = _core.not_equal
+less = _core.less
+less_equal = _core.less_equal
+logical_and = _core.logical_and
+logical_or = _core.logical_or
+logical_xor = _core.logical_xor
+
+maximum = _core.maximum
+minimum = _core.minimum
+squared_difference = _core.squared_difference
+igamma = _core.igamma
+igammac = _core.igammac
+zeta = _core.zeta
+polygamma = _core.polygamma
+
+select = _ops.select
+concat = _ops.concat
+pack = _ops.pack
+unpack = _ops.unpack
+reshape = _ops.reshape
+rename_axis = _ops.rename_axis
+random_crop = _ops.random_crop
+map_fn = _ops.map_fn
+squeeze = _ops.squeeze
+matmul = _ops.matmul
+tile = _ops.tile
+pad = _ops.pad
+constant = _ops.constant
+zeros_like = _ops.zeros_like
+ones_like = _ops.ones_like
+cast = _ops.cast
+verify_tensor_all_finite = _ops.verify_tensor_all_finite
+boolean_mask = _ops.boolean_mask
+where = _ops.where
+
+reduce_all = _ops.reduce_all
+reduce_any = _ops.reduce_any
+reduce_logsumexp = _ops.reduce_logsumexp
+reduce_max = _ops.reduce_max
+reduce_mean = _ops.reduce_mean
+reduce_min = _ops.reduce_min
+reduce_prod = _ops.reduce_prod
+reduce_sum = _ops.reduce_sum
+
+batch = _ops.batch
+shuffle_batch = _ops.shuffle_batch
+
+FixedLenFeature = _io_ops.FixedLenFeature
+parse_example = _io_ops.parse_example
+parse_single_example = _io_ops.parse_single_example
+placeholder = _io_ops.placeholder
+
+ReshapeCoder = _sugar.ReshapeCoder
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/_typecheck.py b/tensorflow/contrib/labeled_tensor/python/ops/_typecheck.py
new file mode 100644
index 0000000000..4a939cb22c
--- /dev/null
+++ b/tensorflow/contrib/labeled_tensor/python/ops/_typecheck.py
@@ -0,0 +1,322 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Minimal runtime type checking library.
+
+This module should not be considered public API.
+"""
+# TODO(ericmc,shoyer): Delete this in favor of using pytype or mypy
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import functools
+import inspect
+import re
+
+
+# used for register_type_abbreviation and _type_repr below.
+_TYPE_ABBREVIATIONS = {}
+
+
+class Type(object):
+ """Base class for type checker types.
+
+ The custom types defined in this module are based on types in the standard
+ library's typing module (in Python 3.5):
+ https://docs.python.org/3/library/typing.html
+
+ The only difference should be that we use actual instances of Type classes to
+ represent custom types rather than the metaclass magic typing uses to create
+ new class objects. In practice, all this should mean is that we use
+ `List(int)` rather than `List[int]`.
+
+ Custom types should implement __instancecheck__ and inherit from Type. Every
+ argument in the constructor must be a type or Type instance, and these
+ arguments must be stored as a tuple on the `_types` attribute.
+ """
+
+ def __init__(self, *types):
+ self._types = types
+
+ def __repr__(self):
+ args_repr = ", ".join(repr(t) for t in self._types)
+ return "typecheck.%s(%s)" % (type(self).__name__, args_repr)
+
+
+class _SingleArgumentType(Type):
+ """Use this subclass for parametric types that accept only one argument."""
+
+ def __init__(self, tpe):
+ super(_SingleArgumentType, self).__init__(tpe)
+
+ @property
+ def _type(self):
+ tpe, = self._types # pylint: disable=unbalanced-tuple-unpacking
+ return tpe
+
+
+class _TwoArgumentType(Type):
+ """Use this subclass for parametric types that accept two arguments."""
+
+ def __init__(self, first_type, second_type):
+ super(_TwoArgumentType, self).__init__(first_type, second_type)
+
+
+class Union(Type):
+ """A sum type.
+
+ A correct type is any of the types provided.
+ """
+
+ def __instancecheck__(self, instance):
+ return isinstance(instance, self._types)
+
+
+class Optional(_SingleArgumentType):
+ """An optional type.
+
+ A correct type is either the provided type or NoneType.
+ """
+
+ def __instancecheck__(self, instance):
+ # types.NoneType does not exist in Python 3
+ return isinstance(instance, (self._type, type(None)))
+
+
+class List(_SingleArgumentType):
+ """A typed list.
+
+ A correct type is a list where each element has the single provided type.
+ """
+
+ def __instancecheck__(self, instance):
+ return (isinstance(instance, list)
+ and all(isinstance(x, self._type) for x in instance))
+
+
+class Sequence(_SingleArgumentType):
+ """A typed sequence.
+
+ A correct type is a sequence where each element has the single provided type.
+ """
+
+ def __instancecheck__(self, instance):
+ return (isinstance(instance, collections.Sequence)
+ and all(isinstance(x, self._type) for x in instance))
+
+
+class Collection(_SingleArgumentType):
+ """A sized, iterable container.
+
+ A correct type is an iterable and container with known size where each element
+ has the single provided type.
+
+ We use this in preference to Iterable because we check each instance of the
+ iterable at runtime, and hence need to avoid iterables that could be
+ exhausted.
+ """
+
+ def __instancecheck__(self, instance):
+ return (isinstance(instance, collections.Iterable)
+ and isinstance(instance, collections.Sized)
+ and isinstance(instance, collections.Container)
+ and all(isinstance(x, self._type) for x in instance))
+
+
+class Tuple(Type):
+ """A typed tuple.
+
+ A correct type is a tuple with the correct length where each element has
+ the correct type.
+ """
+
+ def __instancecheck__(self, instance):
+ return (isinstance(instance, tuple)
+ and len(instance) == len(self._types)
+ and all(isinstance(x, t) for x, t in zip(instance, self._types)))
+
+
+class Mapping(_TwoArgumentType):
+ """A typed mapping.
+
+ A correct type has the correct parametric types for keys and values.
+ """
+
+ def __instancecheck__(self, instance):
+ key_type, value_type = self._types # pylint: disable=unbalanced-tuple-unpacking
+ return (isinstance(instance, collections.Mapping)
+ and all(isinstance(k, key_type) for k in instance.keys())
+ and all(isinstance(k, value_type) for k in instance.values()))
+
+
+class Dict(Mapping):
+ """A typed dict.
+
+ A correct type has the correct parametric types for keys and values.
+ """
+
+ def __instancecheck__(self, instance):
+ return (isinstance(instance, dict)
+ and super(Dict, self).__instancecheck__(instance))
+
+
+def _replace_forward_references(t, context):
+ """Replace forward references in the given type."""
+ if isinstance(t, str):
+ return context[t]
+ elif isinstance(t, Type):
+ return type(t)(*[_replace_forward_references(t, context) for t in t._types]) # pylint: disable=protected-access
+ else:
+ return t
+
+
+def register_type_abbreviation(name, alias):
+ """Register an abbreviation for a type in typecheck tracebacks.
+
+ This makes otherwise very long typecheck errors much more readable.
+
+ Example:
+ typecheck.register_type_abbreviation(tf.Dimension, 'tf.Dimension')
+
+ Args:
+ name: type or class to abbreviate.
+ alias: string alias to substitute.
+ """
+ _TYPE_ABBREVIATIONS[name] = alias
+
+
+def _type_repr(t):
+ """A more succinct repr for typecheck tracebacks."""
+ string = repr(t)
+ for type_, alias in _TYPE_ABBREVIATIONS.items():
+ string = string.replace(repr(type_), alias)
+ string = re.sub(r"<(class|type) '([\w.]+)'>", r"\2", string)
+ string = re.sub(r"typecheck\.(\w+)", r"\1", string)
+ return string
+
+
+class Error(TypeError):
+ """Exception for typecheck failures."""
+
+
+def accepts(*types):
+ """A decorator which checks the input types of a function.
+
+ Based on:
+ http://stackoverflow.com/questions/15299878/how-to-use-python-decorators-to-check-function-arguments
+ The above draws from:
+ https://www.python.org/dev/peps/pep-0318/
+
+ Args:
+ *types: A list of Python types.
+
+ Returns:
+ A function to use as a decorator.
+ """
+
+ def check_accepts(f):
+ """Check the types."""
+ spec = inspect.getargspec(f)
+
+ num_function_arguments = len(spec.args)
+ if len(types) != num_function_arguments:
+ raise Error(
+ "Function %r has %d arguments but only %d types were provided in the "
+ "annotation." % (f, num_function_arguments, len(types)))
+
+ if spec.defaults:
+ num_defaults = len(spec.defaults)
+ for (name, a, t) in zip(spec.args[-num_defaults:],
+ spec.defaults,
+ types[-num_defaults:]):
+ allowed_type = _replace_forward_references(t, f.__globals__)
+ if not isinstance(a, allowed_type):
+ raise Error("default argument value %r of type %r is not an instance "
+ "of the allowed type %s for the %s argument to %r"
+ % (a, type(a), _type_repr(allowed_type), name, f))
+
+ @functools.wraps(f)
+ def new_f(*args, **kwds):
+ """A helper function."""
+ for (a, t) in zip(args, types):
+ allowed_type = _replace_forward_references(t, f.__globals__)
+ if not isinstance(a, allowed_type):
+ raise Error("%r of type %r is not an instance of the allowed type %s "
+ "for %r" % (a, type(a), _type_repr(allowed_type), f))
+ return f(*args, **kwds)
+
+ return new_f
+
+ return check_accepts
+
+
+def returns(*types):
+ """A decorator which checks the return types of a function.
+
+ Based on:
+ http://stackoverflow.com/questions/15299878/how-to-use-python-decorators-to-check-function-arguments
+ The above draws from:
+ https://www.python.org/dev/peps/pep-0318/
+
+ Args:
+ *types: A list of Python types.
+ A list of one element corresponds to a single return value.
+ A list of several elements corresponds to several return values.
+ Note that a function with no explicit return value has an implicit
+ NoneType return and should be annotated correspondingly.
+
+ Returns:
+ A function to use as a decorator.
+ """
+
+ def check_returns(f):
+ """Check the types."""
+ if not types:
+ raise TypeError("A return type annotation must contain at least one type")
+
+ @functools.wraps(f)
+ def new_f(*args, **kwds):
+ """A helper function."""
+ return_value = f(*args, **kwds)
+
+ if len(types) == 1:
+ # The function has a single return value.
+ allowed_type = _replace_forward_references(types[0], f.__globals__)
+ if not isinstance(return_value, allowed_type):
+ raise Error("%r of type %r is not an instance of the allowed type %s "
+ "for %r"
+ % (return_value, type(return_value),
+ _type_repr(allowed_type), f))
+
+ else:
+ if len(return_value) != len(types):
+ raise Error(
+ "Function %r has %d return values but only %d types were "
+ "provided in the annotation." %
+ (f, len(return_value), len(types)))
+
+ for (r, t) in zip(return_value, types):
+ allowed_type = _replace_forward_references(t, f.__globals__)
+ if not isinstance(r, allowed_type):
+ raise Error("%r of type %r is not an instance of allowed type %s "
+ "for %r" % (r, type(r), _type_repr(allowed_type), f))
+
+ return return_value
+
+ return new_f
+
+ return check_returns
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/core.py b/tensorflow/contrib/labeled_tensor/python/ops/core.py
new file mode 100644
index 0000000000..69fd06133b
--- /dev/null
+++ b/tensorflow/contrib/labeled_tensor/python/ops/core.py
@@ -0,0 +1,1197 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Core classes and core ops for LabeledTensor.
+
+Core ops are ops which will eventually be called by LabeledTensor methods,
+and ops which a core op depends upon.
+For example, `add` is a core op because we'll eventually support the `+`
+operator.
+Non-core ops should go in `ops.py`.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import contextlib
+import numbers
+import types
+
+import numpy as np
+from six import binary_type
+from six import string_types
+from six import text_type
+from six.moves import range # pylint: disable=redefined-builtin
+
+from tensorflow.contrib.labeled_tensor.python.ops import _typecheck as tc
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+
+
+# pylint: disable=invalid-name
+
+# Types coercible to Axis.labels
+# We use this instead of collections.Sequence to exclude strings.
+LabelsLike = tc.Union(np.ndarray, range, list, tuple)
+
+# Types coercible to a tf.Dimension
+DimensionLike = tc.Optional(tc.Union(tensor_shape.Dimension, int))
+
+# Types usable for axis values
+AxisValue = tc.Union(LabelsLike, DimensionLike)
+
+# Valid scalar values for TensorFlow
+Scalar = tc.Union(numbers.Number, bool, binary_type, text_type)
+
+# pylint: enable=invalid-name
+
+
+class Axis(object):
+ """Size and label information for an axis.
+
+ Axis contains either a tf.Dimension indicating the size of an axis,
+ or a tuple of tick labels for the axis.
+
+ If tick labels are provided, they must be unique.
+ """
+
+ @tc.accepts(object, string_types, AxisValue)
+ def __init__(self, name, value):
+ """Construct an Axis.
+
+ Args:
+ name: Name of the axis.
+ value: Either None, an int or tf.Dimension giving the size of the axis,
+ or a sequence that is not a string additionally providing coordinate
+ (tick) labels.
+
+ Raises:
+ ValueError: If the user provides labels with duplicate values.
+ """
+ if isinstance(value, tensor_shape.Dimension):
+ dimension = value
+ labels = None
+ elif isinstance(value, int) or value is None:
+ dimension = tensor_shape.Dimension(value)
+ labels = None
+ else:
+ dimension = tensor_shape.Dimension(len(value))
+ labels = tuple(value)
+
+ if dimension.value == 0:
+ # Treat a zero-length axis as if it has labels.
+ labels = ()
+
+ if labels is not None:
+ index = dict(zip(labels, range(len(labels))))
+ if len(index) != len(labels):
+ raise ValueError('Tick labels must be unique, but got {}'
+ .format(labels))
+ else:
+ index = None
+
+ self._name = name # type: string_types
+ self._dimension = dimension # type: tensor_shape.Dimension
+ self._labels = labels # type: Optional[tuple]
+ self._index = index # type: Optional[Dict[Any, int]]
+
+ @property
+ @tc.returns(string_types)
+ def name(self):
+ return self._name
+
+ @tc.returns(string_types)
+ def __repr__(self):
+ # Axis('x', Dimension(2))
+ # TODO(shoyer): make very long reprs more succint?
+ return "%s('%s', %r)" % (type(self).__name__, self.name, self.value)
+
+ @tc.returns(bool)
+ def __eq__(self, other):
+ return (isinstance(other, Axis) and
+ self.name == other.name and
+ self.size == other.size and
+ self.labels == other.labels)
+
+ def __hash__(self):
+ return hash((self.name, self.size, self.labels))
+
+ @tc.returns(bool)
+ def __ne__(self, other):
+ return not self == other
+
+ @tc.returns(int)
+ def __len__(self):
+ size = self.size
+ if size is None:
+ raise ValueError('axis %r has unknown length' % self.name)
+ return size
+
+ @property
+ @tc.returns(tc.Optional(tensor_shape.Dimension))
+ def dimension(self):
+ return self._dimension
+
+ @property
+ @tc.returns(tc.Optional(int))
+ def size(self):
+ return self._dimension.value
+
+ @property
+ @tc.returns(tc.Union(tuple, tensor_shape.Dimension))
+ def value(self):
+ """Returns the tf.Dimension or tuple specifying axis ticks."""
+ if self.labels is None:
+ return self.dimension
+ else:
+ return self.labels
+
+ @property
+ @tc.returns(tc.Optional(tuple))
+ def labels(self):
+ """Returns the tuple containing coordinate labels, else None."""
+ return self._labels
+
+ def index(self, value):
+ """Returns the integer position of the given tick label."""
+ if self._index is None:
+ raise ValueError('Axis does not have tick labels')
+ return self._index[value]
+
+
+# tc class for anything that can be coerced into an Axis
+# pylint: disable=invalid-name
+AxisLike = tc.Union(Axis, tc.Tuple(string_types, AxisValue))
+# pylint: enable=invalid-name
+
+
+@tc.returns(Axis)
+@tc.accepts(AxisLike)
+def as_axis(axis_data):
+ """Convert an AxisLike object into an Axis.
+
+ Args:
+ axis_data: Axis object or tuple (axis_name, axis_value) describing an axis.
+
+ Returns:
+ Axis object. This may be the original object if axis_data is an Axis.
+ """
+ if isinstance(axis_data, Axis):
+ axis = axis_data
+ else:
+ axis = Axis(*axis_data)
+ return axis
+
+
+class Axes(collections.Mapping):
+ """Axis names and indices for a tensor.
+
+ It is an ordered mapping, with keys given by axis name and values given
+ by Axis objets. Duplicate axis names are not allowed.
+ """
+
+ @tc.accepts(object, tc.List(AxisLike))
+ def __init__(self, axes):
+ """Construct an Axes.
+
+ Args:
+ axes: A list of Axis objects or (axis_name, axis_value) tuples.
+
+ Raises:
+ ValueError: If the user provides empty or duplicate axis names.
+ """
+ self._axes = collections.OrderedDict()
+
+ for axis_data in axes:
+ axis = as_axis(axis_data)
+
+ name = axis.name
+ if name in self._axes:
+ raise ValueError('Duplicate axis name: %s' % name)
+
+ self._axes[name] = axis
+
+ def __iter__(self):
+ return iter(self._axes)
+
+ @tc.returns(string_types)
+ def __repr__(self):
+ # Axes([('x', Dimension(2)),
+ # ('y', ['a', 'b', 'c']),
+ # ('z', Dimension(4))])
+ cls_name = type(self).__name__
+ values = ["('%s', %r)" % (v.name, v.value) for v in self._axes.values()]
+ values_repr = (',\n' + ' ' * len(cls_name + '([')).join(values)
+ return '%s([%s])' % (cls_name, values_repr)
+
+ @tc.returns(Axis)
+ @tc.accepts(object, string_types)
+ def __getitem__(self, name):
+ return self._axes[name]
+
+ @tc.returns(bool)
+ def __contains__(self, name):
+ return name in self._axes
+
+ @tc.returns(int)
+ def __len__(self):
+ return len(self._axes)
+
+ def __hash__(self):
+ return hash(tuple(self.items()))
+
+ @tc.accepts(object, string_types)
+ def remove(self, axis_name):
+ """Creates a new Axes object without the given axis."""
+ if axis_name not in self:
+ raise KeyError(axis_name)
+ remaining_axes = [axis for axis in self.values() if axis.name != axis_name]
+ return Axes(remaining_axes)
+
+
+class LabeledTensor(object):
+ """A tensor with annotated axes.
+
+ It has the following invariants:
+ 1) The dimensionality of the tensor is equal to the number of elements
+ in axes.
+ 2) The number of coordinate values in the ith dimension is equal to the
+ size of the tensor in the ith dimension.
+
+ Attributes:
+ tensor: tf.Tensor containing the data.
+ axes: lt.Axes containing axis names and coordinate labels.
+ """
+
+ @tc.accepts(object, ops.Output,
+ tc.Union(Axes, tc.Collection(tc.Union(string_types, AxisLike))))
+ def __init__(self, tensor, axes):
+ """Construct a LabeledTenor.
+
+ Args:
+ tensor: The underlying tensor containing the data.
+ axes: An Axes object, or a collection of strings, Axis objects or tuples
+ of (name, value) pairs indicating the axes.
+
+ Raises:
+ ValueError: If the provided axes do not satisfy the class invariants.
+ """
+ self._tensor = tensor
+ shape = tensor.get_shape()
+
+ if isinstance(axes, Axes):
+ unvalidated_axes = axes
+ else:
+ mutable_axes = []
+
+ for position, axis_like in enumerate(axes):
+ if isinstance(axis_like, string_types):
+ # The coordinates for this axes are unlabeled.
+ # Infer the size of the axis.
+ value = shape[position]
+ axis_like = (axis_like, value)
+
+ mutable_axes.append(axis_like)
+
+ # Construct the Axis object, which will additionally validate the contents
+ # of the object.
+ unvalidated_axes = Axes(mutable_axes)
+
+ # Check our invariants.
+
+ # First, the rank of the tensor must be equal to the number of axes.
+ if len(shape) != len(unvalidated_axes):
+ raise ValueError('Tensor rank was not equal to the number of axes: %r, %r'
+ % (shape, unvalidated_axes))
+
+ # Second, the size of each tensor dimension must match the size of the
+ # corresponding indices.
+ for (d, axis) in zip(shape, unvalidated_axes.values()):
+ if d != axis.size:
+ raise ValueError(
+ 'Provided axis size %d does not match tensor dimension size %d' %
+ (axis.size, d))
+
+ self._axes = unvalidated_axes
+
+ def __repr__(self):
+ # <LabeledTensor 'foo' shape=(2, 3, 4) dtype=float32
+ # axes=[('x', Dimension(2)),
+ # ('y', ('a', 'b', 'c'),
+ # ('z', Dimension(4))]>
+ axes = ["('%s', %r)" % (v.name, v.value) for v in self.axes.values()]
+ axes_repr = (',\n' + ' ' * len(' axes=[')).join(axes)
+ return ("<%s '%s' shape=%s dtype=%s\n axes=[%s]>" %
+ (type(self).__name__, self.tensor.name, self.tensor.get_shape(),
+ self.tensor.dtype.name, axes_repr))
+
+ @property
+ def tensor(self):
+ return self._tensor
+
+ def _as_graph_element(self):
+ """Support tf.Graph.as_graph_element on LabeledTensor objects.
+
+ This allows operations such as tf.name_scope to take labeled tensors.
+
+ Returns:
+ self.tensor
+ """
+ return self.tensor
+
+ @property
+ def axes(self):
+ return self._axes
+
+ # properties/methods directly borrowed from tf.Tensor:
+
+ @property
+ def dtype(self):
+ return self._tensor.dtype
+
+ @property
+ def name(self):
+ return self._tensor.name
+
+ def get_shape(self):
+ """Returns the TensorShape that represents the shape of this tensor.
+
+ See tf.Tensor.get_shape().
+
+ Returns:
+ A TensorShape representing the shape of this tensor.
+ """
+ return self._tensor.get_shape()
+
+ # TODO(shoyer): consider how/if to implement .eval(). Maybe it should return
+ # an xarray.DataArray?
+
+ def __getitem__(self, key):
+ # This should work exactly like tf.Tensor.__getitem__, except it preserves
+ # labels.
+ if not isinstance(key, tuple):
+ key = (key,)
+ if len(key) != len(self.axes):
+ raise ValueError('indexer %r must have the same length as the Tensor '
+ 'rank (%r)' % (key, len(self.axes)))
+ selection = {a: k for a, k in zip(self.axes.keys(), key)}
+ return slice_function(self, selection)
+
+ # special methods for overloading arithmetic operations:
+
+ def __abs__(self):
+ return abs_function(self)
+
+ def __neg__(self):
+ return neg(self)
+
+ def __pos__(self):
+ return self
+
+ def __add__(self, other):
+ return add(self, other)
+
+ def __radd__(self, other):
+ return add(other, self)
+
+ def __sub__(self, other):
+ return sub(self, other)
+
+ def __rsub__(self, other):
+ return sub(other, self)
+
+ def __mul__(self, other):
+ return mul(self, other)
+
+ def __rmul__(self, other):
+ return mul(other, self)
+
+ def __truediv__(self, other):
+ return div(self, other)
+
+ __div__ = __truediv__
+
+ def __rtruediv__(self, other):
+ return div(other, self)
+
+ __rdiv__ = __rtruediv__
+
+ def __mod__(self, other):
+ return mod(self, other)
+
+ def __rmod__(self, other):
+ return mod(other, self)
+
+ def __pow__(self, other):
+ return pow_function(self, other)
+
+ def __rpow__(self, other):
+ return pow_function(other, self)
+
+ # logical operations:
+
+ def __invert__(self):
+ return logical_not(self)
+
+ def __and__(self, other):
+ return logical_and(self, other)
+
+ def __or__(self, other):
+ return logical_or(self, other)
+
+ def __xor__(self, other):
+ return logical_xor(self, other)
+
+ # boolean operations:
+
+ def __lt__(self, other):
+ return less(self, other)
+
+ def __le__(self, other):
+ return less_equal(self, other)
+
+ def __gt__(self, other):
+ return greater(self, other)
+
+ def __ge__(self, other):
+ return greater_equal(self, other)
+
+ def __eq__(self, other):
+ # for consistency with tf.Tensor
+ if not isinstance(other, LabeledTensor):
+ return False
+
+ return self.tensor == other.tensor and self.axes == other.axes
+
+ def __ne__(self, other):
+ return not self == other
+
+ def __hash__(self):
+ return hash((self.tensor, self.axes))
+
+
+# typecheck type abbreviations:
+# abbreviations for third-party types with very long reprs
+tc.register_type_abbreviation(tensor_shape.Dimension, 'tensorflow.Dimension')
+tc.register_type_abbreviation(ops.Output, 'tensorflow.Output')
+tc.register_type_abbreviation(dtypes.DType, 'tensorflow.DType')
+# core LabeledTensor types
+tc.register_type_abbreviation(Axis, 'labeled_tensor.Axis')
+tc.register_type_abbreviation(Axes, 'labeled_tensor.Axes')
+tc.register_type_abbreviation(LabeledTensor, 'labeled_tensor.LabeledTensor')
+
+
+@tc.returns(ops.Output)
+@tc.accepts(LabeledTensor)
+def _convert_labeled_tensor_to_tensor(value, *args, **kwargs):
+ # call ops.convert_to_tensor to handle optional arguments appropriately
+ return ops.convert_to_tensor(value.tensor, *args, **kwargs)
+
+
+ops.register_tensor_conversion_function(
+ LabeledTensor, _convert_labeled_tensor_to_tensor)
+
+
+# tc class for anything that can be coerced into a LabeledTensor
+# pylint: disable=invalid-name
+LabeledTensorLike = tc.Union(LabeledTensor, ops.Output, np.ndarray, Scalar)
+# pylint: enable=invalid-name
+
+
+@tc.returns(LabeledTensor)
+@tc.accepts(LabeledTensorLike, object, tc.Optional(string_types))
+def convert_to_labeled_tensor(value, dtype=None, name=None):
+ """Converts the given `value` to a `LabeledTensor`.
+
+ This function accepts `LabeledTensor` objects, 0-dimensional `Tensor` objects
+ and numpy arrays, and Python scalars. Higher dimensional unlabeled tensors
+ must use the `LabeledTensor` constructor explicitly.
+
+ Args:
+ value: Object to convert.
+ dtype: Optional element type for the returned tensor. If missing, the type
+ is inferred from the type of value.
+ name: Optional name to use if a new Tensor is created.
+
+ Returns:
+ `value` converted into a `LabeledTensor` object.
+
+ Raises:
+ ValueError: If the output would have rank>0 but the input was not already a
+ `LabeledTensor`.
+ """
+ # TODO(shoyer): consider extending to accept xarray.DataArray as input.
+ if isinstance(value, LabeledTensor):
+ axes = value.axes.values()
+ value = value.tensor
+ else:
+ axes = []
+
+ # We call convert_to_tensor even for LabeledTensor input because it also
+ # checks to make sure the dtype argument is compatible.
+ tensor = ops.convert_to_tensor(value, dtype=dtype, name=name)
+ if len(tensor.get_shape()) != len(axes):
+ raise ValueError('cannot automatically convert unlabeled arrays or tensors '
+ 'with rank>0 into LabeledTensors: %r' % value)
+ return LabeledTensor(tensor, axes)
+
+
+@tc.returns(Axis)
+@tc.accepts(tc.Collection(Axis))
+def concat_axes(axes):
+ """Concatenate a list of Axes.
+
+ Args:
+ axes: A collection of Axis objects.
+
+ Returns:
+ The concatenation of the axes.
+ If all axes have labels, the result has the concatenation of the labels.
+ Else, the result has no labels, and its size is the sum of the sizes
+ of the axes.
+
+ Raises:
+ ValueError: If `others` is not a collection of Axes or if it is empty.
+ """
+ if not axes:
+ raise ValueError('axes must not be empty')
+ for a in axes:
+ if not isinstance(a, Axis):
+ raise ValueError('Expected an Axis, but got %r of type %r' % (a, type(a)))
+
+ names = set(a.name for a in axes)
+ if len(names) > 1:
+ raise ValueError('axes do not all have the same name: %r' % names)
+ name, = names
+
+ all_have_labels = all(a.labels is not None for a in axes)
+ any_has_unknown_size = any(a.size is None for a in axes)
+
+ if all_have_labels:
+ value = tuple(label for a in axes for label in a.labels)
+ elif any_has_unknown_size:
+ value = None
+ else:
+ value = sum(len(a) for a in axes)
+ return Axis(name, value)
+
+
+@tc.returns(LabeledTensor)
+@tc.accepts(LabeledTensorLike, tc.Optional(string_types))
+def identity(labeled_tensor, name=None):
+ """The identity op.
+
+ See tf.identity.
+
+ Args:
+ labeled_tensor: The input tensor.
+ name: Optional op name.
+
+ Returns:
+ The tensor.
+ """
+ with ops.name_scope(name, 'lt_identity', [labeled_tensor]) as scope:
+ labeled_tensor = convert_to_labeled_tensor(labeled_tensor)
+ return LabeledTensor(
+ array_ops.identity(labeled_tensor.tensor, name=scope),
+ labeled_tensor.axes)
+
+
+# We don't call this slice because that shadows a built-in. Instead, we alias
+# this to lt.slice in __init__.py.
+@tc.returns(LabeledTensor)
+@tc.accepts(LabeledTensorLike, tc.Mapping(string_types, tc.Union(int, slice)),
+ tc.Optional(string_types))
+def slice_function(labeled_tensor, selection, name=None):
+ """Slice out a subset of the tensor.
+
+ This is an analogue of tf.slice.
+ For example:
+ >>> tensor = tf.reshape(tf.range(0, 6), [3, 2])
+ >>> labeled_tensor = lt.LabeledTensor(tensor, ['a', ('b', ['foo', 'bar'])])
+ >>> lt.slice(labeled_tensor, {'a': slice(0, 2), 'b': 1})
+ <LabeledTensor 'lt_slice:...' shape=(2,) dtype=int32
+ axes=[('a', Dimension(2))]>
+
+ Args:
+ labeled_tensor: The input tensor.
+ selection: A dictionary of type str -> Union(int, slice of int) mapping
+ axis names to sub-selections.
+ name: Optional op name.
+
+ Returns:
+ The slice as a `LabeledTensor`.
+ """
+ with ops.name_scope(name, 'lt_slice', [labeled_tensor]) as scope:
+ labeled_tensor = convert_to_labeled_tensor(labeled_tensor)
+
+ slices = []
+
+ for axis_name in labeled_tensor.axes:
+ if axis_name not in selection:
+ # We're not sub-selecting this axis, so use the full slice.
+ slices.append(slice(None))
+ else:
+ slices.append(selection[axis_name])
+
+ sliced_tensor = labeled_tensor.tensor[tuple(slices)]
+
+ sliced_axes = []
+ for axis, s in zip(labeled_tensor.axes.values(), slices):
+ # We sub-select this axis's index with the slice s.
+
+ # `s` is either an int or a proper slice.
+ if isinstance(s, slice):
+ if axis.labels is None:
+ # We're not tracking coordinate names for this axis.
+ sliced_axes.append(axis.name)
+ else:
+ sliced_axes.append((axis.name, axis.labels[s]))
+ else:
+ # If the slice is an int this dimension now has size 1, so we remove it.
+ assert isinstance(s, int)
+
+ return LabeledTensor(array_ops.identity(sliced_tensor, name=scope),
+ sliced_axes)
+
+
+@tc.returns(LabeledTensor)
+@tc.accepts(LabeledTensorLike, tc.Optional(tc.Collection(string_types)),
+ tc.Optional(string_types))
+def transpose(labeled_tensor, axis_order=None, name=None):
+ """Permute a tensor's axes.
+
+ See tf.transpose.
+
+ Args:
+ labeled_tensor: The input tensor.
+ axis_order: Optional desired axis order, as a list of names. By default, the
+ order of axes is reversed.
+ name: Optional op name.
+
+ Returns:
+ The permuted tensor.
+
+ Raises:
+ ValueError: If axis_order isn't a permutation of the existing axes.
+ """
+ with ops.name_scope(name, 'lt_transpose', [labeled_tensor]) as scope:
+ labeled_tensor = convert_to_labeled_tensor(labeled_tensor)
+
+ original_order = list(labeled_tensor.axes.keys())
+ if axis_order is None:
+ axis_order = list(reversed(original_order))
+ elif sorted(axis_order) != sorted(original_order):
+ raise ValueError(
+ 'The new axis order must have the same names as the original axes, '
+ 'but the new order is %r while the original order is %r' %
+ (axis_order, original_order))
+
+ axis_names = list(labeled_tensor.axes.keys())
+ permutation = [axis_names.index(n) for n in axis_order]
+
+ # Note: TensorFlow doesn't copy data for the identity tranpose.
+ transpose_tensor = array_ops.transpose(labeled_tensor.tensor,
+ permutation,
+ name=scope)
+
+ permuted_axes = [labeled_tensor.axes[n] for n in axis_order]
+
+ return LabeledTensor(transpose_tensor, permuted_axes)
+
+
+@tc.returns(LabeledTensor)
+@tc.accepts(LabeledTensorLike, tc.Collection(tc.Union(string_types, tc.Tuple(
+ string_types, collections.Hashable))), tc.Optional(string_types))
+def expand_dims(labeled_tensor, axes, name=None):
+ """Insert dimensions of size 1.
+
+ See tf.expand_dims.
+
+ Args:
+ labeled_tensor: The input tensor.
+ axes: The desired axis names as strings or tuples of (name, label),
+ where `label` is the coordinate name for the new dimension `name`.
+ These must include the existing axis names, and the existing names must
+ appear in the same order in this list as they do in the input tensor.
+ name: Optional op name.
+
+ Returns:
+ A tensor with an axis for each axis in axes.
+ New axes are created with size 1 and do not have labeled coordinates.
+
+ Raises:
+ AxisOrderError: If axis names don't appear in the same order in axes
+ and the labeled tensor.
+ """
+ with ops.name_scope(name, 'lt_expand_dims', [labeled_tensor]) as scope:
+ labeled_tensor = convert_to_labeled_tensor(labeled_tensor)
+
+ axis_names = [a if isinstance(a, string_types) else a[0] for a in axes]
+ check_axis_order(labeled_tensor, axis_names)
+
+ reshaped_axes = []
+ shape = []
+ for axis_spec in axes:
+ if axis_spec in labeled_tensor.axes:
+ axis = labeled_tensor.axes[axis_spec]
+ reshaped_axes.append(axis)
+ shape.append(-1 if axis.size is None else axis.size)
+ else:
+ if isinstance(axis_spec, string_types):
+ reshaped_axes.append((axis_spec, 1))
+ else:
+ (name, label) = axis_spec
+ reshaped_axes.append((name, (label,)))
+
+ shape.append(1)
+
+ reshaped_tensor = array_ops.reshape(labeled_tensor.tensor, shape,
+ name=scope)
+
+ return LabeledTensor(reshaped_tensor, reshaped_axes)
+
+# This should only be added to a graph collection once.
+_AXIS_ORDER_KEY = ('__axis_order',)
+
+
+@tc.returns(tc.Optional(tc.List(string_types)))
+def get_axis_order():
+ """Get the axis_order set by any containing axis_order_scope.
+
+ Returns:
+ List of strings giving an order to use for axis names, or None, if no axis
+ order is set.
+ """
+ # By storing axis_order in the graph, we can ensure that axis_order_scope is
+ # thread-safe.
+ axis_order_list = ops.get_collection(_AXIS_ORDER_KEY)
+ if axis_order_list:
+ axis_order, = axis_order_list
+ else:
+ axis_order = None
+ return axis_order
+
+
+@tc.accepts(tc.Optional(tc.List(string_types)))
+def _set_axis_order(axis_order):
+ axis_order_list = ops.get_collection_ref(_AXIS_ORDER_KEY)
+ if axis_order_list:
+ axis_order_list[0] = axis_order
+ else:
+ axis_order_list.append(axis_order)
+
+
+@contextlib.contextmanager
+@tc.accepts(tc.Optional(tc.List(string_types)))
+def axis_order_scope(axis_order=None):
+ """Set axis order for the result of broadcasting operations within a scope.
+
+ This allows you to ensure that tensors resulting from arithmetic have a
+ predictable axis order.
+
+ Example usage:
+
+ with lt.axis_order_scope(['x', 'y', 'z']):
+ # result is guranteed to have the correct axis order
+ result = w + b
+
+ You can nest scopes, in which case only the inner-most scope applies, e.g.,
+
+ with lt.axis_order(['x', 'y', 'z']):
+ with lt.axis_order():
+ result = w + b # uses the default (left-most) axis ordering
+
+ Args:
+ axis_order: optional list of strings providing axis names. By default,
+ creates a scope without axis order.
+
+ Yields:
+ The provided axis_order or `None`.
+ """
+ original_axis_order = get_axis_order()
+ _set_axis_order(axis_order)
+ try:
+ yield axis_order
+ finally:
+ _set_axis_order(original_axis_order)
+
+
+@tc.returns(tc.List(string_types))
+def _get_valid_axis_order():
+ axis_order = get_axis_order()
+ if axis_order is None:
+ raise AxisOrderError('an explicit axis order must be provided with the '
+ 'axis_order argument or by using an axis_order_scope')
+ return axis_order
+
+
+class AxisOrderError(ValueError):
+ """Error class for cases where there is no valid axis order."""
+
+
+# TODO(shoyer): should this function accept a list of labeled tensors instead?
+@tc.returns(type(None))
+@tc.accepts(LabeledTensorLike, tc.Optional(tc.Collection(string_types)))
+def check_axis_order(labeled_tensor, axis_order=None):
+ """Verify that the given tensor has a consistent axis order.
+
+ Args:
+ labeled_tensor: The input tensor. All axes on this tensor must appear in
+ axis_order.
+ axis_order: Optional desired axis order, as a list of names. If not
+ provided, defaults to the current axis_order_scope (if set).
+
+ Raises:
+ AxisOrderError: If the axis_order is unavailable, inconsistent or does not
+ include all existing axes.
+ """
+ labeled_tensor = convert_to_labeled_tensor(labeled_tensor)
+
+ if axis_order is None:
+ axis_order = _get_valid_axis_order()
+
+ relevant_axis_order = [a for a in axis_order if a in labeled_tensor.axes]
+
+ if len(relevant_axis_order) < len(labeled_tensor.axes):
+ raise AxisOrderError(
+ 'not all axis names appear in the required axis order %r: %r' %
+ (axis_order, labeled_tensor))
+
+ if relevant_axis_order != list(labeled_tensor.axes):
+ raise AxisOrderError(
+ 'axes on a labeled tensor do not appear in the same order as the '
+ 'required axis order %r: %r' % (axis_order, labeled_tensor))
+
+
+@tc.returns(LabeledTensor)
+@tc.accepts(LabeledTensorLike, tc.Optional(tc.Collection(string_types)),
+ tc.Optional(string_types))
+def impose_axis_order(labeled_tensor, axis_order=None, name=None):
+ """Impose desired axis order on a labeled tensor.
+
+ Args:
+ labeled_tensor: The input tensor.
+ axis_order: Optional desired axis order, as a list of names. If not
+ provided, defaults to the current axis_order_scope (if set).
+ name: Optional op name.
+
+ Returns:
+ Labeled tensor with possibly transposed axes.
+
+ Raises:
+ AxisOrderError: If no axis_order is provided or axis_order does not contain
+ all axes on the input tensor.
+ """
+ with ops.name_scope(name, 'lt_impose_axis_order', [labeled_tensor]) as scope:
+ labeled_tensor = convert_to_labeled_tensor(labeled_tensor)
+
+ if axis_order is None:
+ axis_order = _get_valid_axis_order()
+
+ relevant_axis_order = [a for a in axis_order if a in labeled_tensor.axes]
+
+ return transpose(labeled_tensor, relevant_axis_order, name=scope)
+
+
+@tc.returns(tc.Optional(list))
+@tc.accepts(list, list)
+def _find_consistent_ordering(a, b):
+ """Find the left-most consistent ordering between two lists of unique items.
+
+ A consistent ordering combines all elements in both a and b while keeping all
+ elements in their original order in both inputs. The left-most consistent
+ ordering orders elements from `a` not found in `b` before elements in `b` not
+ found in `a`.
+
+ For example, given ['x', 'z'] and ['y', 'z'], both ['x', 'y', 'z'] and ['y',
+ 'x', 'z'] are consistent orderings because each of the inputs appears in
+ each consistent ordering in the same order, and ['x', 'y', 'z'] is the
+ left-most, because 'x' appears only in `a` and 'y' appears only in `b`. In
+ contrast, there is no consistent ordering between ['x', 'y'] and ['y', 'x'].
+
+ Args:
+ a: list with unique elements.
+ b: list with unique elements.
+
+ Returns:
+ List containing all elements in either a or b, or None, if no consistent
+ ordering exists.
+ """
+ a_set = set(a)
+ b_set = set(b)
+ i = 0
+ j = 0
+ ordering = []
+ while i < len(a) and j < len(b):
+ if a[i] not in b_set:
+ ordering.append(a[i])
+ i += 1
+ elif b[j] not in a_set:
+ ordering.append(b[j])
+ j += 1
+ elif a[i] == b[j]:
+ ordering.append(a[i])
+ i += 1
+ j += 1
+ else:
+ return None
+
+ ordering.extend(a[i:])
+ ordering.extend(b[j:])
+
+ return ordering
+
+
+@tc.returns(LabeledTensor, LabeledTensor, Axes)
+@tc.accepts(LabeledTensorLike, LabeledTensorLike, tc.Optional(string_types))
+def align(labeled_tensor_0, labeled_tensor_1, name=None):
+ """Align the axes of two tensors so they may be broadcast to each other.
+
+ Axes are ordered by the current axis order scope, if present, or by the left-
+ most consistent ordering. An exception is raised if it is impossible to align
+ the tensors without a transpose (align never copies the input data).
+
+ Example usage:
+
+ >>> a = lt.LabeledTensor(tf.ones((2, 4)), ['x', 'z'])
+ >>> b = lt.LabeledTensor(tf.ones((3, 4)), ['y', 'z'])
+ >>> a2, b2, axes = lt.align(a, b)
+ >>> a2
+ <LabeledTensor 'lt_align_1/lt_align_1/0:...' shape=(2, 1, 4) dtype=float32
+ axes=[('x', Dimension(2)),
+ ('y', Dimension(1)),
+ ('z', Dimension(4))]>
+ >>> b2
+ <LabeledTensor 'lt_align_1/lt_align_1/1:...' shape=(1, 3, 4) dtype=float32
+ axes=[('x', Dimension(1)),
+ ('y', Dimension(3)),
+ ('z', Dimension(4))]>
+ >>> axes
+ Axes([('x', Dimension(2)),
+ ('y', Dimension(3)),
+ ('z', Dimension(4))])
+
+ Args:
+ labeled_tensor_0: An input tensor.
+ labeled_tensor_1: An input tensor.
+ name: Optional op name.
+
+ Returns:
+ The aligned tensors and the axes the resulting tensor would have if the two
+ aligned tensors were broadcast to each other. The aligned tensors have the
+ same rank but not necessarily the same shape, with axes in the same order.
+
+ Raises:
+ ValueError: If axes with the same name on the inputs are not equal.
+ AxisOrderError: If there is no way to reshape the input tensors into the
+ output without a transpose.
+ """
+ with ops.name_scope(name, 'lt_align',
+ [labeled_tensor_0, labeled_tensor_1]) as scope:
+
+ labeled_tensor_0 = convert_to_labeled_tensor(labeled_tensor_0)
+ labeled_tensor_1 = convert_to_labeled_tensor(labeled_tensor_1)
+
+ axes_0 = labeled_tensor_0.axes
+ axes_1 = labeled_tensor_1.axes
+ for axis_name in axes_0:
+ if axis_name in axes_1:
+ if axes_0[axis_name] != axes_1[axis_name]:
+ raise ValueError('Mismatched %r axis on input tensors: %r and %r' %
+ (axis_name, axes_0[axis_name], axes_1[axis_name]))
+
+ axis_scope_order = get_axis_order()
+ if axis_scope_order is not None:
+ # we are in an axis_order_scope
+ axis_names_set = set(axes_0) | set(axes_1)
+ new_axis_names = [a for a in axis_scope_order if a in axis_names_set]
+
+ check_axis_order(labeled_tensor_0, axis_scope_order)
+ check_axis_order(labeled_tensor_1, axis_scope_order)
+
+ else:
+ # attempt to find a consistent ordering
+ new_axis_names = _find_consistent_ordering(list(axes_0), list(axes_1))
+ if new_axis_names is None:
+ raise AxisOrderError(
+ 'No consistent axis order allows for aligning tensors with axis '
+ 'orders %r and %r without copying data. Use transpose or '
+ 'impose_axis_order to reorder axes on one of more of the inputs.' %
+ (axes_0.keys(), axes_1.keys()))
+
+ labeled_tensor_0 = expand_dims(labeled_tensor_0,
+ new_axis_names,
+ name=scope + '0')
+ labeled_tensor_1 = expand_dims(labeled_tensor_1,
+ new_axis_names,
+ name=scope + '1')
+
+ broadcast_axes = []
+ for axis_name in new_axis_names:
+ if axis_name in axes_0:
+ broadcast_axes.append(axes_0[axis_name])
+ else:
+ broadcast_axes.append(axes_1[axis_name])
+
+ return labeled_tensor_0, labeled_tensor_1, Axes(broadcast_axes)
+
+
+@tc.returns(types.FunctionType)
+@tc.accepts(string_types, collections.Callable)
+def define_unary_op(op_name, elementwise_function):
+ """Define a unary operation for labeled tensors.
+
+ Args:
+ op_name: string name of the TensorFlow op.
+ elementwise_function: function to call to evaluate the op on a single
+ tf.Tensor object. This function must accept two arguments: a tf.Tensor
+ object, and an optional `name`.
+
+ Returns:
+ Function defining the given op that acts on LabeledTensors.
+ """
+
+ default_name = 'lt_%s' % op_name
+
+ @tc.returns(LabeledTensor)
+ @tc.accepts(LabeledTensorLike, tc.Optional(string_types))
+ def op(labeled_tensor, name=None):
+ """LabeledTensor version of `tf.{op_name}`.
+
+ See `tf.{op_name}` for full details.
+
+ Args:
+ labeled_tensor: Input tensor.
+ name: Optional op name.
+
+ Returns:
+ A LabeledTensor with result of applying `tf.{op_name}` elementwise.
+ """
+ with ops.name_scope(name, default_name, [labeled_tensor]) as scope:
+ labeled_tensor = convert_to_labeled_tensor(labeled_tensor)
+ result_tensor = elementwise_function(labeled_tensor.tensor, name=scope)
+ return LabeledTensor(result_tensor, labeled_tensor.axes)
+
+ op.__doc__ = op.__doc__.format(op_name=op_name)
+ op.__name__ = op_name
+
+ return op
+
+
+abs_function = define_unary_op('abs', math_ops.abs)
+neg = define_unary_op('neg', math_ops.neg)
+sign = define_unary_op('sign', math_ops.sign)
+inv = define_unary_op('inv', math_ops.inv)
+square = define_unary_op('square', math_ops.square)
+round_function = define_unary_op('round', math_ops.round)
+sqrt = define_unary_op('sqrt', math_ops.sqrt)
+rsqrt = define_unary_op('rsqrt', math_ops.rsqrt)
+exp = define_unary_op('exp', math_ops.exp)
+log = define_unary_op('log', math_ops.log)
+ceil = define_unary_op('ceil', math_ops.ceil)
+floor = define_unary_op('floor', math_ops.floor)
+cos = define_unary_op('cos', math_ops.cos)
+sin = define_unary_op('sin', math_ops.sin)
+tan = define_unary_op('tan', math_ops.tan)
+acos = define_unary_op('acos', math_ops.acos)
+asin = define_unary_op('asin', math_ops.asin)
+atan = define_unary_op('atan', math_ops.atan)
+lgamma = define_unary_op('lgamma', math_ops.lgamma)
+digamma = define_unary_op('digamma', math_ops.digamma)
+erf = define_unary_op('erf', math_ops.erf)
+erfc = define_unary_op('erfc', math_ops.erfc)
+logical_not = define_unary_op('logical_not', math_ops.logical_not)
+tanh = define_unary_op('tanh', math_ops.tanh)
+sigmoid = define_unary_op('sigmoid', math_ops.sigmoid)
+
+
+@tc.returns(types.FunctionType)
+@tc.accepts(string_types, collections.Callable)
+def define_binary_op(op_name, elementwise_function):
+ """Define a binary operation that broadcasts labeled tensors.
+
+ Args:
+ op_name: string name of the TensorFlow op.
+ elementwise_function: function to call to evaluate the op on tf.Tensor
+ objects. This function must accept three arguments: two tf.Tensor objects,
+ and an optional `name`.
+
+ Returns:
+ Function defining the given op that acts on LabeledTensors.
+ """
+
+ default_name = 'lt_%s' % op_name
+
+ @tc.returns(LabeledTensor)
+ @tc.accepts(LabeledTensorLike, LabeledTensorLike, tc.Optional(string_types))
+ def op(labeled_tensor_0, labeled_tensor_1, name=None):
+ """LabeledTensor version of `tf.{op_name}` with label based alignment.
+
+ See `tf.{op_name}` for full details.
+
+ Args:
+ labeled_tensor_0: Input tensor.
+ labeled_tensor_1: Input tensor.
+ name: Optional op name.
+
+ Returns:
+ A LabeledTensor with result of applying `tf.{op_name}` elementwise.
+ """
+ with ops.name_scope(name, default_name,
+ [labeled_tensor_0, labeled_tensor_1]) as scope:
+
+ align_0, align_1, broadcast_axes = align(labeled_tensor_0,
+ labeled_tensor_1)
+
+ tensor = elementwise_function(align_0.tensor, align_1.tensor, name=scope)
+
+ return LabeledTensor(tensor, broadcast_axes)
+
+ op.__doc__ = op.__doc__.format(op_name=op_name)
+ op.__name__ = op_name
+
+ return op
+
+
+add = define_binary_op('add', math_ops.add)
+sub = define_binary_op('sub', math_ops.sub)
+mul = define_binary_op('mul', math_ops.mul)
+div = define_binary_op('div', math_ops.div)
+mod = define_binary_op('mod', math_ops.mod)
+pow_function = define_binary_op('pow', math_ops.pow)
+
+equal = define_binary_op('equal', math_ops.equal)
+greater = define_binary_op('greater', math_ops.greater)
+greater_equal = define_binary_op('greater_equal', math_ops.greater_equal)
+not_equal = define_binary_op('not_equal', math_ops.not_equal)
+less = define_binary_op('less', math_ops.less)
+less_equal = define_binary_op('less_equal', math_ops.less_equal)
+logical_and = define_binary_op('logical_and', math_ops.logical_and)
+logical_or = define_binary_op('logical_or', math_ops.logical_or)
+logical_xor = define_binary_op('logical_xor', math_ops.logical_xor)
+
+maximum = define_binary_op('maximum', math_ops.maximum)
+minimum = define_binary_op('minimum', math_ops.minimum)
+squared_difference = define_binary_op(
+ 'squared_difference', math_ops.squared_difference)
+igamma = define_binary_op('igamma', math_ops.igamma)
+igammac = define_binary_op('igammac', math_ops.igammac)
+zeta = define_binary_op('zeta', math_ops.zeta)
+polygamma = define_binary_op('polygamma', math_ops.polygamma)
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/core_test.py b/tensorflow/contrib/labeled_tensor/python/ops/core_test.py
new file mode 100644
index 0000000000..5710dc34e8
--- /dev/null
+++ b/tensorflow/contrib/labeled_tensor/python/ops/core_test.py
@@ -0,0 +1,842 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import operator
+import re
+import textwrap
+
+import numpy as np
+from six.moves import range # pylint: disable=redefined-builtin
+import tensorflow as tf
+
+from tensorflow.contrib.labeled_tensor.python.ops import _typecheck as tc
+from tensorflow.contrib.labeled_tensor.python.ops import core
+from tensorflow.contrib.labeled_tensor.python.ops import test_util
+
+
+class AxisTest(tf.test.TestCase):
+
+ def setUp(self):
+ d_7 = tf.Dimension(7)
+ p_rgb = ['red', 'green', 'blue']
+
+ self.i_7 = core.Axis('7', d_7)
+ self.i_7p = core.Axis('7prime', d_7)
+ self.i_rgb = core.Axis('rgb', p_rgb)
+ self.i_range = core.Axis('range', range(7))
+ self.i_unknown = core.Axis('unknown', None)
+
+ def test_equality(self):
+
+ axes = [self.i_7, self.i_7p, self.i_rgb, self.i_range, self.i_unknown]
+ for i, axis_0 in enumerate(axes):
+ for j, axis_1 in enumerate(axes):
+ if i == j:
+ self.assertEqual(axis_0, axis_1)
+ else:
+ self.assertNotEqual(axis_0, axis_1)
+
+ def test_axis_value(self):
+ self.assertEqual(self.i_7.value, tf.Dimension(7))
+ self.assertTrue(self.i_range.value == tuple(range(7)))
+
+ def test_axis_input(self):
+ axes = [self.i_7, self.i_7p, self.i_rgb, self.i_range, self.i_unknown]
+ for axis in axes:
+ self.assertEqual(axis, core.Axis(axis.name, axis.value))
+
+ def test_axis_value_input(self):
+ axis = self.i_range
+ for value in [range(7), list(range(7)), np.arange(7)]:
+ self.assertEqual(axis, core.Axis(axis.name, value))
+
+ def test_size(self):
+ self.assertEqual(len(self.i_7), 7)
+ self.assertEqual(len(self.i_rgb), 3)
+ self.assertEqual(len(self.i_range), 7)
+ self.assertEqual(self.i_unknown.size, None)
+
+ def test_concat_single(self):
+ red = core.Axis('rgb', ['red'])
+
+ self.assertEqual(core.concat_axes([red]), red)
+
+ def test_concat_many(self):
+ red = core.Axis('rgb', ['red'])
+ green = core.Axis('rgb', ['green'])
+ blue = core.Axis('rgb', ['blue'])
+ red_green_blue = core.Axis('rgb', ['red', 'green', 'blue'])
+
+ self.assertEqual(core.concat_axes([red, green, blue]), red_green_blue)
+
+ def test_concat_different_names(self):
+ red = core.Axis('red', ['red'])
+ green = core.Axis('green', ['red'])
+ with self.assertRaises(ValueError):
+ core.concat_axes([red, green])
+
+ def test_concat_unknown(self):
+ red = core.Axis('rgb', None)
+ green = core.Axis('rgb', None)
+ self.assertEqual(core.concat_axes([red, green]), red)
+
+ def test_repr(self):
+ self.assertEqual("Axis('7', Dimension(7))", repr(self.i_7))
+
+ def test_invalid_input(self):
+ with self.assertRaises(TypeError):
+ core.Axis('foo', [{}])
+ with self.assertRaises(ValueError):
+ core.Axis('foo', [1, 2, 3, 1])
+ red = core.Axis('foo', ['red'])
+ with self.assertRaises(tc.Error):
+ core.concat_axes([red, 1])
+
+ def test_as_axis(self):
+ self.assertEqual(self.i_7, core.as_axis(('7', 7)))
+ self.assertEqual(self.i_7, core.as_axis(self.i_7))
+
+
+class AxesTest(tf.test.TestCase):
+
+ def setUp(self):
+ d_7 = tf.Dimension(7)
+ d_8 = tf.Dimension(8)
+ p_rgb = ['red', 'green', 'blue']
+ p_range = range(7)
+
+ self.i_8 = core.Axis('8', d_8)
+
+ self.a0 = core.Axes([('d7', d_7)])
+ self.a1 = core.Axes([('d7', d_7)])
+ self.a2 = core.Axes([('d7', d_7), ('rgb', p_rgb)])
+ self.a3 = core.Axes([('8', d_8), ('range', p_range)])
+
+ def test_equality(self):
+ self.assertEqual(self.a0, self.a0)
+ self.assertEqual(self.a0, self.a1)
+ self.assertNotEqual(self.a0, self.a2)
+
+ def test_repr(self):
+ self.assertEqual("Axes([('d7', Dimension(7))])", repr(self.a0))
+
+ def test_remove(self):
+ a = self.a3.remove('range')
+ self.assertEqual(a, core.Axes([self.i_8]))
+ with self.assertRaises(KeyError):
+ self.a3.remove('foobar')
+
+ def test_typecheck_error_message(self):
+ pattern = ('List(Union(labeled_tensor.Axis, Tuple(..., '
+ 'Union(Union(numpy.ndarray, %s, list, tuple), '
+ 'Optional(Union(tensorflow.Dimension, int))))))' %
+ range.__name__)
+ regexp = re.escape(pattern).replace(re.escape('...'), '.*')
+ with self.assertRaisesRegexp(tc.Error, 'allowed type ' + regexp):
+ core.Axes(None)
+
+
+class LabeledTensorTest(test_util.Base):
+
+ def setUp(self):
+ tensor = tf.ones([7, 3, 8, 1])
+ a0 = ('x', range(7))
+ a1 = ('channel', ['red', 'green', 'blue'])
+ a2 = ('y', 8)
+ a3 = ('z', tf.Dimension(1))
+
+ self.lt = core.LabeledTensor(tensor, [a0, a1, a2, a3])
+
+ def test_repr(self):
+ pattern = textwrap.dedent("""\
+ <LabeledTensor '...' shape=(7, 3, 8, 1) dtype=float32
+ axes=[('x', ...),
+ ('channel', ...),
+ ('y', Dimension(8)),
+ ('z', Dimension(1))]>""")
+ regexp = re.escape(pattern).replace(re.escape('...'), '.*')
+ self.assertRegexpMatches(repr(self.lt), regexp)
+
+ def test_reuse_existing_axes(self):
+ alt_lt = core.LabeledTensor(self.lt.tensor, self.lt.axes)
+ self.assertLabeledTensorsEqual(alt_lt, self.lt)
+
+ def test_reuse_existing_axis_objects(self):
+ alt_lt = core.LabeledTensor(self.lt.tensor, self.lt.axes.values())
+ self.assertLabeledTensorsEqual(alt_lt, self.lt)
+
+ def test_indexing_scalars(self):
+ actual = self.lt[:, :, :, 0]
+ expected = core.LabeledTensor(self.lt.tensor[:, :, :, 0],
+ list(self.lt.axes.values())[:-1])
+ self.assertLabeledTensorsEqual(actual, expected)
+
+ actual = self.lt[1, :, :, 0]
+ expected = core.LabeledTensor(self.lt.tensor[1, :, :, 0],
+ list(self.lt.axes.values())[1:-1])
+ self.assertLabeledTensorsEqual(actual, expected)
+
+ actual = self.lt[1, 2, :, 0]
+ expected = core.LabeledTensor(self.lt.tensor[1, 2, :, 0],
+ list(self.lt.axes.values())[2:-1])
+ self.assertLabeledTensorsEqual(actual, expected)
+
+ def test_indexing_1d(self):
+ lt_1d = self.lt[1, 2, :, 0]
+ actual = lt_1d[3]
+ expected = core.LabeledTensor(lt_1d.tensor[3], [])
+ self.assertLabeledTensorsEqual(actual, expected)
+
+ def test_indexing_slices(self):
+ actual = self.lt[:3, :, :, :]
+ axes = [('x', range(3))] + list(self.lt.axes.values())[1:]
+ expected = core.LabeledTensor(self.lt.tensor[:3, :, :, :], axes)
+ self.assertLabeledTensorsEqual(actual, expected)
+
+ def test_invalid_indexing(self):
+ with self.assertRaises(ValueError):
+ self.lt[0] # pylint: disable=pointless-statement
+ with self.assertRaises(ValueError):
+ self.lt[:, :, :, :, 0] # pylint: disable=pointless-statement
+
+ def test_unknown_size(self):
+ tensor = tf.placeholder(tf.string, [None])
+ actual = core.LabeledTensor(tensor, ['x'])
+ self.assertIsNone(actual.axes['x'].size)
+ self.assertIs(actual.axes['x'].value, tensor.get_shape()[0])
+
+ def test_eq(self):
+ self.assertEqual(self.lt, self.lt)
+ self.assertNotEqual(self.lt, self.lt.tensor)
+ self.assertNotEqual(self.lt.tensor, self.lt)
+
+ def test_hash(self):
+ lt1 = self.lt
+ lt2 = core.LabeledTensor(self.lt.tensor, self.lt.axes)
+ self.assertEqual(lt1, lt2)
+ self.assertEqual(hash(lt1), hash(lt2))
+
+ def test_name(self):
+ self.assertEqual(self.lt.name, self.lt.tensor.name)
+
+ def test_dtype(self):
+ self.assertEqual(self.lt.dtype, self.lt.tensor.dtype)
+
+ def test_get_shape(self):
+ self.assertEqual(self.lt.get_shape(), self.lt.tensor.get_shape())
+
+ def test_convert_to_tensor(self):
+ expected = self.lt.tensor
+ actual = tf.convert_to_tensor(self.lt)
+ self.assertIs(expected, actual)
+
+
+class Base(test_util.Base):
+
+ def setUp(self):
+ self.x_size = 7
+ self.channel_size = 3
+ self.z_size = 4
+ self.probs_size = 11
+
+ tensor = tf.range(0, self.x_size * self.channel_size * self.z_size *
+ self.probs_size)
+ tensor = tf.reshape(tensor, [self.x_size, self.channel_size, self.z_size,
+ self.probs_size])
+ a0 = ('x', range(self.x_size))
+ a1 = ('channel', ['red', 'green', 'blue'])
+ a2 = 'z'
+ a3 = ('probs', np.linspace(0.0, 1.0, self.probs_size))
+
+ self.tensor = tensor
+ self.a0 = a0
+ self.a1 = a1
+ self.a2 = a2
+ self.a3 = a3
+ self.original_lt = core.LabeledTensor(tensor, [a0, a1, a2, a3])
+
+ self.x_probs_lt = core.slice_function(self.original_lt, {'z': 0,
+ 'channel': 0})
+ self.channel_probs_lt = core.slice_function(self.original_lt, {'x': 3,
+ 'z': 0})
+
+
+class IdentityTest(Base):
+
+ def test_name(self):
+ identity_lt = core.identity(self.original_lt)
+ self.assertIn('lt_identity', identity_lt.name)
+
+
+class SliceFunctionTest(Base):
+
+ def test_name(self):
+ select_lt = core.slice_function(self.original_lt, {'channel': 1})
+ self.assertIn('lt_slice', select_lt.name)
+
+ def test_scalar(self):
+ select_lt = core.slice_function(self.original_lt, {'channel': 1})
+ golden_lt = core.LabeledTensor(self.tensor[:, 1, :, :], [self.a0, self.a2,
+ self.a3])
+
+ self.assertLabeledTensorsEqual(select_lt, golden_lt)
+
+ def test_slice(self):
+ select_lt = core.slice_function(self.original_lt, {'channel': slice(0, 2)})
+
+ a1_sliced = ('channel', ['red', 'green'])
+ golden_lt = core.LabeledTensor(self.tensor[:, :2, :, :],
+ [self.a0, a1_sliced, self.a2, self.a3])
+
+ self.assertLabeledTensorsEqual(select_lt, golden_lt)
+
+ def test_slices(self):
+ select_lt = core.slice_function(self.original_lt, {'x': slice(1, 5),
+ 'channel': slice(1,
+ None)})
+
+ a0_sliced = ('x', range(1, 5))
+ a1_sliced = ('channel', ['green', 'blue'])
+ golden_lt = core.LabeledTensor(self.tensor[1:5, 1:, :, :],
+ [a0_sliced, a1_sliced, self.a2, self.a3])
+
+ self.assertLabeledTensorsEqual(select_lt, golden_lt)
+
+ def test_slice_unlabeled(self):
+ select_lt = core.slice_function(self.original_lt, {'z': slice(1, 3)})
+
+ a2_sliced = 'z'
+ golden_lt = core.LabeledTensor(self.tensor[:, :, 1:3, :],
+ [self.a0, self.a1, a2_sliced, self.a3])
+
+ self.assertLabeledTensorsEqual(select_lt, golden_lt)
+
+ def test_slice_unknown_shape(self):
+ lt = core.LabeledTensor(tf.placeholder(tf.float32, [None, 1]), ['x', 'y'])
+ sliced_lt = core.slice_function(lt, {'y': 0})
+ self.assertEqual(list(sliced_lt.axes.values()), [lt.axes['x']])
+
+
+class TransposeTest(Base):
+
+ def test_name(self):
+ transpose_lt = core.transpose(self.original_lt,
+ self.original_lt.axes.keys())
+ self.assertIn('lt_transpose', transpose_lt.name)
+
+ def test_identity(self):
+ transpose_lt = core.transpose(self.original_lt,
+ self.original_lt.axes.keys())
+ golden_lt = self.original_lt
+
+ self.assertLabeledTensorsEqual(transpose_lt, golden_lt)
+
+ def test(self):
+ transpose_lt = core.transpose(self.original_lt, ['z', 'channel', 'x',
+ 'probs'])
+ golden_lt = core.LabeledTensor(
+ tf.transpose(self.tensor, [2, 1, 0, 3]),
+ [self.a2, self.a1, self.a0, self.a3])
+
+ self.assertLabeledTensorsEqual(transpose_lt, golden_lt)
+
+ def test_default_axis_order(self):
+ transpose_lt = core.transpose(self.original_lt)
+ golden_lt = core.LabeledTensor(
+ tf.transpose(self.tensor, [3, 2, 1, 0]),
+ list(reversed(list(self.original_lt.axes.values()))))
+
+ self.assertLabeledTensorsEqual(transpose_lt, golden_lt)
+
+ def test_invalid_input(self):
+ with self.assertRaises(ValueError):
+ core.transpose(self.original_lt, ['channel', 'x', 'probs'])
+ with self.assertRaises(ValueError):
+ core.transpose(self.original_lt, ['z', 'foo', 'x', 'probs'])
+
+
+class ExpandDimsTest(Base):
+
+ def test_name(self):
+ expand_lt = core.expand_dims(self.original_lt, self.original_lt.axes.keys())
+ self.assertIn('lt_expand', expand_lt.name)
+
+ def test_identity(self):
+ expand_lt = core.expand_dims(self.original_lt, self.original_lt.axes.keys())
+ golden_lt = self.original_lt
+
+ self.assertLabeledTensorsEqual(expand_lt, golden_lt)
+
+ def test(self):
+ expand_lt = core.expand_dims(self.original_lt, ['foo', 'x', 'bar',
+ 'channel', 'z', 'probs',
+ 'grok'])
+ golden_lt = core.LabeledTensor(
+ tf.reshape(self.tensor, [1, self.x_size, 1, self.channel_size,
+ self.z_size, self.probs_size, 1]),
+ ['foo', self.a0, 'bar', self.a1, self.a2, self.a3, 'grok'])
+
+ self.assertLabeledTensorsEqual(expand_lt, golden_lt)
+
+ def test_label(self):
+ expand_lt = core.expand_dims(self.original_lt, ['x',
+ 'channel',
+ ('foo', 'bar'),
+ 'z',
+ 'probs',])
+ golden_lt = core.LabeledTensor(
+ tf.reshape(self.tensor, [self.x_size, self.channel_size, 1, self.z_size,
+ self.probs_size]),
+ [self.a0, self.a1, ('foo', ['bar']), self.a2, self.a3])
+
+ self.assertLabeledTensorsEqual(expand_lt, golden_lt)
+
+ def test_unknown_dimension(self):
+ orig_lt = core.LabeledTensor(tf.placeholder(tf.float32, [None]), ['x'])
+ expand_lt = core.expand_dims(orig_lt, ['x', 'y'])
+ self.assertEqual(expand_lt.axes, core.Axes([('x', None), ('y', 1)]))
+
+ def test_invalid_input(self):
+ with self.assertRaises(core.AxisOrderError):
+ core.expand_dims(self.original_lt, ['foo', 'not_x', 'bar', 'channel', 'z',
+ 'probs', 'grok'])
+ with self.assertRaises(core.AxisOrderError):
+ core.expand_dims(self.original_lt, ['foo', 'z', 'bar', 'channel', 'x',
+ 'probs', 'grok'])
+
+
+class AxisOrderScopeTest(Base):
+
+ def test(self):
+ xyz = ['x', 'y', 'z']
+ abc = ['a', 'b', 'c']
+
+ self.assertIsNone(core.get_axis_order())
+
+ with core.axis_order_scope(xyz):
+ self.assertEqual(core.get_axis_order(), xyz)
+
+ with core.axis_order_scope():
+ self.assertIsNone(core.get_axis_order())
+
+ with core.axis_order_scope(abc):
+ self.assertEqual(core.get_axis_order(), abc)
+
+ self.assertIsNone(core.get_axis_order())
+
+ self.assertEqual(core.get_axis_order(), xyz)
+
+ self.assertIsNone(core.get_axis_order())
+
+
+class CheckAxisOrderTest(Base):
+
+ def test_passes(self):
+ axis_order = ['w', 'x', 'y', 'z']
+
+ lt = core.LabeledTensor(tf.ones((1, 1, 1, 1)), axis_order)
+ core.check_axis_order(lt, axis_order)
+
+ lt = core.LabeledTensor(tf.ones((1, 1, 1)), axis_order[1:])
+ core.check_axis_order(lt, axis_order)
+
+ lt = core.LabeledTensor(tf.ones((1, 1, 1)), axis_order[:-1])
+ core.check_axis_order(lt, axis_order)
+
+ def test_invalid(self):
+ axis_order = ['w', 'x', 'y', 'z']
+ lt = core.LabeledTensor(tf.ones((1, 1, 1, 1)), axis_order)
+ with self.assertRaises(core.AxisOrderError):
+ core.check_axis_order(lt)
+ with self.assertRaises(core.AxisOrderError):
+ core.check_axis_order(lt, axis_order[:-1])
+ with self.assertRaises(core.AxisOrderError):
+ core.check_axis_order(lt, axis_order[::-1])
+
+ def test_scope(self):
+ axis_order = ['w', 'x', 'y', 'z']
+ lt = core.LabeledTensor(tf.ones((1, 1, 1, 1)), axis_order)
+ with core.axis_order_scope(axis_order):
+ core.check_axis_order(lt)
+
+
+class ImposeAxisOrderTest(Base):
+
+ def test_identity(self):
+ axis_order = ['w', 'x', 'y', 'z']
+ lt = core.LabeledTensor(tf.reshape(tf.range(24), (1, 2, 3, 4)), axis_order)
+ actual = core.impose_axis_order(lt, axis_order)
+ self.assertLabeledTensorsEqual(lt, actual)
+
+ lt = core.LabeledTensor(tf.reshape(tf.range(6), (1, 2, 3)), axis_order[:3])
+ actual = core.impose_axis_order(lt, axis_order)
+ self.assertLabeledTensorsEqual(lt, actual)
+
+ def test_reverse(self):
+ axis_order = ['w', 'x', 'y', 'z']
+
+ lt = core.LabeledTensor(tf.reshape(tf.range(24), (1, 2, 3, 4)), axis_order)
+ actual = core.impose_axis_order(lt, axis_order[::-1])
+ expected = core.transpose(lt, axis_order[::-1])
+ self.assertLabeledTensorsEqual(expected, actual)
+
+ lt = core.LabeledTensor(tf.reshape(tf.range(6), (1, 2, 3)), axis_order[:3])
+ actual = core.impose_axis_order(lt, axis_order[::-1])
+ expected = core.transpose(lt, ['y', 'x', 'w'])
+ self.assertLabeledTensorsEqual(expected, actual)
+
+ def test_scope(self):
+ axis_order = ['w', 'x', 'y', 'z']
+
+ lt = core.LabeledTensor(tf.reshape(tf.range(24), (1, 2, 3, 4)), axis_order)
+ expected = core.transpose(lt, axis_order[::-1])
+ with core.axis_order_scope(axis_order[::-1]):
+ actual = core.impose_axis_order(lt)
+ self.assertLabeledTensorsEqual(expected, actual)
+
+ def test_invalid(self):
+ lt = core.LabeledTensor(tf.reshape(tf.range(2), (1, 2)), ['x', 'y'])
+ with self.assertRaises(ValueError):
+ core.impose_axis_order(lt)
+ with self.assertRaises(ValueError):
+ core.impose_axis_order(lt, ['x'])
+
+
+class FindConsistentOrderingTest(Base):
+
+ def test(self):
+ cases = [
+ ([], [], []),
+ (['x'], [], ['x']),
+ ([], ['x'], ['x']),
+ (['x'], ['x'], ['x']),
+ (['x'], ['y'], ['x', 'y']),
+ (['y'], ['x'], ['y', 'x']),
+ (['x', 'y'], ['x', 'y'], ['x', 'y']),
+ (['x', 'y'], ['y', 'x'], None),
+ (['x', 'y'], ['y', 'z'], ['x', 'y', 'z']),
+ (['x', 'z'], ['y', 'z'], ['x', 'y', 'z']),
+ (['x', 'y'], ['x', 'z'], ['x', 'y', 'z']),
+ (['w', 'x'], ['y', 'z'], ['w', 'x', 'y', 'z']),
+ (['x', 'y', 'z'], ['z', 'x'], None),
+ (['x', 'y', 'z'], ['x'], ['x', 'y', 'z']),
+ ([], ['x', 'y', 'z'], ['x', 'y', 'z']),
+ ]
+ for a, b, expected in cases:
+ actual = core._find_consistent_ordering(a, b)
+ msg = ('unexpected ordering between %r and %r:\nexpected: %r\nactual: %r'
+ % (a, b, expected, actual))
+ self.assertEqual(expected, actual, msg=msg)
+
+
+class AlignTest(Base):
+
+ def test_name(self):
+ align_lt_0, align_lt_1, _ = core.align(self.original_lt, self.original_lt)
+ self.assertIn('lt_align', align_lt_0.name)
+ self.assertIn('/0', align_lt_0.name)
+ self.assertIn('lt_align', align_lt_1.name)
+ self.assertIn('/1', align_lt_1.name)
+
+ def test_identical_shaped_inputs(self):
+ offset_tensor = self.original_lt.tensor + 1
+ offset_lt = core.LabeledTensor(offset_tensor, self.original_lt.axes)
+
+ align_lt, align_offset_lt, broadcast_axes = core.align(self.original_lt,
+ offset_lt)
+
+ self.assertLabeledTensorsEqual(align_lt, self.original_lt)
+ self.assertLabeledTensorsEqual(align_offset_lt, offset_lt)
+ self.assertEqual(broadcast_axes, self.original_lt.axes)
+
+ def test_different_inputs(self):
+ # The correct axis ordering is ['x', 'channel', 'probs'].
+ align_x_probs_lt, align_channel_probs_lt, broadcast_axes = core.align(
+ self.x_probs_lt, self.channel_probs_lt)
+
+ x_probs_golden_lt = core.LabeledTensor(
+ tf.reshape(self.x_probs_lt.tensor, [self.x_size, 1, self.probs_size]),
+ [self.a0, 'channel', self.a3])
+
+ self.assertLabeledTensorsEqual(align_x_probs_lt, x_probs_golden_lt)
+
+ channel_probs_golden_lt = core.LabeledTensor(
+ tf.reshape(self.channel_probs_lt.tensor,
+ [1, self.channel_size, self.probs_size]),
+ ['x', self.a1, self.a3])
+
+ self.assertLabeledTensorsEqual(align_channel_probs_lt,
+ channel_probs_golden_lt)
+
+ self.assertEqual(broadcast_axes, core.Axes([self.a0, self.a1, self.a3]))
+
+ def test_axis_order_scope(self):
+ xz_lt = core.LabeledTensor(tf.ones((2, 3)), ['x', 'z'])
+ yz_lt = core.LabeledTensor(tf.ones((4, 3)), ['y', 'z'])
+
+ _, _, broadcast_axes = core.align(xz_lt, yz_lt)
+ self.assertEqual(list(broadcast_axes.keys()), ['x', 'y', 'z'])
+
+ _, _, broadcast_axes = core.align(yz_lt, xz_lt)
+ self.assertEqual(list(broadcast_axes.keys()), ['y', 'x', 'z'])
+
+ with core.axis_order_scope(['x', 'y', 'z']):
+ _, _, broadcast_axes = core.align(yz_lt, xz_lt)
+ self.assertEqual(list(broadcast_axes.keys()), ['x', 'y', 'z'])
+
+ with core.axis_order_scope(['x', 'y']):
+ with self.assertRaises(core.AxisOrderError):
+ core.align(xz_lt, yz_lt)
+ with self.assertRaises(core.AxisOrderError):
+ core.align(yz_lt, xz_lt)
+
+ def test_invalid_input(self):
+ lt_0 = core.LabeledTensor(tf.zeros([5]), [('a', range(5))])
+ lt_1 = core.LabeledTensor(tf.zeros([5]), [('a', range(1, 6))])
+ with self.assertRaises(ValueError):
+ core.align(lt_0, lt_1)
+
+
+class ConvertToLabeledTensorTest(Base):
+
+ # TODO(shoyer): Simplify these tests once we can reuse labeled tensors in
+ # assertLabeledTensorsEqual.
+
+ def test_labeled_tensor(self):
+ actual = core.convert_to_labeled_tensor(self.original_lt)
+ self.assertLabeledTensorsEqual(actual, self.original_lt)
+
+ def test_python_scalar(self):
+ actual = core.convert_to_labeled_tensor(42)
+ golden_lt = core.LabeledTensor(tf.convert_to_tensor(42), [])
+ self.assertLabeledTensorsEqual(actual, golden_lt)
+
+ def test_numpy_array(self):
+ actual = core.convert_to_labeled_tensor(np.array(42))
+ golden_lt = core.LabeledTensor(tf.convert_to_tensor(42), [])
+ self.assertLabeledTensorsEqual(actual, golden_lt)
+
+ def test_tensor(self):
+ actual = core.convert_to_labeled_tensor(tf.constant(42))
+ golden_lt = core.LabeledTensor(tf.convert_to_tensor(42), [])
+ self.assertLabeledTensorsEqual(actual, golden_lt)
+
+ def test_invalid_input(self):
+ with self.assertRaises(ValueError):
+ core.convert_to_labeled_tensor(tf.range(5))
+ with self.assertRaises(ValueError):
+ core.convert_to_labeled_tensor(np.array([1, 2]))
+
+
+class DocStringCheckMixin(object):
+ # requires self.ops to be defined
+
+ def test_function_docstring_and_name(self):
+ for op_name, _, _, lt_op in self.ops:
+ if lt_op is not None:
+ self.assertIn('tf.%s' % op_name, lt_op.__doc__)
+ self.assertEqual(op_name, lt_op.__name__)
+
+
+class UnaryOpsTestsMixin(object):
+ # requires self.ops and self.test_lt to be defined
+
+ def test_core_op(self):
+ for op_name, _, tf_op, lt_op in self.ops:
+ if tf_op is not None:
+ golden_lt = core.LabeledTensor(tf_op(self.test_lt.tensor),
+ self.test_lt.axes)
+ actual_lt = lt_op(self.test_lt)
+ self.assertIn(op_name, actual_lt.name)
+ self.assertLabeledTensorsEqual(golden_lt, actual_lt)
+
+ def test_infix(self):
+ for op_name, infix_op, _, _ in self.ops:
+ if infix_op is not None:
+ expected_lt = core.LabeledTensor(infix_op(self.test_lt.tensor),
+ self.test_lt.axes)
+ actual_lt = infix_op(self.test_lt)
+ self.assertIn(op_name, actual_lt.name)
+ self.assertLabeledTensorsEqual(expected_lt, actual_lt)
+
+
+class CoreUnaryOpsTest(Base, DocStringCheckMixin, UnaryOpsTestsMixin):
+
+ def setUp(self):
+ super(CoreUnaryOpsTest, self).setUp()
+
+ self.ops = [
+ ('abs', operator.abs, tf.abs, core.abs_function),
+ ('neg', operator.neg, tf.neg, core.neg),
+ # TODO(shoyer): add unary + to core TensorFlow
+ ('pos', None, None, None),
+ ('sign', None, tf.sign, core.sign),
+ ('inv', None, tf.inv, core.inv),
+ ('square', None, tf.square, core.square),
+ ('round', None, tf.round, core.round_function),
+ ('sqrt', None, tf.sqrt, core.sqrt),
+ ('rsqrt', None, tf.rsqrt, core.rsqrt),
+ ('log', None, tf.log, core.log),
+ ('exp', None, tf.exp, core.exp),
+ ('log', None, tf.log, core.log),
+ ('ceil', None, tf.ceil, core.ceil),
+ ('floor', None, tf.floor, core.floor),
+ ('cos', None, tf.cos, core.cos),
+ ('sin', None, tf.sin, core.sin),
+ ('tan', None, tf.tan, core.tan),
+ ('acos', None, tf.acos, core.acos),
+ ('asin', None, tf.asin, core.asin),
+ ('atan', None, tf.atan, core.atan),
+ ('lgamma', None, tf.lgamma, core.lgamma),
+ ('digamma', None, tf.digamma, core.digamma),
+ ('erf', None, tf.erf, core.erf),
+ ('erfc', None, tf.erfc, core.erfc),
+ ('lgamma', None, tf.lgamma, core.lgamma),
+ ]
+ total_size = np.prod([v.size for v in self.original_lt.axes.values()])
+ self.test_lt = core.LabeledTensor(
+ tf.cast(self.original_lt, tf.float32) / total_size,
+ self.original_lt.axes)
+
+
+class LogicalNotTest(Base, DocStringCheckMixin, UnaryOpsTestsMixin):
+
+ def setUp(self):
+ super(LogicalNotTest, self).setUp()
+ self.ops = [
+ ('logical_not', operator.invert, tf.logical_not, core.logical_not),
+ ]
+ self.test_lt = self.original_lt < 10
+
+
+class BinaryOpsTestsMixin(object):
+ # requires self.ops, self.test_lt_1, self.test_lt_2, self.test_lt_1_broadcast
+ # and self.test_lt_2_broadcast to be defined
+
+ def test_core_op(self):
+ for op_name, _, tf_op, lt_op in self.ops:
+ golden_tensor = tf_op(self.test_lt_1_broadcast,
+ self.test_lt_2_broadcast)
+ golden_lt = core.LabeledTensor(golden_tensor, self.broadcast_axes)
+ actual_lt = lt_op(self.test_lt_1, self.test_lt_2)
+ self.assertIn(op_name, actual_lt.name)
+ self.assertLabeledTensorsEqual(golden_lt, actual_lt)
+
+ def test_infix(self):
+ for op_name, infix_op, _, lt_op in self.ops:
+ if infix_op is not None:
+ expected_lt = lt_op(self.test_lt_1, self.test_lt_2)
+ actual_lt = infix_op(self.test_lt_1, self.test_lt_2)
+ self.assertIn(op_name, actual_lt.name)
+ self.assertLabeledTensorsEqual(expected_lt, actual_lt)
+
+
+class CoreBinaryOpsTest(Base, DocStringCheckMixin, BinaryOpsTestsMixin):
+
+ def setUp(self):
+ super(CoreBinaryOpsTest, self).setUp()
+
+ self.x_probs_broadcast_tensor = tf.reshape(
+ self.x_probs_lt.tensor, [self.x_size, 1, self.probs_size])
+
+ self.channel_probs_broadcast_tensor = tf.reshape(
+ self.channel_probs_lt.tensor, [1, self.channel_size, self.probs_size])
+
+ # == and != are not element-wise for tf.Tensor, so they shouldn't be
+ # elementwise for LabeledTensor, either.
+ self.ops = [
+ ('add', operator.add, tf.add, core.add),
+ ('sub', operator.sub, tf.sub, core.sub),
+ ('mul', operator.mul, tf.mul, core.mul),
+ ('div', operator.truediv, tf.div, core.div),
+ ('mod', operator.mod, tf.mod, core.mod),
+ ('pow', operator.pow, tf.pow, core.pow_function),
+ ('equal', None, tf.equal, core.equal),
+ ('less', operator.lt, tf.less, core.less),
+ ('less_equal', operator.le, tf.less_equal, core.less_equal),
+ ('not_equal', None, tf.not_equal, core.not_equal),
+ ('greater', operator.gt, tf.greater, core.greater),
+ ('greater_equal', operator.ge, tf.greater_equal, core.greater_equal),
+ ]
+ self.test_lt_1 = self.x_probs_lt
+ self.test_lt_2 = self.channel_probs_lt
+ self.test_lt_1_broadcast = self.x_probs_broadcast_tensor
+ self.test_lt_2_broadcast = self.channel_probs_broadcast_tensor
+ self.broadcast_axes = [self.a0, self.a1, self.a3]
+
+ def test_reflexive(self):
+ labeled_tensor = self.x_probs_lt + 1 # all elements must be >0 for division
+ for op_name, infix_op, _, lt_op in self.ops:
+ if infix_op is not None:
+ expected_lt = lt_op(2, labeled_tensor)
+ actual_lt = infix_op(2, labeled_tensor)
+ # Python uses greater for the reflexive version of less (and vise-versa)
+ if 'less' in op_name:
+ op_name = op_name.replace('less', 'greater')
+ elif 'greater' in op_name:
+ op_name = op_name.replace('greater', 'less')
+ self.assertIn(op_name, actual_lt.name)
+ self.assertLabeledTensorsEqual(expected_lt, actual_lt)
+
+
+class LogicalBinaryOpsTest(Base, DocStringCheckMixin, BinaryOpsTestsMixin):
+
+ def setUp(self):
+ super(LogicalBinaryOpsTest, self).setUp()
+
+ self.ops = [
+ ('logical_and', operator.and_, tf.logical_and, core.logical_and),
+ ('logical_or', operator.or_, tf.logical_or, core.logical_or),
+ ('logical_xor', operator.xor, tf.logical_xor, core.logical_xor),
+ ]
+ self.test_lt_1 = self.original_lt < 10
+ self.test_lt_2 = self.original_lt < 5
+ self.test_lt_1_broadcast = self.test_lt_1.tensor
+ self.test_lt_2_broadcast = self.test_lt_2.tensor
+ self.broadcast_axes = self.test_lt_1.axes
+
+
+class FloatBinaryOpsTest(Base, DocStringCheckMixin, BinaryOpsTestsMixin):
+
+ def setUp(self):
+ super(FloatBinaryOpsTest, self).setUp()
+
+ self.ops = [
+ ('igamma', None, tf.igamma, core.igamma),
+ ('igammac', None, tf.igammac, core.igammac),
+ ('zeta', None, tf.zeta, core.zeta),
+ ('polygamma', None, tf.polygamma, core.polygamma),
+ ('maximum', None, tf.maximum, core.maximum),
+ ('minimum', None, tf.minimum, core.minimum),
+ ('squared_difference', None, tf.squared_difference,
+ core.squared_difference),
+ ]
+ total_size = np.prod([v.size for v in self.original_lt.axes.values()])
+ test_lt = core.LabeledTensor(
+ tf.cast(self.original_lt, tf.float32) / total_size,
+ self.original_lt.axes)
+ self.test_lt_1 = test_lt
+ self.test_lt_2 = 1.0 - test_lt
+ self.test_lt_1_broadcast = self.test_lt_1.tensor
+ self.test_lt_2_broadcast = self.test_lt_2.tensor
+ self.broadcast_axes = self.test_lt_1.axes
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/io_ops.py b/tensorflow/contrib/labeled_tensor/python/ops/io_ops.py
new file mode 100644
index 0000000000..3bb9c21c2e
--- /dev/null
+++ b/tensorflow/contrib/labeled_tensor/python/ops/io_ops.py
@@ -0,0 +1,178 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Input parsing code for LabeledTensors."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from six import string_types
+
+from tensorflow.contrib.labeled_tensor.python.ops import _typecheck as tc
+from tensorflow.contrib.labeled_tensor.python.ops import core
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import parsing_ops
+
+
+class FixedLenFeature(object):
+ """Configuration for parsing a fixed-length input feature.
+
+ Fields:
+ axes: A list of Axis objects or tuples (axis_name, axis_value),
+ where `axis_name` is a string and `axis_value` is None (unknown size), an
+ integer or a list of tick labels.
+ dtype: Data type of input.
+ default_value: Value to be used if an example is missing this feature. It
+ must be compatible with `dtype`.
+ """
+
+ def __init__(self, axes, dtype, default_value=None):
+ self._axes = [core.as_axis(a) for a in axes]
+ self._dtype = dtype
+ self._default_value = default_value
+
+ @property
+ def axes(self):
+ return self._axes
+
+ @property
+ def dtype(self):
+ return self._dtype
+
+ @property
+ def default_value(self):
+ return self._default_value
+
+
+@tc.returns(tc.Dict(string_types, parsing_ops.FixedLenFeature))
+@tc.accepts(tc.Mapping(string_types, FixedLenFeature))
+def _labeled_to_unlabeled_features(features):
+ """Convert a dict of lt.FixedLenFeature into a dict of tf.FixedLenFeature."""
+ unlabeled_features = {}
+ for name, labeled_feature in features.items():
+ shape = [ax.size for ax in labeled_feature.axes]
+ if any(size is None for size in shape):
+ # This should be caught on the TensorFlow side, but it isn't yet:
+ # https://github.com/tensorflow/tensorflow/issues/2874
+ raise ValueError('axes with unknown size are not supported')
+ dtype = labeled_feature.dtype
+ default_value = labeled_feature.default_value
+ unlabeled_features[name] = parsing_ops.FixedLenFeature(
+ shape, dtype, default_value)
+ return unlabeled_features
+
+
+@tc.returns(tc.Dict(string_types, core.LabeledTensor))
+@tc.accepts(core.LabeledTensorLike, tc.Mapping(string_types, FixedLenFeature),
+ tc.Optional(string_types), object)
+def parse_example(serialized, features, name=None, example_names=None):
+ """Parse `Example` protos into a `dict` of labeled tensors.
+
+ See tf.parse_example.
+
+ Args:
+ serialized: A 1-D LabeledTensor of strings, a batch of binary serialized
+ `Example` protos.
+ features: A `dict` mapping feature keys to `labeled_tensor.FixedLenFeature`
+ values.
+ name: A name for this operation (optional).
+ example_names: A vector (1-D Tensor) of strings (optional), the names of
+ the serialized protos in the batch.
+
+ Returns:
+ A `dict` mapping feature keys to `LabeledTensor` values. The single axis
+ from `serialized` will be prepended to the axes provided by each feature.
+
+ Raises:
+ ValueError: if any feature is invalid.
+ """
+ serialized = core.convert_to_labeled_tensor(serialized)
+ unlabeled_features = _labeled_to_unlabeled_features(features)
+
+ unlabeled_parsed = parsing_ops.parse_example(
+ serialized.tensor, unlabeled_features, name, example_names)
+
+ parsed = {}
+ for name, parsed_feature in unlabeled_parsed.items():
+ axes = list(serialized.axes.values()) + features[name].axes
+ parsed[name] = core.LabeledTensor(parsed_feature, axes)
+
+ return parsed
+
+
+@tc.returns(tc.Dict(string_types, core.LabeledTensor))
+@tc.accepts(core.LabeledTensorLike, tc.Mapping(string_types, FixedLenFeature),
+ tc.Optional(string_types), object)
+def parse_single_example(serialized, features, name=None, example_names=None):
+ """Parses a single `Example` proto.
+
+ See tf.parse_single_example.
+
+ Args:
+ serialized: A scalar string Tensor or LabeledTensor, a single serialized
+ Example.
+ features: A `dict` mapping feature keys to `labeled_tensor.FixedLenFeature`
+ values.
+ name: A name for this operation (optional).
+ example_names: (Optional) A scalar string Tensor, the associated name.
+
+ Returns:
+ A `dict` mapping feature keys to `LabeledTensor` values.
+
+ Raises:
+ ValueError: if any feature is invalid.
+ """
+ serialized = core.convert_to_labeled_tensor(serialized)
+ unlabeled_features = _labeled_to_unlabeled_features(features)
+
+ unlabeled_parsed = parsing_ops.parse_single_example(
+ serialized.tensor, unlabeled_features, name, example_names)
+
+ parsed = {}
+ for name, parsed_feature in unlabeled_parsed.items():
+ parsed[name] = core.LabeledTensor(parsed_feature, features[name].axes)
+
+ return parsed
+
+
+@tc.returns(core.LabeledTensor)
+@tc.accepts(dtypes.DType, tc.Collection(tc.Union(string_types, core.AxisLike)),
+ tc.Optional(string_types))
+def placeholder(dtype, axes, name=None):
+ """Create a placeholder for a labeled tensor.
+
+ For example:
+
+ lt.placeholder(tf.float32, ['batch', ('channel', ['r', 'g', 'b'])])
+
+ See tf.placeholder for more details.
+
+ Args:
+ dtype: The type of elements in the tensor to be fed.
+ axes: sequence of strings (denoting axes of unknown size) and/or objects
+ convertable to lt.Axis to label the result.
+ name: Optional op name.
+
+ Returns:
+ Placeholder labeled tensor.
+ """
+ with ops.name_scope(name, 'lt_placeholder', []) as scope:
+ axes = core.Axes([(axis, None) if isinstance(axis, string_types) else axis
+ for axis in axes])
+ shape = [axis.size for axis in axes.values()]
+ tensor = array_ops.placeholder(dtype, shape, name=scope)
+ return core.LabeledTensor(tensor, axes)
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/io_ops_test.py b/tensorflow/contrib/labeled_tensor/python/ops/io_ops_test.py
new file mode 100644
index 0000000000..b9d3d9cec2
--- /dev/null
+++ b/tensorflow/contrib/labeled_tensor/python/ops/io_ops_test.py
@@ -0,0 +1,106 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.contrib.labeled_tensor.python.ops import core
+from tensorflow.contrib.labeled_tensor.python.ops import io_ops
+from tensorflow.contrib.labeled_tensor.python.ops import test_util
+
+
+class ParseBase(test_util.Base):
+
+ def setUp(self):
+ super(ParseBase, self).setUp()
+ examples = [
+ tf.train.Example(features=tf.train.Features(feature={
+ 'a': tf.train.Feature(
+ int64_list=tf.train.Int64List(value=[1])),
+ 'b': tf.train.Feature(
+ int64_list=tf.train.Int64List(value=[2, 3, 4])),
+ })),
+ tf.train.Example(features=tf.train.Features(feature={
+ 'a': tf.train.Feature(
+ int64_list=tf.train.Int64List(value=[5])),
+ 'b': tf.train.Feature(
+ int64_list=tf.train.Int64List(value=[6, 7, 8])),
+ })),
+ ]
+ self.serialized = core.LabeledTensor(
+ tf.constant([ex.SerializeToString() for ex in examples]), ['batch'])
+ self.features = {'a': io_ops.FixedLenFeature([], tf.int64),
+ 'b': io_ops.FixedLenFeature([('x', 3)], tf.int64)}
+
+
+class TestParseExample(ParseBase):
+
+ def test(self):
+ expected_a = core.LabeledTensor(tf.constant([1, 5]), ['batch'])
+ expected_b = core.LabeledTensor(tf.constant([[2, 3, 4], [6, 7, 8]]),
+ ['batch', 'x'])
+ parsed = io_ops.parse_example(self.serialized, self.features)
+ self.assertLabeledTensorsEqual(expected_a, parsed['a'])
+ self.assertLabeledTensorsEqual(expected_b, parsed['b'])
+
+ def test_placeholder(self):
+ serialized = core.LabeledTensor(tf.placeholder(tf.string, [None]),
+ ['batch'])
+ # should not raise
+ io_ops.parse_example(serialized, self.features)
+
+
+class TestParseSingleExample(ParseBase):
+
+ def test(self):
+ expected_a = core.LabeledTensor(tf.constant(1), [])
+ expected_b = core.LabeledTensor(tf.constant([2, 3, 4]), ['x'])
+ parsed = io_ops.parse_single_example(self.serialized[0], self.features)
+ self.assertLabeledTensorsEqual(expected_a, parsed['a'])
+ self.assertLabeledTensorsEqual(expected_b, parsed['b'])
+
+ def test_unknown_size(self):
+ features = {'a': io_ops.FixedLenFeature([('x', None)], tf.int64)}
+ serialized = tf.placeholder(tf.string, [])
+ with self.assertRaisesRegexp(ValueError, 'unknown size'):
+ io_ops.parse_single_example(serialized, features)
+
+
+class PlaceholderTest(test_util.Base):
+
+ def test_name(self):
+ placeholder_lt = io_ops.placeholder(tf.float32, [])
+ self.assertIn('lt_placeholder', placeholder_lt.name)
+
+ def test(self):
+ placeholder_lt = io_ops.placeholder(tf.float32,
+ ['batch', ('x', ['a', 'b'])])
+ self.assertEqual(placeholder_lt.dtype, tf.float32)
+ self.assertEqual(placeholder_lt.axes,
+ core.Axes([('batch', None), ('x', ['a', 'b'])]))
+
+ def test_feed(self):
+ sess = tf.Session()
+ placeholder_lt = io_ops.placeholder(tf.float32, [])
+ two_times = 2.0 * placeholder_lt
+ result = sess.run(two_times, {placeholder_lt.tensor: 1})
+ self.assertEqual(result, 2.0)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/nn.py b/tensorflow/contrib/labeled_tensor/python/ops/nn.py
new file mode 100644
index 0000000000..dce16ccf27
--- /dev/null
+++ b/tensorflow/contrib/labeled_tensor/python/ops/nn.py
@@ -0,0 +1,42 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Neural network ops for LabeledTensors."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.labeled_tensor.python.ops import core
+from tensorflow.python.ops import nn
+
+
+relu = core.define_unary_op('relu', nn.relu)
+relu6 = core.define_unary_op('relu6', nn.relu6)
+crelu = core.define_unary_op('crelu', nn.crelu)
+elu = core.define_unary_op('elu', nn.elu)
+softplus = core.define_unary_op('softplus', nn.softplus)
+
+l2_loss = core.define_unary_op('l2_loss', nn.l2_loss)
+sigmoid_cross_entropy_with_logits = core.define_binary_op(
+ 'sigmoid_cross_entropy_with_logits',
+ nn.sigmoid_cross_entropy_with_logits)
+softmax = core.define_unary_op('softmax', nn.softmax)
+log_softmax = core.define_unary_op('log_softmax', nn.log_softmax)
+softmax_cross_entropy_with_logits = core.define_binary_op(
+ 'softmax_cross_entropy_with_logits',
+ nn.softmax_cross_entropy_with_logits)
+sparse_softmax_cross_entropy_with_logits = core.define_binary_op(
+ 'sparse_softmax_cross_entropy_with_logits',
+ nn.sparse_softmax_cross_entropy_with_logits)
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/nn_test.py b/tensorflow/contrib/labeled_tensor/python/ops/nn_test.py
new file mode 100644
index 0000000000..18cbd8b4ed
--- /dev/null
+++ b/tensorflow/contrib/labeled_tensor/python/ops/nn_test.py
@@ -0,0 +1,70 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.contrib.labeled_tensor.python.ops import core
+from tensorflow.contrib.labeled_tensor.python.ops import nn
+from tensorflow.contrib.labeled_tensor.python.ops import test_util
+
+
+class NNTests(test_util.Base):
+
+ def setUp(self):
+ super(NNTests, self).setUp()
+ self.axes = ['x']
+ self.original_lt = core.LabeledTensor([0.0, 0.5, 1.0], self.axes)
+ self.other_lt = 1 - self.original_lt
+
+ def test_unary_ops(self):
+ ops = [
+ ('relu', tf.nn.relu, nn.relu),
+ ('relu6', tf.nn.relu6, nn.relu6),
+ ('crelu', tf.nn.crelu, nn.crelu),
+ ('elu', tf.nn.elu, nn.elu),
+ ('softplus', tf.nn.softplus, nn.softplus),
+ ('l2_loss', tf.nn.l2_loss, nn.l2_loss),
+ ('softmax', tf.nn.softmax, nn.softmax),
+ ('log_softmax', tf.nn.log_softmax, nn.log_softmax),
+ ]
+ for op_name, tf_op, lt_op in ops:
+ golden_tensor = tf_op(self.original_lt.tensor)
+ golden_lt = core.LabeledTensor(golden_tensor, self.axes)
+ actual_lt = lt_op(self.original_lt)
+ self.assertIn(op_name, actual_lt.name)
+ self.assertLabeledTensorsEqual(golden_lt, actual_lt)
+
+ def test_binary_ops(self):
+ ops = [
+ ('sigmoid_cross_entropy_with_logits',
+ tf.nn.sigmoid_cross_entropy_with_logits,
+ nn.sigmoid_cross_entropy_with_logits),
+ ('softmax_cross_entropy_with_logits',
+ tf.nn.softmax_cross_entropy_with_logits,
+ nn.softmax_cross_entropy_with_logits),
+ ('sparse_softmax_cross_entropy_with_logits',
+ tf.nn.sparse_softmax_cross_entropy_with_logits,
+ nn.sparse_softmax_cross_entropy_with_logits),
+ ]
+ for op_name, tf_op, lt_op in ops:
+ golden_tensor = tf_op(self.original_lt.tensor, self.other_lt.tensor)
+ golden_lt = core.LabeledTensor(golden_tensor, self.axes)
+ actual_lt = lt_op(self.original_lt, self.other_lt)
+ self.assertIn(op_name, actual_lt.name)
+ self.assertLabeledTensorsEqual(golden_lt, actual_lt)
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/ops.py b/tensorflow/contrib/labeled_tensor/python/ops/ops.py
new file mode 100644
index 0000000000..a9ddbbd2cf
--- /dev/null
+++ b/tensorflow/contrib/labeled_tensor/python/ops/ops.py
@@ -0,0 +1,1207 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Non-core ops for LabeledTensor."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import types
+
+import numpy as np
+from six import string_types
+
+from tensorflow.contrib.labeled_tensor.python.ops import _typecheck as tc
+from tensorflow.contrib.labeled_tensor.python.ops import core
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import numerics
+from tensorflow.python.ops import random_ops
+from tensorflow.python.training import input # pylint: disable=redefined-builtin
+
+
+@tc.returns(core.LabeledTensor)
+@tc.accepts(core.LabeledTensor, ops.Output, core.Axis,
+ tc.Optional(string_types))
+def _gather_1d_on_axis(labeled_tensor, indexer, axis, name=None):
+ with ops.name_scope(name, 'lt_take', [labeled_tensor]) as scope:
+ temp_axes = core.Axes(
+ [axis] + list(labeled_tensor.axes.remove(axis.name).values()))
+ transposed = core.transpose(labeled_tensor, temp_axes.keys())
+ indexed = core.LabeledTensor(array_ops.gather(transposed.tensor, indexer),
+ temp_axes)
+ return core.transpose(indexed, labeled_tensor.axes.keys(), name=scope)
+
+
+@tc.returns(core.LabeledTensor)
+@tc.accepts(core.LabeledTensorLike,
+ tc.Mapping(string_types, tc.Union(
+ slice, collections.Hashable, collections.Sequence)),
+ tc.Optional(string_types))
+def select(labeled_tensor, selection, name=None):
+ """Slice out a subset of the tensor.
+
+ Args:
+ labeled_tensor: The input tensor.
+ selection: A dictionary mapping an axis name to a scalar, slice or list of
+ values to select. Currently supports two types of selections:
+ (a) Any number of scalar and/or slice selections.
+ (b) Exactly one list selection, without any scalars or slices.
+ name: Optional op name.
+
+ Returns:
+ The selection as a `LabeledTensor`.
+
+ Raises:
+ ValueError: If the tensor doesn't have an axis in the selection or if
+ that axis lacks labels.
+ KeyError: If any labels in a selection are not found in the original axis.
+ NotImplementedError: If you attempt to combine a list selection with
+ scalar selection or another list selection.
+ """
+ with ops.name_scope(name, 'lt_select', [labeled_tensor]) as scope:
+ labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
+
+ slices = {}
+ indexers = {}
+ for axis_name, value in selection.items():
+ if axis_name not in labeled_tensor.axes:
+ raise ValueError(
+ 'The tensor does not have an axis named %s. Its axes are: %r' %
+ (axis_name, labeled_tensor.axes.keys()))
+ axis = labeled_tensor.axes[axis_name]
+ if axis.labels is None:
+ raise ValueError(
+ 'The axis named %s does not have labels. The axis is: %r' %
+ (axis_name, axis))
+
+ if isinstance(value, slice):
+ # TODO(shoyer): consider deprecating using slices in favor of lists
+ if value.start is None:
+ start = None
+ else:
+ start = axis.index(value.start)
+
+ if value.stop is None:
+ stop = None
+ else:
+ # For now, follow the pandas convention of making labeled slices
+ # inclusive of both bounds.
+ stop = axis.index(value.stop) + 1
+
+ if value.step is not None:
+ raise NotImplementedError('slicing with a step is not yet supported')
+
+ slices[axis_name] = slice(start, stop)
+
+ else:
+ # We're allowing anything NumPy treats as a scalar or 1D array.
+ value = np.asarray(value)
+ if value.ndim == 0:
+ slices[axis_name] = axis.index(value.item())
+ elif value.ndim == 1:
+ if indexers:
+ raise NotImplementedError(
+ 'select does not yet support more than one list selection at '
+ 'the same time')
+ indexer = [axis.index(v) for v in value.tolist()]
+ indexers[axis_name] = ops.convert_to_tensor(
+ indexer, dtype=dtypes.int64)
+ else:
+ raise NotImplementedError(
+ 'select does not yet support selections with more than one '
+ 'dimension: %s on axis %r' % (value, axis_name))
+
+ if indexers and slices:
+ raise NotImplementedError(
+ 'select does not yet support combined scalar and list selection')
+
+ # For now, handle array selection separately, because tf.gather_nd does
+ # not support gradients yet. Later, using gather_nd will let us combine
+ # these paths.
+ if indexers:
+ (axis_name, indexer), = indexers.items()
+ axis = core.Axis(axis_name, selection[axis_name])
+ return _gather_1d_on_axis(labeled_tensor, indexer, axis, name=scope)
+ else:
+ return core.slice_function(labeled_tensor, slices, name=scope)
+
+
+@tc.returns(core.LabeledTensor)
+@tc.accepts(tc.Collection(core.LabeledTensorLike), string_types,
+ tc.Optional(string_types))
+def concat(labeled_tensors, axis_name, name=None):
+ """Concatenate tensors along a dimension.
+
+ See tf.concat.
+
+ Args:
+ labeled_tensors: A list of input LabeledTensors.
+ axis_name: The name of the axis along which to concatenate.
+ name: Optional op name.
+
+ Returns:
+ The concatenated tensor.
+ The coordinate labels for the concatenation dimension are also concatenated,
+ if they are available for every tensor.
+
+ Raises:
+ ValueError: If fewer than one tensor inputs is provided, if the tensors
+ have incompatible axes, or if `axis_name` isn't the name of an axis.
+ """
+ with ops.name_scope(name, 'lt_concat', labeled_tensors) as scope:
+ labeled_tensors = [core.convert_to_labeled_tensor(lt)
+ for lt in labeled_tensors]
+
+ if len(labeled_tensors) < 1:
+ raise ValueError('concat expects at least 1 tensor, but received %s' %
+ labeled_tensors)
+
+ # All tensors must have these axes.
+ axes_0 = labeled_tensors[0].axes
+ axis_names = list(axes_0.keys())
+
+ if axis_name not in axis_names:
+ raise ValueError('%s not in %s' % (axis_name, axis_names))
+
+ shared_axes = axes_0.remove(axis_name)
+
+ tensors = [labeled_tensors[0].tensor]
+ concat_axis_list = [axes_0[axis_name]]
+ for labeled_tensor in labeled_tensors[1:]:
+ current_shared_axes = labeled_tensor.axes.remove(axis_name)
+ if current_shared_axes != shared_axes:
+ # TODO(shoyer): add more specific checks about what went wrong,
+ # including raising AxisOrderError when appropriate
+ raise ValueError('Mismatched shared axes: the first tensor '
+ 'had axes %r but this tensor has axes %r.' %
+ (shared_axes, current_shared_axes))
+
+ # Accumulate the axis labels, if they're available.
+ concat_axis_list.append(labeled_tensor.axes[axis_name])
+ tensors.append(labeled_tensor.tensor)
+
+ concat_axis = core.concat_axes(concat_axis_list)
+ concat_dimension = axis_names.index(axis_name)
+ concat_tensor = array_ops.concat(concat_dimension, tensors, name=scope)
+ values = list(axes_0.values())
+ concat_axes = (values[:concat_dimension] + [concat_axis] +
+ values[concat_dimension + 1:])
+
+ return core.LabeledTensor(concat_tensor, concat_axes)
+
+
+# TODO(shoyer): rename pack/unpack to stack/unstack
+
+
+@tc.returns(core.LabeledTensor)
+@tc.accepts(
+ tc.Collection(core.LabeledTensorLike),
+ tc.Union(string_types, core.AxisLike),
+ int, tc.Optional(string_types))
+def pack(labeled_tensors, new_axis, axis_position=0, name=None):
+ """Pack tensors along a new axis.
+
+ See tf.pack.
+
+ Args:
+ labeled_tensors: The input tensors, which must have identical axes.
+ new_axis: The name of the new axis, or a tuple containing the name
+ and coordinate labels.
+ axis_position: Optional integer position at which to insert the new axis.
+ name: Optional op name.
+
+ Returns:
+ The packed tensors as a single LabeledTensor, with `new_axis` in the given
+ `axis_position`.
+
+ Raises:
+ ValueError: If fewer than one input tensors is provided, or if the tensors
+ don't have identical axes.
+ """
+ with ops.name_scope(name, 'lt_pack', labeled_tensors) as scope:
+ labeled_tensors = [core.convert_to_labeled_tensor(lt)
+ for lt in labeled_tensors]
+
+ if len(labeled_tensors) < 1:
+ raise ValueError('pack expects at least 1 tensors, but received %s' %
+ labeled_tensors)
+
+ axes_0 = labeled_tensors[0].axes
+ for t in labeled_tensors:
+ if t.axes != axes_0:
+ raise ValueError('Non-identical axes. Expected %s but got %s' %
+ (axes_0, t.axes))
+
+ pack_op = array_ops.stack(
+ [t.tensor for t in labeled_tensors], axis=axis_position, name=scope)
+ axes = list(axes_0.values())
+ axes.insert(axis_position, new_axis)
+ return core.LabeledTensor(pack_op, axes)
+
+
+@tc.returns(tc.List(core.LabeledTensor))
+@tc.accepts(core.LabeledTensorLike, tc.Optional(string_types),
+ tc.Optional(string_types))
+def unpack(labeled_tensor, axis_name=None, name=None):
+ """Unpack the tensor.
+
+ See tf.unpack.
+
+ Args:
+ labeled_tensor: The input tensor.
+ axis_name: Optional name of axis to unpack. By default, the first axis is
+ used.
+ name: Optional op name.
+
+ Returns:
+ The list of unpacked LabeledTensors.
+
+ Raises:
+ ValueError: If `axis_name` is not an axis on the input.
+ """
+ with ops.name_scope(name, 'lt_unpack', [labeled_tensor]) as scope:
+ labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
+
+ axis_names = list(labeled_tensor.axes.keys())
+ if axis_name is None:
+ axis_name = axis_names[0]
+
+ if axis_name not in axis_names:
+ raise ValueError('%s not in %s' % (axis_name, axis_names))
+ axis = axis_names.index(axis_name)
+
+ unpack_ops = array_ops.unstack(labeled_tensor.tensor, axis=axis, name=scope)
+ axes = [a for i, a in enumerate(labeled_tensor.axes.values())
+ if i != axis]
+ return [core.LabeledTensor(t, axes) for t in unpack_ops]
+
+
+@tc.returns(core.LabeledTensor)
+@tc.accepts(core.LabeledTensorLike, tc.Collection(string_types),
+ tc.Collection(tc.Union(string_types, core.AxisLike)),
+ tc.Optional(string_types))
+def reshape(labeled_tensor, existing_axes, new_axes, name=None):
+ """Reshape specific axes of a LabeledTensor.
+
+ Non-indicated axes remain in their original locations.
+
+ Args:
+ labeled_tensor: The input tensor.
+ existing_axes: List of axis names found on the input tensor. These must
+ appear sequentially in the list of axis names on the input. In other
+ words, they must be a valid slice of `list(labeled_tensor.axes.keys())`.
+ new_axes: List of strings, tuples of (axis_name, axis_value) or Axis objects
+ providing new axes with which to replace `existing_axes` in the reshaped
+ result. At most one element of `new_axes` may be a string, indicating an
+ axis with unknown size.
+ name: Optional op name.
+
+ Returns:
+ The reshaped LabeledTensor.
+
+ Raises:
+ ValueError: If `existing_axes` are not all axes on the input, or if more
+ than one of `new_axes` has unknown size.
+ AxisOrderError: If `existing_axes` are not a slice of axis names on the
+ input.
+ """
+ with ops.name_scope(name, 'lt_reshape', [labeled_tensor]) as scope:
+ labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
+
+ original_axis_names = list(labeled_tensor.axes.keys())
+ existing_axes = list(existing_axes)
+ if not set(existing_axes) <= set(original_axis_names):
+ raise ValueError('existing_axes %r are not contained in the set of axis '
+ 'names %r on the input labeled tensor' %
+ (existing_axes, original_axis_names))
+
+ start = original_axis_names.index(existing_axes[0])
+ stop = original_axis_names.index(existing_axes[-1]) + 1
+
+ if existing_axes != original_axis_names[start:stop]:
+ # We could support existing_axes that aren't a slice by using transpose,
+ # but that could lead to unpredictable performance consequences because
+ # transposes are not free in TensorFlow. If we did transpose
+ # automatically, the user might never realize that their data is being
+ # produced with the wrong order. (The later will occur with some frequency
+ # because of how broadcasting automatically choose axis order.)
+ # So for now we've taken the strict approach.
+ raise core.AxisOrderError(
+ 'existing_axes %r are not a slice of axis names %r on the input '
+ 'labeled tensor. Use `transpose` or `impose_axis_order` to reorder '
+ 'axes on the input explicitly.' %
+ (existing_axes, original_axis_names))
+
+ if sum(isinstance(axis, string_types) for axis in new_axes) > 1:
+ raise ValueError(
+ 'at most one axis in new_axes can have unknown size. All other '
+ 'axes must have an indicated integer size or labels: %r' % new_axes)
+
+ original_values = list(labeled_tensor.axes.values())
+ axis_size = lambda axis: -1 if axis.size is None else axis.size
+ shape = [axis_size(axis) for axis in original_values[:start]]
+ for axis_ref in new_axes:
+ if isinstance(axis_ref, string_types):
+ shape.append(-1)
+ else:
+ axis = core.as_axis(axis_ref)
+ shape.append(axis_size(axis))
+ shape.extend(axis_size(axis) for axis in original_values[stop:])
+
+ reshaped_tensor = array_ops.reshape(
+ labeled_tensor.tensor, shape, name=scope)
+ axes = original_values[:start] + list(new_axes) + original_values[stop:]
+ return core.LabeledTensor(reshaped_tensor, axes)
+
+
+@tc.returns(core.LabeledTensor)
+@tc.accepts(core.LabeledTensorLike, string_types, string_types,
+ tc.Optional(string_types))
+def rename_axis(labeled_tensor, existing_name, new_name, name=None):
+ """Rename an axis of LabeledTensor.
+
+ Args:
+ labeled_tensor: The input tensor.
+ existing_name: Name for an existing axis on the input.
+ new_name: Desired replacement name.
+ name: Optional op name.
+
+ Returns:
+ LabeledTensor with renamed axis.
+
+ Raises:
+ ValueError: If `existing_name` is not an axis on the input.
+ """
+ with ops.name_scope(name, 'lt_rename_axis', [labeled_tensor]) as scope:
+ if existing_name not in labeled_tensor.axes:
+ raise ValueError('existing_name %r are not contained in the set of axis '
+ 'names %r on the input labeled tensor' %
+ (existing_name, labeled_tensor.axes.keys()))
+ new_axis = core.Axis(new_name, labeled_tensor.axes[existing_name].value)
+ return reshape(labeled_tensor, [existing_name], [new_axis], name=scope)
+
+
+@tc.returns(tc.List(core.LabeledTensor))
+@tc.accepts(string_types, collections.Callable, int, bool,
+ tc.Collection(core.LabeledTensorLike), bool,
+ tc.Optional(string_types))
+def _batch_helper(default_name,
+ batch_fn,
+ batch_size,
+ enqueue_many,
+ labeled_tensors,
+ allow_smaller_final_batch,
+ name=None):
+ with ops.name_scope(name, default_name, labeled_tensors) as scope:
+ labeled_tensors = [core.convert_to_labeled_tensor(lt)
+ for lt in labeled_tensors]
+
+ batch_ops = batch_fn([t.tensor for t in labeled_tensors], scope)
+ # TODO(shoyer): Remove this when they sanitize the TF API.
+ if not isinstance(batch_ops, list):
+ assert isinstance(batch_ops, ops.Tensor)
+ batch_ops = [batch_ops]
+
+ if allow_smaller_final_batch:
+ batch_size = None
+
+ @tc.returns(core.Axes)
+ @tc.accepts(core.Axes)
+ def output_axes(axes):
+ if enqueue_many:
+ if 'batch' not in axes or list(axes.keys()).index('batch') != 0:
+ raise ValueError(
+ 'When enqueue_many is True, input tensors must have an axis '
+ 'called "batch" as their first dimension, '
+ 'but axes were %s' % axes)
+ culled_axes = axes.remove('batch')
+ return core.Axes([('batch', batch_size)] + list(culled_axes.values()))
+ else:
+ return core.Axes([('batch', batch_size)] + list(axes.values()))
+
+ output_labeled_tensors = []
+ for i, tensor in enumerate(batch_ops):
+ axes = output_axes(labeled_tensors[i].axes)
+ output_labeled_tensors.append(core.LabeledTensor(tensor, axes))
+
+ return output_labeled_tensors
+
+
+@tc.returns(tc.List(core.LabeledTensor))
+@tc.accepts(
+ tc.Collection(core.LabeledTensorLike), int, int, int, bool, bool,
+ tc.Optional(string_types))
+def batch(labeled_tensors,
+ batch_size,
+ num_threads=1,
+ capacity=32,
+ enqueue_many=False,
+ allow_smaller_final_batch=False,
+ name=None):
+ """Rebatch a tensor.
+
+ See tf.batch.
+
+ Args:
+ labeled_tensors: The input tensors.
+ batch_size: The output batch size.
+ num_threads: See tf.batch.
+ capacity: See tf.batch.
+ enqueue_many: If true, the input tensors must contain a 'batch' axis as
+ their first axis.
+ If false, the input tensors must not contain a 'batch' axis.
+ See tf.batch.
+ allow_smaller_final_batch: See tf.batch.
+ name: Optional op name.
+
+ Returns:
+ The rebatched tensors.
+ If enqueue_many is false, the output tensors will have a new 'batch' axis
+ as their first axis.
+
+ Raises:
+ ValueError: If enqueue_many is True and the first axis of the tensors
+ isn't "batch".
+ """
+
+ def fn(tensors, scope):
+ return input.batch(tensors,
+ batch_size=batch_size,
+ num_threads=num_threads,
+ capacity=capacity,
+ enqueue_many=enqueue_many,
+ allow_smaller_final_batch=allow_smaller_final_batch,
+ name=scope)
+
+ return _batch_helper('lt_batch', fn, batch_size, enqueue_many,
+ labeled_tensors, allow_smaller_final_batch, name)
+
+
+@tc.returns(tc.List(core.LabeledTensor))
+@tc.accepts(
+ tc.Collection(core.LabeledTensorLike), int, int, int, bool, int,
+ tc.Optional(int), bool, tc.Optional(string_types))
+def shuffle_batch(labeled_tensors,
+ batch_size,
+ num_threads=1,
+ capacity=32,
+ enqueue_many=False,
+ min_after_dequeue=0,
+ seed=None,
+ allow_smaller_final_batch=False,
+ name=None):
+ """Rebatch a tensor, with shuffling.
+
+ See tf.batch.
+
+ Args:
+ labeled_tensors: The input tensors.
+ batch_size: The output batch size.
+ num_threads: See tf.batch.
+ capacity: See tf.batch.
+ enqueue_many: If true, the input tensors must contain a 'batch' axis as
+ their first axis.
+ If false, the input tensors must not contain a 'batch' axis.
+ See tf.batch.
+ min_after_dequeue: Minimum number of elements in the queue after a dequeue,
+ used to ensure mixing.
+ seed: Optional random seed.
+ allow_smaller_final_batch: See tf.batch.
+ name: Optional op name.
+
+ Returns:
+ The rebatched tensors.
+ If enqueue_many is false, the output tensors will have a new 'batch' axis
+ as their first axis.
+
+ Raises:
+ ValueError: If enqueue_many is True and the first axis of the tensors
+ isn't "batch".
+ """
+
+ def fn(tensors, scope):
+ return input.shuffle_batch(
+ tensors,
+ batch_size=batch_size,
+ num_threads=num_threads,
+ capacity=capacity,
+ enqueue_many=enqueue_many,
+ min_after_dequeue=min_after_dequeue,
+ seed=seed,
+ allow_smaller_final_batch=allow_smaller_final_batch,
+ name=scope)
+
+ return _batch_helper('lt_shuffle_batch', fn, batch_size, enqueue_many,
+ labeled_tensors, allow_smaller_final_batch, name)
+
+
+@tc.returns(core.LabeledTensor)
+@tc.accepts(core.LabeledTensorLike, tc.Mapping(string_types, int),
+ tc.Optional(int), tc.Optional(string_types))
+def random_crop(labeled_tensor, shape_map, seed=None, name=None):
+ """Randomly crops a tensor to a given size.
+
+ See tf.random_crop.
+
+ Args:
+ labeled_tensor: The input tensor.
+ shape_map: A dictionary mapping axis names to the size of the random crop
+ for that dimension.
+ seed: An optional random seed.
+ name: An optional op name.
+
+ Returns:
+ A tensor of the same rank as `labeled_tensor`, cropped randomly in the
+ selected dimensions.
+
+ Raises:
+ ValueError: If the shape map contains an axis name not in the input tensor.
+ """
+ with ops.name_scope(name, 'lt_random_crop', [labeled_tensor]) as scope:
+ labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
+
+ for axis_name in shape_map:
+ if axis_name not in labeled_tensor.axes:
+ raise ValueError('Selection axis %s not in axes %s' %
+ (axis_name, labeled_tensor.axes))
+
+ shape = []
+ axes = []
+ for axis in labeled_tensor.axes.values():
+ if axis.name in shape_map:
+ size = shape_map[axis.name]
+ shape.append(size)
+ # We lose labels for the axes we crop, leaving just the size.
+ axes.append((axis.name, size))
+ else:
+ shape.append(len(axis))
+ axes.append(axis)
+
+ crop_op = random_ops.random_crop(labeled_tensor.tensor,
+ shape,
+ seed=seed,
+ name=scope)
+
+ return core.LabeledTensor(crop_op, axes)
+
+
+# TODO(shoyer): Allow the user to select the axis over which to map.
+@tc.returns(core.LabeledTensor)
+@tc.accepts(collections.Callable, core.LabeledTensorLike,
+ tc.Optional(string_types))
+def map_fn(fn, labeled_tensor, name=None):
+ """Map on the list of tensors unpacked from labeled_tensor.
+
+ See tf.map_fn.
+
+ Args:
+ fn: The function to apply to each unpacked LabeledTensor.
+ It should have type LabeledTensor -> LabeledTensor.
+ labeled_tensor: The input tensor.
+ name: Optional op name.
+
+ Returns:
+ A tensor that packs the results of applying fn to the list of tensors
+ unpacked from labeled_tensor.
+ """
+ with ops.name_scope(name, 'lt_map_fn', [labeled_tensor]) as scope:
+ labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
+ unpack_lts = unpack(labeled_tensor)
+ map_lts = [fn(t) for t in unpack_lts]
+ return pack(map_lts, list(labeled_tensor.axes.values())[0], name=scope)
+
+
+@tc.returns(core.LabeledTensor)
+@tc.accepts(core.LabeledTensorLike, tc.Optional(tc.Collection(string_types)),
+ tc.Optional(string_types))
+def squeeze(labeled_tensor, axis_names=None, name=None):
+ """Remove size-1 dimensions.
+
+ See tf.squeeze.
+
+ Args:
+ labeled_tensor: The input tensor.
+ axis_names: The names of the dimensions to remove, or None to remove
+ all size-1 dimensions.
+ name: Optional op name.
+
+ Returns:
+ A tensor with the specified dimensions removed.
+
+ Raises:
+ ValueError: If the named axes are not in the tensor, or if they are
+ not size-1.
+ """
+ with ops.name_scope(name, 'lt_squeeze', [labeled_tensor]) as scope:
+ labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
+
+ if axis_names is None:
+ axis_names = [a.name for a in labeled_tensor.axes.values() if len(a) == 1]
+
+ for axis_name in axis_names:
+ if axis_name not in labeled_tensor.axes:
+ raise ValueError('axis %s is not in tensor axes %s' %
+ (axis_name, labeled_tensor.axes))
+ elif len(labeled_tensor.axes[axis_name]) != 1:
+ raise ValueError(
+ 'cannot squeeze axis with size greater than 1: (%s, %s)' %
+ (axis_name, labeled_tensor.axes[axis_name]))
+
+ squeeze_dimensions = []
+ axes = []
+ for i, axis in enumerate(labeled_tensor.axes.values()):
+ if axis.name in axis_names:
+ squeeze_dimensions.append(i)
+ else:
+ axes.append(axis)
+
+ if squeeze_dimensions:
+ squeeze_op = array_ops.squeeze(labeled_tensor.tensor,
+ squeeze_dimensions,
+ name=scope)
+ else:
+ squeeze_op = array_ops.identity(labeled_tensor.tensor, name=scope)
+
+ return core.LabeledTensor(squeeze_op, axes)
+
+# pylint: disable=invalid-name
+ReduceAxis = tc.Union(
+ string_types, tc.Tuple(string_types, collections.Hashable))
+ReduceAxes = tc.Optional(tc.Union(ReduceAxis, tc.Collection(ReduceAxis)))
+# pylint: enable=invalid-name
+
+
+@tc.returns(core.LabeledTensor)
+@tc.accepts(core.LabeledTensorLike, core.LabeledTensorLike,
+ tc.Optional(string_types))
+def matmul(a, b, name=None):
+ """Matrix multiply two tensors with rank 1 or 2.
+
+ If both tensors have rank 2, a matrix-matrix product is performed.
+ If one tensor has rank 1 and the other has rank 2, then a matrix-vector
+ product is performed.
+ If both tensors have rank 1, then a vector dot-product is performed.
+ (This behavior matches that of `numpy.dot`.)
+
+ Both tensors must share exactly one dimension in common, which is the
+ dimension the operation is summed along. The inputs will be automatically
+ transposed if necessary as part of the matmul op.
+
+ We intend to eventually support `matmul` on higher rank input, and also
+ eventually support summing over any number shared dimensions (via an `axis`
+ argument), but neither of these features has been implemented yet.
+
+ Args:
+ a: First LabeledTensor.
+ b: Second LabeledTensor.
+ name: Optional op name.
+
+ Returns:
+ LabeledTensor with the result of matrix multiplication. Axes are ordered by
+ the current axis_order_scope, if set, or in or order of appearance on the
+ inputs.
+
+ Raises:
+ NotImplementedError: If inputs have rank >2 or share multiple axes.
+ ValueError: If the inputs have rank 0 or do not share any axes.
+ """
+ with ops.name_scope(name, 'lt_matmul', [a, b]) as scope:
+
+ a = core.convert_to_labeled_tensor(a)
+ b = core.convert_to_labeled_tensor(b)
+
+ if len(a.axes) > 2 or len(b.axes) > 2:
+ # We could use tf.batch_matmul to make this work, but we would also need
+ # to use tf.tile and/or tf.transpose. These are more expensive than doing
+ # reshapes, so it's not clear if it's a good idea to do this
+ # automatically.
+ raise NotImplementedError(
+ 'matmul currently requires inputs with rank 2 or less, but '
+ 'inputs have ranks %r and %r' % (len(a.axes), len(b.axes)))
+
+ if not a.axes or not b.axes:
+ raise ValueError(
+ 'matmul currently requires inputs with at least rank 1, but '
+ 'inputs have ranks %r and %r' % (len(a.axes), len(b.axes)))
+
+ shared_axes = set(a.axes) & set(b.axes)
+ if len(shared_axes) > 1:
+ raise NotImplementedError(
+ 'matmul does not yet support summing over multiple shared axes: %r. '
+ 'Use transpose and reshape to create a single shared axis to sum '
+ 'over.' % shared_axes)
+ if not shared_axes:
+ raise ValueError('there must have exactly one axis in common between '
+ 'input to matmul: %r, %r' %
+ (a.axes.keys(), b.axes.keys()))
+ shared_axis, = shared_axes
+
+ if a.axes[shared_axis] != b.axes[shared_axis]:
+ raise ValueError('axis %r does not match on input arguments: %r vs %r' %
+ (shared_axis, a.axes[shared_axis].value,
+ b.axes[shared_axis].value))
+
+ result_axes = []
+ for axes in [a.axes, b.axes]:
+ for axis in axes.values():
+ if axis.name != shared_axis:
+ result_axes.append(axis)
+
+ axis_scope_order = core.get_axis_order()
+ if axis_scope_order is not None:
+ result_axis_names = [axis.name for axis in result_axes]
+ new_axis_names = [name for name in axis_scope_order
+ if name in result_axis_names]
+ if new_axis_names != result_axis_names:
+ # switch a and b
+ b, a = a, b
+ # result_axes is a list of length 1 or 2
+ result_axes = result_axes[::-1]
+
+ squeeze_dims = []
+
+ if len(a.axes) == 1:
+ a_tensor = array_ops.reshape(a.tensor, (1, -1))
+ squeeze_dims.append(0)
+ transpose_a = False
+ else:
+ a_tensor = a.tensor
+ transpose_a = list(a.axes.keys()).index(shared_axis) == 0
+
+ if len(b.axes) == 1:
+ b_tensor = array_ops.reshape(b.tensor, (-1, 1))
+ squeeze_dims.append(1)
+ transpose_b = False
+ else:
+ b_tensor = b.tensor
+ transpose_b = list(b.axes.keys()).index(shared_axis) == 1
+
+ result_op = math_ops.matmul(a_tensor,
+ b_tensor,
+ transpose_a=transpose_a,
+ transpose_b=transpose_b)
+
+ if squeeze_dims:
+ result_op = array_ops.squeeze(result_op, squeeze_dims)
+ result_op = array_ops.identity(result_op, name=scope)
+
+ return core.LabeledTensor(result_op, result_axes)
+
+
+@tc.returns(types.FunctionType)
+@tc.accepts(string_types, collections.Callable)
+def define_reduce_op(op_name, reduce_fn):
+ """Define a reduction op for labeled tensors.
+
+ Args:
+ op_name: string name of the TensorFlow op.
+ reduce_fn: function to call to evaluate the op on a tf.Tensor.
+
+ Returns:
+ Function defining the given reduction op that acts on a LabeledTensor.
+ """
+
+ default_name = 'lt_%s' % op_name
+
+ @tc.returns(core.LabeledTensor)
+ @tc.accepts(core.LabeledTensorLike, ReduceAxes, tc.Optional(string_types))
+ def op(labeled_tensor, axes=None, name=None):
+ """Computes the given reduction across the given axes of a LabeledTensor.
+
+ See `tf.{op_name}` for full details.
+
+ Args:
+ labeled_tensor: The input tensor.
+ axes: A set of axes or None.
+ If None, all axes will be reduced.
+ Axes must all be strings, in which case those dimensions will be
+ removed, or pairs of (name, None) or (name, label), in which case those
+ dimensions will be kept.
+ name: Optional op name.
+
+ Returns:
+ The reduced LabeledTensor.
+
+ Raises:
+ ValueError: if any of the axes to reduce over are not found on
+ `labeled_tensor`.
+ """
+ with ops.name_scope(name, default_name, [labeled_tensor]) as scope:
+ labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
+
+ if axes is None:
+ axes = labeled_tensor.axes.keys()
+
+ if isinstance(axes, (string_types, tuple)):
+ axes = [axes]
+
+ reduction_axes = {}
+ axes_to_squeeze = []
+ for a in axes:
+ if isinstance(a, string_types):
+ # We squeeze out this axis.
+ reduction_axes[a] = a
+ axes_to_squeeze.append(a)
+ else:
+ # We keep this axis, with the user-provided labels.
+ (axis_name, label) = a
+ if label is not None:
+ # The input was a single label, so make it a list so it can be
+ # turned into an Axis.
+ label = [label]
+ reduction_axes[axis_name] = (axis_name, label)
+
+ for axis_name in reduction_axes:
+ if axis_name not in labeled_tensor.axes:
+ raise ValueError('Axis %s not in axes %s' %
+ (axis_name, labeled_tensor.axes))
+
+ intermediate_axes = []
+ reduction_dimensions = []
+ for i, axis in enumerate(labeled_tensor.axes.values()):
+ if axis.name in reduction_axes:
+ intermediate_axes.append(reduction_axes[axis.name])
+ reduction_dimensions.append(i)
+ else:
+ intermediate_axes.append(axis)
+
+ reduce_op = reduce_fn(labeled_tensor.tensor,
+ reduction_dimensions,
+ keep_dims=True)
+ reduce_lt = core.LabeledTensor(reduce_op, intermediate_axes)
+
+ return squeeze(reduce_lt, axes_to_squeeze, name=scope)
+
+ op.__doc__ = op.__doc__.format(op_name=op_name)
+ op.__name__ = op_name
+
+ return op
+
+
+reduce_all = define_reduce_op('reduce_all', math_ops.reduce_all)
+reduce_any = define_reduce_op('reduce_any', math_ops.reduce_any)
+reduce_logsumexp = define_reduce_op('reduce_logsumexp',
+ math_ops.reduce_logsumexp)
+reduce_max = define_reduce_op('reduce_max', math_ops.reduce_max)
+reduce_mean = define_reduce_op('reduce_mean', math_ops.reduce_mean)
+reduce_min = define_reduce_op('reduce_min', math_ops.reduce_min)
+reduce_prod = define_reduce_op('reduce_prod', math_ops.reduce_prod)
+reduce_sum = define_reduce_op('reduce_sum', math_ops.reduce_sum)
+
+
+@tc.returns(core.LabeledTensor)
+@tc.accepts(core.LabeledTensorLike, tc.Mapping(str, tc.Union(int, ops.Tensor)),
+ tc.Optional(string_types))
+def tile(labeled_tensor, multiples, name=None):
+ """Constructs a tensor by tiling a given tensor.
+
+ Only axes without tick-labels can be tiled. (Otherwise, axis labels on tiled
+ tensors would no longer be unique.)
+
+ See lt.tile.
+
+ Args:
+ labeled_tensor: The input tensor.
+ multiples: A mapping where the keys are axis names and the values are the
+ integer number of times to tile along that axis. Only axes with a multiple
+ different than 1 need be included.
+ name: Optional op name.
+
+ Returns:
+ A tensor with the indicated axes tiled.
+
+ Raises:
+ ValueError: If the tiled axes are not axes in the input tensor, or if any
+ axes in multiples have tick labels.
+ """
+ with ops.name_scope(name, 'lt_tile', [labeled_tensor]) as scope:
+ labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
+
+ if not set(multiples.keys()) <= set(labeled_tensor.axes.keys()):
+ raise ValueError('tile axes %r are not contained in the set of axis '
+ 'names %r on the input labeled tensor' %
+ (multiples.keys(), labeled_tensor.axes))
+
+ labeled_axes = [name for name in multiples
+ if labeled_tensor.axes[name].labels is not None]
+ if labeled_axes:
+ raise ValueError('cannot tile axes with tick labels: %r' % labeled_axes)
+
+ multiples_list = [multiples.get(name, 1) for name in labeled_tensor.axes]
+ tile_op = array_ops.tile(labeled_tensor.tensor, multiples_list, name=scope)
+
+ new_axes = [axis.name if axis.labels is None else axis
+ for axis in labeled_tensor.axes.values()]
+ return core.LabeledTensor(tile_op, new_axes)
+
+
+@tc.returns(core.LabeledTensor)
+@tc.accepts(core.LabeledTensorLike,
+ tc.Mapping(str, tc.Tuple(core.AxisValue, core.AxisValue)),
+ string_types, tc.Optional(string_types))
+def pad(labeled_tensor, paddings, mode='CONSTANT', name=None):
+ """Pads a tensor.
+
+ See tf.pad.
+
+ Args:
+ labeled_tensor: The input tensor.
+ paddings: A mapping where the keys are axis names and the values are
+ tuples where the first element is the padding to insert at the beginning
+ of the axis and the second is the padding to insert at the end of the
+ axis.
+ mode: One of "CONSTANT", "REFLECT", or "SYMMETRIC".
+ name: Optional op name.
+
+ Returns:
+ A tensor with the indicated axes padded, optionally with those axes extended
+ with the provided labels.
+
+ Raises:
+ ValueError: If the padded axes are not axes in the input tensor.
+ """
+ with ops.name_scope(name, 'lt_pad', [labeled_tensor]) as scope:
+ labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
+
+ if not set(paddings.keys()) <= set(labeled_tensor.axes.keys()):
+ raise ValueError('pad axes %r are not contained in the set of axis '
+ 'names %r on the input labeled tensor' %
+ (paddings.keys(), labeled_tensor.axes))
+
+ new_axes = []
+ padding_pairs = []
+ for name, axis in labeled_tensor.axes.items():
+ if name in paddings:
+ padding_before, padding_after = paddings[name]
+ axis_before = core.Axis(name, padding_before)
+ axis_after = core.Axis(name, padding_after)
+ new_axes.append(core.concat_axes([axis_before, axis, axis_after]))
+ padding_pairs.append((len(axis_before), len(axis_after)))
+ else:
+ new_axes.append(axis)
+ padding_pairs.append((0, 0))
+
+ pad_op = array_ops.pad(
+ labeled_tensor.tensor, padding_pairs, mode, name=scope)
+
+ return core.LabeledTensor(pad_op, new_axes)
+
+
+@tc.returns(core.LabeledTensor)
+@tc.accepts(tc.Union(np.ndarray, list, tuple, core.Scalar),
+ tc.Optional(dtypes.DType),
+ tc.Optional(tc.Union(
+ core.Axes,
+ tc.Collection(tc.Union(string_types, core.AxisLike)))),
+ tc.Optional(string_types))
+def constant(value, dtype=None, axes=None, name=None):
+ """Creates a constant tensor.
+
+ If `axes` includes any strings, shape is inferred from `value`. Otherwise,
+ the sizes of the given `axes` are used to set `shape` for `tf.constant`.
+
+ See tf.constant for more details.
+
+ Args:
+ value: The input tensor.
+ dtype: The type of the returned tensor.
+ axes: Optional Axes, list of strings or list of objects coercible to Axis
+ objects. By default, axes are assumed to be an empty list (i.e., `value`
+ is treated as a scalar).
+ name: Optional op name.
+
+ Returns:
+ The tensor with elements set to zero.
+ """
+ with ops.name_scope(name, 'lt_constant', [value]) as scope:
+
+ if axes is None:
+ axes = []
+
+ if isinstance(axes, core.Axes):
+ axes = axes.values()
+
+ if any(isinstance(ax, string_types) for ax in axes):
+ # need to infer shape
+ shape = None
+ else:
+ # axes already indicate shape
+ axes = [core.as_axis(a) for a in axes]
+ shape = [a.size for a in axes]
+
+ op = array_ops.constant(value, dtype=dtype, shape=shape, name=scope)
+ return core.LabeledTensor(op, axes)
+
+
+@tc.returns(core.LabeledTensor)
+@tc.accepts(core.LabeledTensorLike, tc.Optional(dtypes.DType),
+ tc.Optional(string_types))
+def zeros_like(labeled_tensor, dtype=None, name=None):
+ """Creates an identical tensor with all elements set to zero.
+
+ Args:
+ labeled_tensor: The input tensor.
+ dtype: The type of the returned tensor.
+ name: Optional op name.
+
+ Returns:
+ The tensor with elements set to zero.
+ """
+ with ops.name_scope(name, 'lt_zeros_like', [labeled_tensor]) as scope:
+ labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
+ op = array_ops.zeros_like(labeled_tensor.tensor, dtype=dtype, name=scope)
+ return core.LabeledTensor(op, labeled_tensor.axes)
+
+
+@tc.returns(core.LabeledTensor)
+@tc.accepts(core.LabeledTensorLike, tc.Optional(dtypes.DType),
+ tc.Optional(string_types))
+def ones_like(labeled_tensor, dtype=None, name=None):
+ """Creates an identical tensor with all elements set to one.
+
+ Args:
+ labeled_tensor: The input tensor.
+ dtype: The type of the returned tensor.
+ name: Optional op name.
+
+ Returns:
+ The tensor with elements set to one.
+ """
+ with ops.name_scope(name, 'lt_ones_like', [labeled_tensor]) as scope:
+ labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
+ op = array_ops.ones_like(labeled_tensor.tensor, dtype=dtype, name=scope)
+ return core.LabeledTensor(op, labeled_tensor.axes)
+
+
+@tc.returns(core.LabeledTensor)
+@tc.accepts(core.LabeledTensorLike, tc.Optional(dtypes.DType),
+ tc.Optional(string_types))
+def cast(labeled_tensor, dtype=None, name=None):
+ """Casts a labeled tensor to a new type.
+
+ Args:
+ labeled_tensor: The input tensor.
+ dtype: The type of the returned tensor.
+ name: Optional op name.
+
+ Returns:
+ A labeled tensor with the new dtype.
+ """
+ with ops.name_scope(name, 'lt_cast', [labeled_tensor]) as scope:
+ labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
+ op = math_ops.cast(labeled_tensor.tensor, dtype=dtype, name=scope)
+ return core.LabeledTensor(op, labeled_tensor.axes)
+
+
+@tc.returns(core.LabeledTensor)
+@tc.accepts(core.LabeledTensorLike, string_types, tc.Optional(string_types))
+def verify_tensor_all_finite(labeled_tensor, message, name=None):
+ """Asserts a tensor doesn't contain NaNs or Infs.
+
+ See tf.verify_tensor_all_finite.
+
+ Args:
+ labeled_tensor: The input tensor.
+ message: Message to log on failure.
+ name: Optional op name.
+
+ Returns:
+ The input tensor.
+ """
+ with ops.name_scope(name, 'lt_verify_tensor_all_finite',
+ [labeled_tensor]) as scope:
+ labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
+ op = numerics.verify_tensor_all_finite(labeled_tensor.tensor,
+ msg=message,
+ name=scope)
+ return core.LabeledTensor(op, labeled_tensor.axes)
+
+
+@tc.returns(core.LabeledTensor)
+@tc.accepts(core.LabeledTensorLike, core.LabeledTensorLike,
+ tc.Optional(string_types))
+def boolean_mask(labeled_tensor, mask, name=None):
+ """Apply a boolean mask to a labeled tensor.
+
+ Unlike `tf.boolean_mask`, this currently only works on 1-dimensional masks.
+ The mask is applied to the first axis of `labeled_tensor`. Labels on the first
+ axis are removed, because True indices in `mask` may not be known dynamically.
+
+ Args:
+ labeled_tensor: The input tensor.
+ mask: The type of the returned tensor.
+ name: Optional op name.
+
+ Returns:
+ The masked labeled tensor.
+
+ Raises:
+ ValueError: if the first axis of the mask
+ """
+ with ops.name_scope(name, 'lt_boolean_mask', [labeled_tensor, mask]) as scope:
+ labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
+ mask = core.convert_to_labeled_tensor(mask)
+
+ if len(mask.axes) > 1:
+ raise NotImplementedError(
+ "LabeledTensor's boolean_mask currently only supports 1D masks")
+ mask_axis = list(mask.axes.values())[0]
+ lt_axis = list(labeled_tensor.axes.values())[0]
+ if mask_axis != lt_axis:
+ raise ValueError('the first axis of the labeled tensor and the mask '
+ 'are not equal:\n%r\n%r' % (lt_axis, mask_axis))
+ op = array_ops.boolean_mask(labeled_tensor.tensor, mask.tensor, name=scope)
+ # TODO(shoyer): attempt to infer labels for the masked values, by calling
+ # tf.contrib.util.constant_value on the mask?
+ axes = [lt_axis.name] + list(labeled_tensor.axes.values())[1:]
+ return core.LabeledTensor(op, axes)
+
+
+@tc.returns(core.LabeledTensor)
+@tc.accepts(core.LabeledTensorLike, core.LabeledTensorLike,
+ core.LabeledTensorLike, tc.Optional(string_types))
+def where(condition, x, y, name=None):
+ """Return elements from x or y depending on condition.
+
+ See `tf.where` for more details. This function currently only implements the
+ three argument version of where.
+
+ Args:
+ condition: LabeledTensor of type `bool`.
+ x: LabeledTensor for values where condition is true.
+ y: LabeledTensor for values where condition is false.
+ name: Optional op name.
+
+ Returns:
+ The labeled tensor with values according to condition.
+
+ Raises:
+ ValueError: if `x` and `y` have different axes, or if the axes of `x` do not
+ start with the axes of `condition`.
+ """
+ with ops.name_scope(name, 'lt_where', [condition, x, y]) as scope:
+ condition = core.convert_to_labeled_tensor(condition)
+ x = core.convert_to_labeled_tensor(x)
+ y = core.convert_to_labeled_tensor(y)
+
+ if not condition.axes == x.axes == y.axes:
+ raise ValueError('all inputs to `where` must have equal axes')
+
+ op = array_ops.where(condition.tensor, x.tensor, y.tensor, name=scope)
+ return core.LabeledTensor(op, x.axes)
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py b/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py
new file mode 100644
index 0000000000..c19fc09f93
--- /dev/null
+++ b/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py
@@ -0,0 +1,918 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from six.moves import range # pylint: disable=redefined-builtin
+import tensorflow as tf
+
+from tensorflow.contrib.labeled_tensor.python.ops import core
+from tensorflow.contrib.labeled_tensor.python.ops import ops
+from tensorflow.contrib.labeled_tensor.python.ops import test_util
+
+
+class Base(test_util.Base):
+
+ def setUp(self):
+ super(Base, self).setUp()
+
+ self.x_size = 7
+ self.channel_size = 3
+ self.z_size = 4
+ self.probs_size = 11
+
+ tensor = tf.range(0, self.x_size * self.channel_size * self.z_size *
+ self.probs_size)
+ tensor = tf.reshape(tensor, [self.x_size, self.channel_size, self.z_size,
+ self.probs_size])
+ a0 = ('x', range(self.x_size))
+ a1 = ('channel', ['red', 'green', 'blue'])
+ a2 = 'z'
+ a3 = ('probs', np.linspace(0.0, 1.0, self.probs_size))
+
+ self.tensor = tensor
+ self.a0 = a0
+ self.a1 = a1
+ self.a2 = a2
+ self.a2_resolved = ('z', self.z_size)
+ self.a3 = a3
+ self.original_lt = core.LabeledTensor(tensor, [a0, a1, a2, a3])
+
+ self.x_probs_lt = core.slice_function(self.original_lt, {'z': 0})
+ self.x_probs_lt = ops.select(self.x_probs_lt, {'channel': 'red'})
+ self.channel_probs_lt = core.slice_function(self.original_lt, {'x': 3,
+ 'z': 0})
+
+
+class SelectTest(Base):
+
+ def test_name(self):
+ select_lt = ops.select(self.original_lt, {'channel': 'green'})
+ self.assertIn('lt_select', select_lt.name)
+
+ def test_scalar(self):
+ select_lt = ops.select(self.original_lt, {'channel': 'green'})
+ golden_lt = core.LabeledTensor(self.tensor[:, 1, :, :], [self.a0, self.a2,
+ self.a3])
+ self.assertLabeledTensorsEqual(select_lt, golden_lt)
+
+ def test_slice(self):
+ select_lt = ops.select(self.original_lt, {'channel': slice('red', 'green')})
+ a1_sliced = ('channel', ['red', 'green'])
+ golden_lt = core.LabeledTensor(self.tensor[:, :2, :, :],
+ [self.a0, a1_sliced, self.a2, self.a3])
+ self.assertLabeledTensorsEqual(select_lt, golden_lt)
+
+ def test_slices(self):
+ select_lt = ops.select(self.original_lt, {'x': slice(1, 4),
+ 'channel': slice('green', None)})
+
+ a0_sliced = ('x', range(1, 5))
+ a1_sliced = ('channel', ['green', 'blue'])
+ golden_lt = core.LabeledTensor(self.tensor[1:5, 1:, :, :],
+ [a0_sliced, a1_sliced, self.a2, self.a3])
+ self.assertLabeledTensorsEqual(select_lt, golden_lt)
+
+ def test_list(self):
+ select_lt = ops.select(self.original_lt, {'channel': ['red', 'green']})
+ a1_sliced = ('channel', ['red', 'green'])
+ golden_lt = core.LabeledTensor(self.tensor[:, :2, :, :],
+ [self.a0, a1_sliced, self.a2, self.a3])
+ self.assertLabeledTensorsEqual(select_lt, golden_lt)
+
+ def test_list_one_item(self):
+ select_lt = ops.select(self.original_lt, {'channel': ['red']})
+ a1_sliced = ('channel', ['red'])
+ golden_lt = core.LabeledTensor(self.tensor[:, :1, :, :],
+ [self.a0, a1_sliced, self.a2, self.a3])
+ self.assertLabeledTensorsEqual(select_lt, golden_lt)
+
+ def test_list_zero_items(self):
+ select_lt = ops.select(self.original_lt, {'channel': []})
+ golden_lt = core.LabeledTensor(self.tensor[:, :0, :, :],
+ [self.a0, 'channel', self.a2, self.a3])
+ self.assertLabeledTensorsEqual(select_lt, golden_lt)
+
+ def test_scalars(self):
+ select_lt = ops.select(self.original_lt, {'x': 1, 'channel': 'green'})
+ golden_lt = core.LabeledTensor(self.tensor[1, 1, :, :],
+ [self.a2, self.a3])
+ self.assertLabeledTensorsEqual(select_lt, golden_lt)
+
+ def test_invalid_input(self):
+ with self.assertRaises(ValueError):
+ ops.select(self.original_lt, {'foo': 1})
+ with self.assertRaises(ValueError):
+ ops.select(self.original_lt, {'z': 1})
+ with self.assertRaises(KeyError):
+ ops.select(self.original_lt, {'channel': 'purple'})
+ with self.assertRaises(KeyError):
+ ops.select(self.original_lt, {'channel': ['red', 'purple']})
+ with self.assertRaises(NotImplementedError):
+ ops.select(self.original_lt, {'channel': ['red'], 'x': [1]})
+ with self.assertRaises(NotImplementedError):
+ ops.select(self.original_lt, {'channel': ['red'], 'x': 1})
+ with self.assertRaises(NotImplementedError):
+ ops.select(self.original_lt, {'channel': slice('red', 'green', 2)})
+
+
+class ConcatTest(Base):
+
+ def setUp(self):
+ super(ConcatTest, self).setUp()
+
+ self.red_lt = ops.select(self.original_lt, {'channel': ['red']})
+ self.green_lt = ops.select(self.original_lt, {'channel': ['green']})
+ self.blue_lt = ops.select(self.original_lt, {'channel': ['blue']})
+
+ def test_name(self):
+ concat_lt = ops.concat([self.red_lt, self.blue_lt], 'channel')
+ self.assertIn('lt_concat', concat_lt.name)
+
+ def test(self):
+ concat_lt = ops.concat([self.red_lt, self.green_lt], 'channel')
+ golden_lt = ops.select(self.original_lt, {'channel': ['red', 'green']})
+
+ self.assertLabeledTensorsEqual(concat_lt, golden_lt)
+
+ def test_transposed(self):
+ green_transposed = core.transpose(self.green_lt,
+ ['probs', 'channel', 'z', 'x'])
+ with self.assertRaises(ValueError):
+ ops.concat([self.red_lt, green_transposed], 'channel')
+
+ def test_invalid_input(self):
+ with self.assertRaises(ValueError):
+ ops.concat([], 'channel')
+ with self.assertRaises(ValueError):
+ ops.concat([self.red_lt, self.red_lt], 'channel')
+ with self.assertRaises(ValueError):
+ ops.concat([self.red_lt, self.red_lt], 'foo')
+
+
+class PackTest(Base):
+
+ def test_name(self):
+ pack_lt = ops.pack([self.original_lt, self.original_lt], 'batch')
+ self.assertIn('lt_pack', pack_lt.name)
+
+ def test(self):
+ pack_lt = ops.pack([self.original_lt, self.original_lt], 'batch')
+ golden_lt = core.LabeledTensor(
+ tf.stack([self.original_lt.tensor, self.original_lt.tensor]),
+ ['batch', self.a0, self.a1, self.a2, self.a3])
+
+ self.assertLabeledTensorsEqual(pack_lt, golden_lt)
+
+ def test_axis(self):
+ pack_lt = ops.pack([self.original_lt, self.original_lt],
+ new_axis='batch',
+ axis_position=4)
+ golden_lt = core.LabeledTensor(
+ tf.stack(
+ [self.original_lt.tensor, self.original_lt.tensor], axis=4),
+ [self.a0, self.a1, self.a2, self.a3, 'batch'])
+
+ self.assertLabeledTensorsEqual(pack_lt, golden_lt)
+
+ def test_invalid_input(self):
+ with self.assertRaises(ValueError):
+ ops.pack([self.original_lt, self.original_lt], 'channel')
+
+
+class UnpackTest(Base):
+
+ def test_name(self):
+ unpack_lts = ops.unpack(self.original_lt)
+ for t in unpack_lts:
+ self.assertIn('lt_unpack', t.name)
+
+ def test(self):
+ unpack_lt = ops.unpack(self.original_lt)[0]
+ golden_lt = core.LabeledTensor(
+ tf.unstack(self.original_lt.tensor)[0], [self.a1, self.a2, self.a3])
+
+ self.assertLabeledTensorsEqual(unpack_lt, golden_lt)
+
+ def test_axis(self):
+ unpack_lt = ops.unpack(self.original_lt, axis_name='z')[0]
+ golden_lt = core.LabeledTensor(
+ tf.unstack(
+ self.original_lt.tensor, axis=2)[0], [self.a0, self.a1, self.a3])
+
+ self.assertLabeledTensorsEqual(unpack_lt, golden_lt)
+
+ def test_invalid_input(self):
+ with self.assertRaises(ValueError):
+ ops.unpack(self.original_lt, axis_name='not_found')
+
+
+class ReshapeTest(Base):
+
+ def test_name(self):
+ reshape_lt = ops.reshape(self.original_lt, ['channel'], ['foo'])
+ self.assertIn('lt_reshape', reshape_lt.name)
+
+ def test_identity(self):
+ reshape_lt = ops.reshape(self.original_lt, self.original_lt.axes.keys(),
+ self.original_lt.axes.values())
+ self.assertLabeledTensorsEqual(reshape_lt, self.original_lt)
+
+ def test_known_size(self):
+ new_dim_size = self.channel_size * self.z_size * self.probs_size
+ reshape_lt = ops.reshape(self.original_lt, ['channel', 'z', 'probs'],
+ [('new_dim', new_dim_size)])
+ golden_lt = core.LabeledTensor(
+ tf.reshape(self.original_lt.tensor, [self.x_size, -1]),
+ [self.original_lt.axes['x'], 'new_dim'])
+ self.assertLabeledTensorsEqual(reshape_lt, golden_lt)
+
+ def test_unknown_size(self):
+ reshape_lt = ops.reshape(self.original_lt, ['channel', 'z', 'probs'],
+ ['new_dim'])
+ golden_lt = core.LabeledTensor(
+ tf.reshape(self.original_lt.tensor, [self.x_size, -1]),
+ [self.original_lt.axes['x'], 'new_dim'])
+ self.assertLabeledTensorsEqual(reshape_lt, golden_lt)
+
+ def test_unknown_dimension(self):
+ orig_lt = core.LabeledTensor(tf.placeholder(tf.float32, [None]), ['x'])
+ reshape_lt = ops.reshape(orig_lt, ['x'], ['y', ('z', 1)])
+ self.assertEqual(reshape_lt.axes, core.Axes([('y', None), ('z', 1)]))
+ with self.test_session() as sess:
+ result = sess.run(reshape_lt, feed_dict={orig_lt.tensor: [1, 2]})
+ np.testing.assert_array_equal(result, [[1], [2]])
+
+ def test_with_labels(self):
+ new_dim_size = self.channel_size * self.z_size * self.probs_size
+ reshape_lt = ops.reshape(self.original_lt, ['channel', 'z', 'probs'],
+ [('new_dim', range(new_dim_size))])
+ golden_lt = core.LabeledTensor(
+ tf.reshape(self.original_lt.tensor, [self.x_size, -1]),
+ [self.original_lt.axes['x'], ('new_dim', range(new_dim_size))])
+ self.assertLabeledTensorsEqual(reshape_lt, golden_lt)
+
+ def test_invalid_input(self):
+ with self.assertRaisesRegexp(ValueError, 'not contained in the set'):
+ ops.reshape(self.original_lt, ['foo'], ['bar'])
+ with self.assertRaisesRegexp(core.AxisOrderError,
+ 'not a slice of axis names'):
+ ops.reshape(self.original_lt, ['probs', 'z'], ['bar'])
+ with self.assertRaisesRegexp(ValueError, 'at most one axis in new_axes'):
+ ops.reshape(self.original_lt, ['probs'], ['foo', 'bar'])
+
+
+class RenameAxisTest(Base):
+
+ def test_name(self):
+ rename_axis_lt = ops.rename_axis(self.original_lt, 'channel', 'foo')
+ self.assertIn('lt_rename_axis', rename_axis_lt.name)
+
+ def test_identity(self):
+ rename_axis_lt = ops.rename_axis(self.original_lt, 'channel', 'channel')
+ self.assertLabeledTensorsEqual(rename_axis_lt, self.original_lt)
+
+ def test_new_name(self):
+ rename_axis_lt = ops.rename_axis(self.original_lt, 'channel', 'foo')
+ expected_axes = [(name if name != 'channel' else 'foo', axis.value)
+ for name, axis in self.original_lt.axes.items()]
+ expected_lt = core.LabeledTensor(self.original_lt.tensor, expected_axes)
+ self.assertLabeledTensorsEqual(rename_axis_lt, expected_lt)
+
+ def test_invalid_input(self):
+ with self.assertRaisesRegexp(ValueError, 'not contained in the set'):
+ ops.rename_axis(self.original_lt, 'foo', 'bar')
+
+
+class BatchTest(Base):
+
+ def setUp(self):
+ super(BatchTest, self).setUp()
+
+ tensors = []
+ for i in range(10):
+ offset_lt = core.LabeledTensor(tf.constant(i), [])
+ tensors.append(core.add(self.original_lt, offset_lt))
+ self.pack_lt = ops.pack(tensors, 'batch')
+
+ def test_name(self):
+ batch_ops = ops.batch([self.pack_lt, self.pack_lt],
+ batch_size=2,
+ enqueue_many=True)
+ for bo in batch_ops:
+ self.assertIn('lt_batch', bo.name)
+
+ def test_enqueue_many(self):
+ [batch_2_op] = ops.batch([self.pack_lt], batch_size=2, enqueue_many=True)
+ self.assertEqual(len(batch_2_op.axes['batch']), 2)
+
+ [batch_10_op] = ops.batch([batch_2_op], batch_size=10, enqueue_many=True)
+
+ self.assertLabeledTensorsEqual(self.pack_lt, batch_10_op)
+
+ def test_no_enqueue_many(self):
+ [batch_2_op] = ops.batch([self.original_lt], batch_size=2)
+ self.assertEqual(len(batch_2_op.axes['batch']), 2)
+
+ [batch_10_op] = ops.batch([batch_2_op], batch_size=10, enqueue_many=True)
+
+ self.assertLabeledTensorsEqual(
+ ops.pack(10 * [self.original_lt], 'batch'), batch_10_op)
+
+ def test_invalid_input(self):
+ with self.assertRaises(ValueError):
+ ops.batch([self.original_lt], 3, enqueue_many=True)
+
+ def test_allow_smaller_final_batch(self):
+ [batch_2_op] = ops.batch([self.original_lt], batch_size=2,
+ allow_smaller_final_batch=True)
+ self.assertEqual(batch_2_op.axes['batch'].size, None)
+
+
+class ShuffleBatchTest(Base):
+
+ def setUp(self):
+ super(ShuffleBatchTest, self).setUp()
+
+ tensors = []
+ for i in range(10):
+ offset_lt = core.LabeledTensor(tf.constant(i), [])
+ tensors.append(core.add(self.original_lt, offset_lt))
+ self.pack_lt = ops.pack(tensors, 'batch')
+
+ def test_name(self):
+ batch_lts = ops.shuffle_batch([self.pack_lt, self.pack_lt],
+ batch_size=2,
+ enqueue_many=True)
+ for blt in batch_lts:
+ self.assertIn('lt_shuffle_batch', blt.name)
+
+ def test_enqueue_many(self):
+ [batch_2_lt] = ops.shuffle_batch([self.pack_lt],
+ batch_size=2,
+ enqueue_many=True,
+ min_after_dequeue=8,
+ seed=0)
+ self.assertEqual(len(batch_2_lt.axes['batch']), 2)
+
+ [batch_10_lt] = ops.batch([batch_2_lt], batch_size=10, enqueue_many=True)
+
+ self.assertEqual(batch_10_lt.axes, self.pack_lt.axes)
+ [batch_10, pack] = self.eval([batch_10_lt.tensor, self.pack_lt.tensor])
+ self.assertFalse((batch_10 == pack).all())
+
+ def test_allow_smaller_final_batch(self):
+ [batch_2_op] = ops.shuffle_batch([self.original_lt], batch_size=2,
+ allow_smaller_final_batch=True)
+ self.assertEqual(batch_2_op.axes['batch'].size, None)
+
+
+class RandomCropTest(Base):
+
+ def test_name(self):
+ crop_lt = ops.random_crop(self.original_lt, {'probs': 3})
+ self.assertIn('lt_random_crop', crop_lt.name)
+
+ def test_single(self):
+ crop_lt = ops.random_crop(self.original_lt, {'probs': 3})
+
+ self.assertEqual(
+ core.Axes([self.a0, self.a1, self.a2_resolved, ('probs', 3)]),
+ crop_lt.axes)
+
+ def test_double(self):
+ crop_lt = ops.random_crop(self.original_lt, {'probs': 3, 'channel': 2})
+
+ self.assertEqual(
+ core.Axes([self.a0, ('channel', 2), self.a2_resolved, ('probs', 3)]),
+ crop_lt.axes)
+
+ def test_size1(self):
+ crop_lt = ops.random_crop(self.original_lt, {'probs': 1})
+
+ self.assertEqual(
+ core.Axes([self.a0, self.a1, self.a2_resolved, ('probs', 1)]),
+ crop_lt.axes)
+
+ def test_different_seeds(self):
+ crop_0_lt = ops.random_crop(self.original_lt, {'probs': 3,
+ 'channel': 2},
+ seed=0)
+ crop_1_lt = ops.random_crop(self.original_lt, {'probs': 3,
+ 'channel': 2},
+ seed=1)
+
+ self.assertEqual(crop_0_lt.axes, crop_1_lt.axes)
+ [crop_0, crop_1] = self.eval([crop_0_lt.tensor, crop_1_lt.tensor])
+ self.assertFalse((crop_0 == crop_1).all())
+
+ def test_identical_seeds(self):
+ crop_0_lt = ops.random_crop(self.original_lt, {'probs': 3,
+ 'channel': 2},
+ seed=0)
+ crop_1_lt = ops.random_crop(self.original_lt, {'probs': 3,
+ 'channel': 2},
+ seed=0)
+
+ self.assertLabeledTensorsEqual(crop_0_lt, crop_1_lt)
+
+ def test_crop_idempotent(self):
+ crop_0_lt = ops.random_crop(self.original_lt, {'probs': 3,
+ 'channel': 2},
+ seed=0)
+ crop_1_lt = ops.random_crop(crop_0_lt, {'probs': 3, 'channel': 2}, seed=1)
+
+ self.assertLabeledTensorsEqual(crop_0_lt, crop_1_lt)
+
+ def test_invalid_input(self):
+ with self.assertRaises(ValueError):
+ ops.random_crop(self.original_lt, {'foobar': 2})
+
+
+class MapFnTest(Base):
+
+ def test_name(self):
+ map_lt = ops.map_fn(core.identity, self.original_lt)
+ self.assertIn('lt_map_fn', map_lt.name)
+
+ def test_identity(self):
+ map_lt = ops.map_fn(core.identity, self.original_lt)
+ self.assertLabeledTensorsEqual(map_lt, self.original_lt)
+
+ def test_callable_object(self):
+
+ class Identity(object):
+
+ def __call__(self, other):
+ return other
+
+ map_lt = ops.map_fn(Identity(), self.original_lt)
+ self.assertLabeledTensorsEqual(map_lt, self.original_lt)
+
+ def test_slice(self):
+ map_lt = ops.map_fn(lambda t: core.slice_function(t, {'channel': 1}),
+ self.original_lt)
+ slice_lt = core.slice_function(self.original_lt, {'channel': 1})
+ self.assertLabeledTensorsEqual(map_lt, slice_lt)
+
+
+class SqueezeTest(Base):
+
+ def setUp(self):
+ super(SqueezeTest, self).setUp()
+
+ self.squeezable_lt = core.slice_function(self.original_lt,
+ {'channel': slice(0, 1),
+ 'probs': slice(0, 1)})
+
+ def test_name(self):
+ squeeze_lt = ops.squeeze(self.squeezable_lt)
+ self.assertIn('lt_squeeze', squeeze_lt.name)
+
+ def test_none(self):
+ none_lt = ops.squeeze(self.squeezable_lt, None)
+ axes_lt = ops.squeeze(self.squeezable_lt, ['channel', 'probs'])
+ self.assertLabeledTensorsEqual(none_lt, axes_lt)
+
+ def test(self):
+ squeeze_lt = ops.squeeze(self.squeezable_lt, ['probs'])
+ golden_lt = core.slice_function(self.squeezable_lt, {'probs': 0})
+ self.assertLabeledTensorsEqual(squeeze_lt, golden_lt)
+
+ def test_invalid_input(self):
+ with self.assertRaises(ValueError):
+ ops.squeeze(self.original_lt, ['channel'])
+ with self.assertRaises(ValueError):
+ ops.squeeze(self.squeezable_lt, ['foo'])
+
+
+class MatMulTest(Base):
+
+ def test_name(self):
+ x_lt = core.LabeledTensor(tf.ones((3,)), ['x'])
+ matmul_lt = ops.matmul(x_lt, x_lt)
+ self.assertIn('lt_matmul', matmul_lt.name)
+
+ def test_vector_vector(self):
+ x_lt = core.LabeledTensor(tf.range(3), ['x'])
+ matmul_lt = ops.matmul(x_lt, x_lt)
+ golden_lt = core.convert_to_labeled_tensor(5)
+ self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
+
+ def test_matrix_vector(self):
+ xy_lt = core.LabeledTensor(tf.reshape(tf.range(6), (2, 3)), ['x', 'y'])
+ y_lt = core.LabeledTensor(tf.range(3), ['y'])
+
+ matmul_lt = ops.matmul(xy_lt, y_lt)
+ golden_lt = core.LabeledTensor(
+ tf.matmul(xy_lt.tensor, tf.reshape(y_lt.tensor, (-1, 1)))[:, 0], ['x'])
+ self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
+
+ matmul_lt = ops.matmul(y_lt, xy_lt)
+ self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
+
+ def test_matrix_matrix(self):
+ xy_lt = core.LabeledTensor(tf.reshape(tf.range(6), (2, 3)), ['x', 'y'])
+ yz_lt = core.LabeledTensor(tf.reshape(tf.range(12), (3, 4)), ['y', 'z'])
+
+ matmul_lt = ops.matmul(xy_lt, yz_lt)
+ golden_lt = core.LabeledTensor(
+ tf.matmul(xy_lt.tensor, yz_lt.tensor), ['x', 'z'])
+ self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
+
+ transpose = lambda x: core.transpose(x, list(x.axes.keys())[::-1])
+
+ matmul_lt = ops.matmul(xy_lt, transpose(yz_lt))
+ self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
+
+ matmul_lt = ops.matmul(transpose(xy_lt), yz_lt)
+ self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
+
+ matmul_lt = ops.matmul(transpose(xy_lt), transpose(yz_lt))
+ self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
+
+ matmul_lt = ops.matmul(yz_lt, xy_lt)
+ self.assertLabeledTensorsEqual(matmul_lt, transpose(golden_lt))
+
+ def test_matrix_matrix_axis_order(self):
+ xy_lt = core.LabeledTensor(tf.reshape(tf.range(6), (2, 3)), ['x', 'y'])
+ yz_lt = core.LabeledTensor(tf.reshape(tf.range(12), (3, 4)), ['y', 'z'])
+
+ golden_lt = core.LabeledTensor(
+ tf.matmul(xy_lt.tensor, yz_lt.tensor), ['x', 'z'])
+
+ with core.axis_order_scope(['x', 'y', 'z']):
+
+ matmul_lt = ops.matmul(xy_lt, yz_lt)
+ self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
+
+ matmul_lt = ops.matmul(yz_lt, xy_lt)
+ self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
+
+ def test_invalid(self):
+ scalar_lt = core.LabeledTensor(tf.ones(()), [])
+ x_lt = core.LabeledTensor(tf.ones((2,)), ['x'])
+ x2_lt = core.LabeledTensor(tf.ones((3,)), ['x'])
+ y_lt = core.LabeledTensor(tf.ones((3,)), ['y'])
+ xy_lt = core.LabeledTensor(tf.ones((2, 3)), ['x', 'y'])
+ xyz_lt = core.LabeledTensor(tf.ones((2, 3, 1)), ['x', 'y', 'z'])
+
+ with self.assertRaisesRegexp(ValueError, 'inputs with at least rank'):
+ ops.matmul(x_lt, scalar_lt)
+
+ with self.assertRaises(NotImplementedError):
+ ops.matmul(x_lt, xyz_lt)
+
+ with self.assertRaisesRegexp(ValueError, 'exactly one axis in common'):
+ ops.matmul(x_lt, y_lt)
+
+ with self.assertRaises(NotImplementedError):
+ ops.matmul(xy_lt, xy_lt)
+
+ with self.assertRaisesRegexp(ValueError, 'does not match'):
+ ops.matmul(x_lt, x2_lt)
+
+
+class ReduceSumTest(Base):
+
+ def test_name(self):
+ sum_lt = ops.reduce_sum(self.original_lt, {'channel'})
+ self.assertIn('lt_reduce_sum', sum_lt.name)
+
+ def test_drop_axis(self):
+ sum_lt = ops.reduce_sum(self.original_lt, {'channel'})
+ golden_lt = core.LabeledTensor(
+ tf.reduce_sum(self.original_lt.tensor, 1), [self.a0, self.a2, self.a3])
+ self.assertLabeledTensorsEqual(sum_lt, golden_lt)
+
+ def test_drop_scalar_axis(self):
+ sum_lt = ops.reduce_sum(self.original_lt, 'channel')
+ golden_lt = core.LabeledTensor(
+ tf.reduce_sum(self.original_lt.tensor, 1), [self.a0, self.a2, self.a3])
+ self.assertLabeledTensorsEqual(sum_lt, golden_lt)
+
+ def test_keep_axis(self):
+ sum_lt = ops.reduce_sum(self.original_lt, {('channel', 'hihowareyou')})
+ golden_lt = core.LabeledTensor(
+ tf.reduce_sum(self.original_lt.tensor,
+ 1, keep_dims=True),
+ [self.a0, ('channel', ['hihowareyou']), self.a2, self.a3])
+ self.assertLabeledTensorsEqual(sum_lt, golden_lt)
+
+ def test_keep_scalar_axis(self):
+ sum_lt = ops.reduce_sum(self.original_lt, ('channel', 'hihowareyou'))
+ golden_lt = core.LabeledTensor(
+ tf.reduce_sum(self.original_lt.tensor,
+ 1, keep_dims=True),
+ [self.a0, ('channel', ['hihowareyou']), self.a2, self.a3])
+ self.assertLabeledTensorsEqual(sum_lt, golden_lt)
+
+ def test_scalar(self):
+ scalar_lt = core.LabeledTensor(tf.constant(42), [])
+ reduce_lt = ops.reduce_sum(scalar_lt, [])
+ self.assertLabeledTensorsEqual(reduce_lt, scalar_lt)
+
+ def test_empty_list(self):
+ reduce_lt = ops.reduce_sum(self.original_lt, [])
+ self.assertLabeledTensorsEqual(reduce_lt, self.original_lt)
+
+ def test_none(self):
+ sum_lt = ops.reduce_sum(self.original_lt)
+ golden_lt = core.LabeledTensor(tf.reduce_sum(self.original_lt.tensor), [])
+ self.assertLabeledTensorsEqual(sum_lt, golden_lt)
+
+ def test_function_docstring_and_name(self):
+ self.assertIn('tf.reduce_sum', ops.reduce_sum.__doc__)
+ self.assertEqual('reduce_sum', ops.reduce_sum.__name__)
+
+
+class ReduceMeanTest(Base):
+
+ def test_name(self):
+ actual_lt = ops.reduce_mean(self.original_lt, {'channel'})
+ self.assertIn('lt_reduce_mean', actual_lt.name)
+
+ def test(self):
+ actual_lt = ops.reduce_mean(self.original_lt, {'channel'})
+ golden_lt = core.LabeledTensor(
+ tf.reduce_mean(self.original_lt.tensor, 1), [self.a0, self.a2, self.a3])
+ self.assertLabeledTensorsEqual(actual_lt, golden_lt)
+
+
+class ReduceProdTest(Base):
+
+ def test_name(self):
+ result_lt = ops.reduce_prod(self.original_lt, {'channel'})
+ self.assertIn('lt_reduce_prod', result_lt.name)
+
+ def test(self):
+ result_lt = ops.reduce_prod(self.original_lt, {'channel'})
+ golden_lt = core.LabeledTensor(
+ tf.reduce_prod(self.original_lt.tensor, 1), [self.a0, self.a2, self.a3])
+ self.assertLabeledTensorsEqual(result_lt, golden_lt)
+
+
+class ReduceMinTest(Base):
+
+ def test_name(self):
+ result_lt = ops.reduce_min(self.original_lt, {'channel'})
+ self.assertIn('lt_reduce_min', result_lt.name)
+
+ def test(self):
+ result_lt = ops.reduce_min(self.original_lt, {'channel'})
+ golden_lt = core.LabeledTensor(
+ tf.reduce_min(self.original_lt.tensor, 1), [self.a0, self.a2, self.a3])
+ self.assertLabeledTensorsEqual(result_lt, golden_lt)
+
+
+class ReduceMaxTest(Base):
+
+ def test_name(self):
+ result_lt = ops.reduce_max(self.original_lt, {'channel'})
+ self.assertIn('lt_reduce_max', result_lt.name)
+
+ def test(self):
+ result_lt = ops.reduce_max(self.original_lt, {'channel'})
+ golden_lt = core.LabeledTensor(
+ tf.reduce_max(self.original_lt.tensor, 1), [self.a0, self.a2, self.a3])
+ self.assertLabeledTensorsEqual(result_lt, golden_lt)
+
+
+class BaseReduceBoolean(Base):
+
+ def setUp(self):
+ super(BaseReduceBoolean, self).setUp()
+ self.bool_tensor = tf.cast(self.original_lt.tensor > 5, tf.bool)
+ self.bool_lt = core.LabeledTensor(self.bool_tensor, self.original_lt.axes)
+
+
+class ReduceAllTest(BaseReduceBoolean):
+
+ def test_name(self):
+ result_lt = ops.reduce_all(self.bool_lt, {'channel'})
+ self.assertIn('lt_reduce_all', result_lt.name)
+
+ def test(self):
+ result_lt = ops.reduce_all(self.bool_lt, {'channel'})
+ golden_lt = core.LabeledTensor(
+ tf.reduce_all(self.bool_tensor, 1), [self.a0, self.a2, self.a3])
+ self.assertLabeledTensorsEqual(result_lt, golden_lt)
+
+
+class ReduceAnyTest(BaseReduceBoolean):
+
+ def test_name(self):
+ result_lt = ops.reduce_any(self.bool_lt, {'channel'})
+ self.assertIn('lt_reduce_any', result_lt.name)
+
+ def test(self):
+ result_lt = ops.reduce_any(self.bool_lt, {'channel'})
+ golden_lt = core.LabeledTensor(
+ tf.reduce_any(self.bool_tensor, 1), [self.a0, self.a2, self.a3])
+ self.assertLabeledTensorsEqual(result_lt, golden_lt)
+
+
+class TileTest(Base):
+
+ def test_name(self):
+ tile_lt = ops.tile(self.original_lt, {'z': 2})
+ self.assertIn('lt_tile', tile_lt.name)
+
+ def test(self):
+ for multiple in [2, tf.constant(2)]:
+ tile_lt = ops.tile(self.original_lt, {'z': multiple})
+ golden_op = tf.tile(self.original_lt.tensor, [1, 1, multiple, 1])
+ golden_axes = ['z' if axis.name == 'z' else axis
+ for axis in self.original_lt.axes.values()]
+ golden_lt = core.LabeledTensor(golden_op, golden_axes)
+ self.assertLabeledTensorsEqual(tile_lt, golden_lt)
+
+ def test_invalid_input(self):
+ with self.assertRaisesRegexp(ValueError, 'are not contained in the set'):
+ ops.tile(self.original_lt, {'foo': 5})
+ with self.assertRaisesRegexp(ValueError, 'axes with tick labels'):
+ ops.tile(self.original_lt, {'x': 5})
+
+
+class PadTest(Base):
+
+ def test_name(self):
+ pad_lt = ops.pad(self.original_lt, {'x': (1, 1),
+ 'channel': ([], ['alpha'])})
+ self.assertIn('lt_pad', pad_lt.name)
+
+ def test(self):
+ pad_lt = ops.pad(self.original_lt, {'x': (1, 1),
+ 'channel': ([], ['alpha'])})
+
+ golden_op = tf.pad(self.original_lt.tensor, [[1, 1], [0, 1], [0, 0],
+ [0, 0]])
+ golden_axes = [('x', self.x_size + 2),
+ ('channel', ['red', 'green', 'blue', 'alpha']), self.a2,
+ self.a3]
+ golden_lt = core.LabeledTensor(golden_op, golden_axes)
+ self.assertLabeledTensorsEqual(pad_lt, golden_lt)
+
+ def test_invalid_input(self):
+ with self.assertRaisesRegexp(ValueError, 'are not contained in the set'):
+ ops.pad(self.original_lt, {'foo': (1, 1), 'channel': ([], ['alpha'])})
+
+
+class ConstantTest(Base):
+
+ def test_name(self):
+ constant_lt = ops.constant(1)
+ self.assertIn('lt_constant', constant_lt.name)
+
+ def test_scalar(self):
+ constant_lt = ops.constant(1)
+ golden_lt = core.LabeledTensor(tf.constant(1), [])
+ self.assertLabeledTensorsEqual(constant_lt, golden_lt)
+
+ def test_infer_shape(self):
+ constant_lt = ops.constant([1, 2], axes=['x'])
+ golden_lt = core.LabeledTensor(tf.constant([1, 2]), ['x'])
+ self.assertLabeledTensorsEqual(constant_lt, golden_lt)
+
+ def test_specify_shape(self):
+ constant_lt = ops.constant(1, axes=[('x', 3)])
+ golden_lt = core.LabeledTensor(tf.constant(1, shape=(3,)), ['x'])
+ self.assertLabeledTensorsEqual(constant_lt, golden_lt)
+
+ def test_existing_axes(self):
+ golden_lt = core.LabeledTensor(tf.constant([1, 2]), ['x'])
+ constant_lt = ops.constant([1, 2], axes=golden_lt.axes)
+ self.assertLabeledTensorsEqual(constant_lt, golden_lt)
+
+
+class ZerosLikeTest(Base):
+
+ def test_name(self):
+ like_lt = ops.zeros_like(self.original_lt)
+ self.assertIn('lt_zeros_like', like_lt.name)
+
+ def test(self):
+ like_lt = ops.zeros_like(self.original_lt)
+ golden_lt = core.LabeledTensor(
+ tf.zeros_like(self.original_lt.tensor), self.original_lt.axes)
+ self.assertLabeledTensorsEqual(like_lt, golden_lt)
+
+
+class OnesLikeTest(Base):
+
+ def test_name(self):
+ like_lt = ops.ones_like(self.original_lt)
+ self.assertIn('lt_ones_like', like_lt.name)
+
+ def test(self):
+ like_lt = ops.ones_like(self.original_lt)
+ golden_lt = core.LabeledTensor(
+ tf.ones_like(self.original_lt.tensor), self.original_lt.axes)
+ self.assertLabeledTensorsEqual(like_lt, golden_lt)
+
+
+class CastTest(Base):
+
+ def test_name(self):
+ cast_lt = ops.cast(self.original_lt, tf.float16)
+ self.assertIn('lt_cast', cast_lt.name)
+
+ def test(self):
+ cast_lt = ops.cast(self.original_lt, tf.float16)
+ golden_lt = core.LabeledTensor(
+ tf.cast(self.original_lt.tensor, tf.float16), self.original_lt.axes)
+ self.assertLabeledTensorsEqual(cast_lt, golden_lt)
+
+
+class VerifyTensorAllFiniteTest(Base):
+
+ def setUp(self):
+ super(VerifyTensorAllFiniteTest, self).setUp()
+
+ self.finite_lt = core.LabeledTensor(tf.constant(42.0), [])
+ self.nan_lt = core.LabeledTensor(tf.constant(np.nan), [])
+
+ self.checked_finite_lt = ops.verify_tensor_all_finite(self.finite_lt, '')
+ self.checked_nan_lt = ops.verify_tensor_all_finite(self.nan_lt, '')
+
+ def test_name(self):
+ self.assertIn('lt_verify_tensor_all_finite', self.checked_finite_lt.name)
+ self.assertIn('lt_verify_tensor_all_finite', self.checked_nan_lt.name)
+
+ def test_finite(self):
+ self.assertLabeledTensorsEqual(self.finite_lt, self.checked_finite_lt)
+
+ def test_nan(self):
+ with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
+ 'Tensor had NaN values'):
+ self.eval([self.checked_nan_lt])
+
+
+class BooleanMaskTest(Base):
+
+ def test_name(self):
+ mask = core.LabeledTensor(tf.range(7) > 3, [self.a0])
+ masked_lt = ops.boolean_mask(self.original_lt, mask)
+ self.assertIn('lt_boolean_mask', masked_lt.name)
+
+ def test(self):
+ mask = core.LabeledTensor(tf.range(7) > 3, [self.a0])
+ masked_lt = ops.boolean_mask(self.original_lt, mask)
+ golden_lt = core.LabeledTensor(
+ tf.boolean_mask(self.original_lt.tensor, mask.tensor),
+ ['x', self.a1, self.a2, self.a3])
+ self.assertLabeledTensorsEqual(masked_lt, golden_lt)
+
+ def test_invalid_rank(self):
+ mask = core.LabeledTensor(tf.ones((7, 3)) > 3, [self.a0, self.a1])
+ with self.assertRaises(NotImplementedError):
+ ops.boolean_mask(self.original_lt, mask)
+
+ def test_mismatched_axis(self):
+ mask = core.LabeledTensor(tf.range(7) > 3, ['foo'])
+ with self.assertRaisesRegexp(ValueError, 'not equal'):
+ ops.boolean_mask(self.original_lt, mask)
+
+
+class WhereTest(Base):
+
+ def test_name(self):
+ condition = core.LabeledTensor(tf.range(5) < 3, ['x'])
+ where_lt = ops.where(condition, condition, condition)
+ self.assertIn('lt_where', where_lt.name)
+
+ def test(self):
+ condition = core.LabeledTensor(tf.range(5) < 3, ['x'])
+ x = core.LabeledTensor(tf.ones(5), ['x'])
+ y = core.LabeledTensor(tf.zeros(5), ['x'])
+ where_lt = ops.where(condition, x, y)
+
+ golden_lt = core.LabeledTensor(
+ tf.concat(0, [tf.ones(3), tf.zeros(2)]), ['x'])
+ self.assertLabeledTensorsEqual(where_lt, golden_lt)
+
+ def test_mismatched_axes(self):
+ condition = core.LabeledTensor(tf.range(5) < 3, ['x'])
+ with self.assertRaisesRegexp(ValueError, 'equal axes'):
+ ops.where(condition, condition[:3], condition)
+ with self.assertRaisesRegexp(ValueError, 'equal axes'):
+ ops.where(condition, condition, condition[:3])
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/sugar.py b/tensorflow/contrib/labeled_tensor/python/ops/sugar.py
new file mode 100644
index 0000000000..914493f473
--- /dev/null
+++ b/tensorflow/contrib/labeled_tensor/python/ops/sugar.py
@@ -0,0 +1,131 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tools to make it a bit easier to use LabeledTensor."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from six import string_types
+
+from tensorflow.contrib.labeled_tensor.python.ops import _typecheck as tc
+from tensorflow.contrib.labeled_tensor.python.ops import core
+from tensorflow.contrib.labeled_tensor.python.ops import ops
+from tensorflow.python.framework import ops as tf_ops
+
+
+class ReshapeCoder(object):
+ """Utility class for mapping to and from another shape.
+
+ For example, say you have a function `crop_center` which expects a
+ LabeledTensor with axes named ['batch', 'row', 'column', 'depth'], and
+ you have a LabeledTensor `masked_image_lt` with axes ['batch', 'row',
+ 'column', 'channel', 'mask'].
+
+ To call `crop_center` with `masked_image_lt` you'd normally have to write:
+
+ >>> reshape_lt = lt.reshape(masked_image_lt, ['channel', 'mask'], ['depth'])
+ >>> crop_lt = crop_center(reshape_lt)
+ >>> result_lt = lt.reshape(crop_lt, ['depth'],
+ ... [masked_image_lt.axes['channel'], masked_image_lt.axes['mask']])
+
+ ReshapeCoder takes care of this renaming logic for you, allowing you to
+ instead write:
+
+ >>> rc = ReshapeCoder(['channel', 'mask'], ['depth'])
+ >>> result_lt = rc.decode(crop_center(rc.encode(masked_image_lt)))
+
+ Here, `decode` restores the original axes 'channel' and 'mask', so
+ `crop_center` must not have modified the size of the 'depth' axis.
+ """
+
+ @tc.accepts(object, tc.Collection(str),
+ tc.Collection(tc.Union(str, core.AxisLike)), tc.Optional(str))
+ def __init__(self, existing_axis_names, new_axes, name=None):
+ self._name = name
+ self._existing_axis_names = existing_axis_names
+ self._new_axes = new_axes
+
+ self._existing_axes = None
+
+ @tc.returns(core.LabeledTensor)
+ @tc.accepts(object, core.LabeledTensorLike)
+ def encode(self, labeled_tensor):
+ """Reshape the input to the target shape.
+
+ If called several times, the axes named in existing_axis_names must be
+ identical.
+
+ Args:
+ labeled_tensor: The input tensor.
+
+ Returns:
+ The input reshaped to the target shape.
+
+ Raises:
+ ValueError: If the axes in existing_axis_names don't match the axes of
+ a tensor in a previous invocation of this method.
+ """
+ with tf_ops.name_scope(self._name, 'lt_reshape_encode',
+ [labeled_tensor]) as scope:
+ labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
+
+ reshape_lt = ops.reshape(labeled_tensor,
+ self._existing_axis_names,
+ self._new_axes,
+ name=scope)
+
+ axes = [labeled_tensor.axes[n] for n in self._existing_axis_names]
+ if self._existing_axes is not None and self._existing_axes != axes:
+ raise ValueError(
+ 'input axes %r do not match axes from previous method call %r' %
+ (axes, self._existing_axes))
+ else:
+ self._existing_axes = axes
+
+ return reshape_lt
+
+ @tc.returns(core.LabeledTensor)
+ @tc.accepts(object, core.LabeledTensorLike)
+ def decode(self, labeled_tensor):
+ """Reshape the input to the original shape.
+
+ This is the inverse of encode.
+ Encode must have been called at least once prior to this method being
+ called.
+
+ Args:
+ labeled_tensor: The input tensor.
+
+ Returns:
+ The input reshaped to the original shape.
+
+ Raises:
+ ValueError: If this method was called before encode was called.
+ """
+ if self._existing_axes is None:
+ raise ValueError('decode called before encode')
+
+ with tf_ops.name_scope(self._name, 'lt_reshape_decode',
+ [labeled_tensor]) as scope:
+ labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
+
+ new_axis_names = [axis if isinstance(axis, string_types) else
+ core.as_axis(axis).name for axis in self._new_axes]
+
+ return ops.reshape(labeled_tensor,
+ new_axis_names,
+ self._existing_axes,
+ name=scope)
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/sugar_test.py b/tensorflow/contrib/labeled_tensor/python/ops/sugar_test.py
new file mode 100644
index 0000000000..3923f5a174
--- /dev/null
+++ b/tensorflow/contrib/labeled_tensor/python/ops/sugar_test.py
@@ -0,0 +1,106 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from six.moves import range # pylint: disable=redefined-builtin
+import tensorflow as tf
+
+from tensorflow.contrib.labeled_tensor.python.ops import core
+from tensorflow.contrib.labeled_tensor.python.ops import ops
+from tensorflow.contrib.labeled_tensor.python.ops import sugar
+from tensorflow.contrib.labeled_tensor.python.ops import test_util
+
+
+class Base(test_util.Base):
+
+ def setUp(self):
+ super(Base, self).setUp()
+
+ self.small_lt = core.LabeledTensor(tf.constant([1]), [('x', 1)])
+
+
+class ReshapeCoderTest(Base):
+
+ def setUp(self):
+ super(ReshapeCoderTest, self).setUp()
+
+ self.batch_size = 8
+ self.num_rows = 50
+ self.num_columns = 100
+ self.channels = ['red', 'green', 'blue']
+ self.masks = [False, True]
+
+ tensor = tf.range(0, self.batch_size * self.num_rows * self.num_columns *
+ len(self.channels) * len(self.masks))
+ tensor = tf.reshape(tensor, [self.batch_size, self.num_rows,
+ self.num_columns, len(self.channels),
+ len(self.masks)])
+
+ self.batch_axis = ('batch', range(self.batch_size))
+ self.row_axis = ('row', range(self.num_rows))
+ self.column_axis = ('column', range(self.num_columns))
+ self.channel_axis = ('channel', self.channels)
+ self.mask_axis = ('mask', self.masks)
+
+ axes = [self.batch_axis, self.row_axis, self.column_axis, self.channel_axis,
+ self.mask_axis]
+ self.masked_image_lt = core.LabeledTensor(tensor, axes)
+
+ def test_name(self):
+ rc = sugar.ReshapeCoder(['channel', 'mask'], ['depth'])
+ encode_lt = rc.encode(self.masked_image_lt)
+ decode_lt = rc.decode(encode_lt)
+ self.assertIn('lt_reshape_encode', encode_lt.name)
+ self.assertIn('lt_reshape_decode', decode_lt.name)
+
+ def test_bijection_flat(self):
+ rc = sugar.ReshapeCoder(['channel', 'mask'], ['depth'])
+
+ encode_lt = rc.encode(self.masked_image_lt)
+ golden_axes = core.Axes([self.batch_axis, self.row_axis, self.column_axis,
+ ('depth', len(self.channels) * len(self.masks))])
+ self.assertEqual(encode_lt.axes, golden_axes)
+
+ decode_lt = rc.decode(encode_lt)
+ self.assertLabeledTensorsEqual(decode_lt, self.masked_image_lt)
+
+ def test_bijection_with_labels(self):
+ depth_axis = core.Axis('depth', range(len(self.channels) * len(self.masks)))
+ rc = sugar.ReshapeCoder(['channel', 'mask'], [depth_axis,
+ ('other', ['label'])])
+
+ encode_lt = rc.encode(self.masked_image_lt)
+ golden_axes = core.Axes([self.batch_axis, self.row_axis, self.column_axis,
+ depth_axis, ('other', ['label'])])
+ self.assertEqual(encode_lt.axes, golden_axes)
+
+ decode_lt = rc.decode(encode_lt)
+ self.assertLabeledTensorsEqual(decode_lt, self.masked_image_lt)
+
+ def test_invalid_input(self):
+ with self.assertRaises(ValueError):
+ rc = sugar.ReshapeCoder(['channel', 'mask'], ['depth'])
+ rc.decode(self.masked_image_lt)
+ with self.assertRaises(ValueError):
+ rc = sugar.ReshapeCoder(['channel', 'mask'], ['depth'])
+ rc.encode(self.masked_image_lt)
+ rc.encode(ops.select(self.masked_image_lt, {'channel': 'red'}))
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/test_util.py b/tensorflow/contrib/labeled_tensor/python/ops/test_util.py
new file mode 100644
index 0000000000..521314010e
--- /dev/null
+++ b/tensorflow/contrib/labeled_tensor/python/ops/test_util.py
@@ -0,0 +1,47 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Utils for writing tests."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+class Base(tf.test.TestCase):
+ """A class with some useful methods for testing."""
+
+ def eval(self, tensors):
+ with self.test_session() as sess:
+ coord = tf.train.Coordinator()
+ threads = tf.train.start_queue_runners(sess=sess, coord=coord)
+
+ try:
+ results = sess.run(tensors)
+ finally:
+ coord.request_stop()
+ coord.join(threads)
+
+ return results
+
+ def assertTensorsEqual(self, tensor_0, tensor_1):
+ [tensor_0_eval, tensor_1_eval] = self.eval([tensor_0, tensor_1])
+ self.assertAllEqual(tensor_0_eval, tensor_1_eval)
+
+ def assertLabeledTensorsEqual(self, tensor_0, tensor_1):
+ self.assertEqual(tensor_0.axes, tensor_1.axes)
+ self.assertTensorsEqual(tensor_0.tensor, tensor_1.tensor)