diff options
author | Stephan Hoyer <shoyer@google.com> | 2016-11-14 17:24:44 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-11-14 17:44:53 -0800 |
commit | 9d20f4ea4b0b5792bf88ef886d0143b7aa780522 (patch) | |
tree | 7007220d84d18a058a7c5ed02a695af728e15a3e /tensorflow/contrib/labeled_tensor | |
parent | 887892a499590fd24a052074d5d32ae9393e3a35 (diff) |
Initial version of tf.contrib.labeled_tensor
Change: 139143754
Diffstat (limited to 'tensorflow/contrib/labeled_tensor')
-rw-r--r-- | tensorflow/contrib/labeled_tensor/BUILD | 166 | ||||
-rw-r--r-- | tensorflow/contrib/labeled_tensor/README.md | 8 | ||||
-rw-r--r-- | tensorflow/contrib/labeled_tensor/__init__.py | 139 | ||||
-rw-r--r-- | tensorflow/contrib/labeled_tensor/python/ops/_typecheck.py | 322 | ||||
-rw-r--r-- | tensorflow/contrib/labeled_tensor/python/ops/core.py | 1197 | ||||
-rw-r--r-- | tensorflow/contrib/labeled_tensor/python/ops/core_test.py | 842 | ||||
-rw-r--r-- | tensorflow/contrib/labeled_tensor/python/ops/io_ops.py | 178 | ||||
-rw-r--r-- | tensorflow/contrib/labeled_tensor/python/ops/io_ops_test.py | 106 | ||||
-rw-r--r-- | tensorflow/contrib/labeled_tensor/python/ops/nn.py | 42 | ||||
-rw-r--r-- | tensorflow/contrib/labeled_tensor/python/ops/nn_test.py | 70 | ||||
-rw-r--r-- | tensorflow/contrib/labeled_tensor/python/ops/ops.py | 1207 | ||||
-rw-r--r-- | tensorflow/contrib/labeled_tensor/python/ops/ops_test.py | 918 | ||||
-rw-r--r-- | tensorflow/contrib/labeled_tensor/python/ops/sugar.py | 131 | ||||
-rw-r--r-- | tensorflow/contrib/labeled_tensor/python/ops/sugar_test.py | 106 | ||||
-rw-r--r-- | tensorflow/contrib/labeled_tensor/python/ops/test_util.py | 47 |
15 files changed, 5479 insertions, 0 deletions
diff --git a/tensorflow/contrib/labeled_tensor/BUILD b/tensorflow/contrib/labeled_tensor/BUILD new file mode 100644 index 0000000000..82d9dc9a45 --- /dev/null +++ b/tensorflow/contrib/labeled_tensor/BUILD @@ -0,0 +1,166 @@ +# Description: +# Labels for TensorFlow. + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +package(default_visibility = ["//tensorflow:__subpackages__"]) + +py_library( + name = "labeled_tensor", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":core", + ":io_ops", + ":nn", + ":ops", + ":sugar", + ], +) + +py_library( + name = "_typecheck", + srcs = ["python/ops/_typecheck.py"], + srcs_version = "PY2AND3", + visibility = [":__subpackages__"], +) + +py_library( + name = "core", + srcs = ["python/ops/core.py"], + srcs_version = "PY2AND3", + deps = [ + ":_typecheck", + ], +) + +py_library( + name = "test_util", + srcs = ["python/ops/test_util.py"], + srcs_version = "PY2AND3", + deps = [ + ":_typecheck", + ":core", + "//tensorflow:tensorflow_py", + ], +) + +py_test( + name = "core_test", + size = "small", + srcs = [ + "python/ops/core_test.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":core", + ":test_util", + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "io_ops", + srcs = ["python/ops/io_ops.py"], + srcs_version = "PY2AND3", + deps = [ + ":core", + ], +) + +py_test( + name = "io_ops_test", + size = "small", + srcs = [ + "python/ops/io_ops_test.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":io_ops", + ":ops", + ":test_util", + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "nn", + srcs = ["python/ops/nn.py"], + srcs_version = "PY2AND3", + deps = [ + ":core", + ], +) + +py_test( + name = "nn_test", + size = "small", + srcs = [ + "python/ops/nn_test.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":nn", + ":test_util", + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "ops", + srcs = ["python/ops/ops.py"], + srcs_version = "PY2AND3", + deps = [ + ":core", + ], +) + +py_test( + name = "ops_test", + srcs = [ + "python/ops/ops_test.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":ops", + ":test_util", + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "sugar", + srcs = ["python/ops/sugar.py"], + srcs_version = "PY2AND3", + deps = [ + ":core", + ":ops", + ], +) + +py_test( + name = "sugar_test", + size = "small", + srcs = [ + "python/ops/sugar_test.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":sugar", + ":test_util", + "//tensorflow:tensorflow_py", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), +) diff --git a/tensorflow/contrib/labeled_tensor/README.md b/tensorflow/contrib/labeled_tensor/README.md new file mode 100644 index 0000000000..50c6750fd0 --- /dev/null +++ b/tensorflow/contrib/labeled_tensor/README.md @@ -0,0 +1,8 @@ +# Labels for TensorFlow + +LabeledTensor is a library for adding semantically meaningful dimension and +coordinate labels to tensors in Tensorflow. + +Maintainers: +- Stephan Hoyer (shoyer@google.com, github.com/shoyer) +- Eric Christiansen (ericmc@google.com, github.com/emchristiansen) diff --git a/tensorflow/contrib/labeled_tensor/__init__.py b/tensorflow/contrib/labeled_tensor/__init__.py new file mode 100644 index 0000000000..75299a3a0e --- /dev/null +++ b/tensorflow/contrib/labeled_tensor/__init__.py @@ -0,0 +1,139 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Labels for TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.labeled_tensor.python.ops import core as _core +from tensorflow.contrib.labeled_tensor.python.ops import io_ops as _io_ops +from tensorflow.contrib.labeled_tensor.python.ops import nn +from tensorflow.contrib.labeled_tensor.python.ops import ops as _ops +from tensorflow.contrib.labeled_tensor.python.ops import sugar as _sugar + +# pylint: disable=invalid-name + +# Core types. +Axis = _core.Axis +Axes = _core.Axes +LabeledTensor = _core.LabeledTensor + +as_axis = _core.as_axis +convert_to_labeled_tensor = _core.convert_to_labeled_tensor + +identity = _core.identity +slice = _core.slice_function # pylint: disable=redefined-builtin +transpose = _core.transpose +expand_dims = _core.expand_dims +align = _core.align + +axis_order_scope = _core.axis_order_scope +check_axis_order = _core.check_axis_order +impose_axis_order = _core.impose_axis_order +AxisOrderError = _core.AxisOrderError + +define_unary_op = _core.define_unary_op +define_binary_op = _core.define_binary_op +define_reduce_op = _ops.define_reduce_op + +abs = _core.abs_function # pylint: disable=redefined-builtin +neg = _core.neg +sign = _core.sign +inv = _core.inv +square = _core.square +round = _core.round_function # pylint: disable=redefined-builtin +sqrt = _core.sqrt +rsqrt = _core.rsqrt +exp = _core.exp +log = _core.log +ceil = _core.ceil +floor = _core.floor +cos = _core.cos +sin = _core.sin +tan = _core.tan +acos = _core.acos +asin = _core.asin +atan = _core.atan +lgamma = _core.lgamma +digamma = _core.digamma +erf = _core.erf +erfc = _core.erfc +logical_not = _core.logical_not + +add = _core.add +sub = _core.sub +mul = _core.mul +div = _core.div +mod = _core.mod +pow = _core.pow_function # pylint: disable=redefined-builtin + +equal = _core.equal +greater = _core.greater +greater_equal = _core.greater_equal +not_equal = _core.not_equal +less = _core.less +less_equal = _core.less_equal +logical_and = _core.logical_and +logical_or = _core.logical_or +logical_xor = _core.logical_xor + +maximum = _core.maximum +minimum = _core.minimum +squared_difference = _core.squared_difference +igamma = _core.igamma +igammac = _core.igammac +zeta = _core.zeta +polygamma = _core.polygamma + +select = _ops.select +concat = _ops.concat +pack = _ops.pack +unpack = _ops.unpack +reshape = _ops.reshape +rename_axis = _ops.rename_axis +random_crop = _ops.random_crop +map_fn = _ops.map_fn +squeeze = _ops.squeeze +matmul = _ops.matmul +tile = _ops.tile +pad = _ops.pad +constant = _ops.constant +zeros_like = _ops.zeros_like +ones_like = _ops.ones_like +cast = _ops.cast +verify_tensor_all_finite = _ops.verify_tensor_all_finite +boolean_mask = _ops.boolean_mask +where = _ops.where + +reduce_all = _ops.reduce_all +reduce_any = _ops.reduce_any +reduce_logsumexp = _ops.reduce_logsumexp +reduce_max = _ops.reduce_max +reduce_mean = _ops.reduce_mean +reduce_min = _ops.reduce_min +reduce_prod = _ops.reduce_prod +reduce_sum = _ops.reduce_sum + +batch = _ops.batch +shuffle_batch = _ops.shuffle_batch + +FixedLenFeature = _io_ops.FixedLenFeature +parse_example = _io_ops.parse_example +parse_single_example = _io_ops.parse_single_example +placeholder = _io_ops.placeholder + +ReshapeCoder = _sugar.ReshapeCoder diff --git a/tensorflow/contrib/labeled_tensor/python/ops/_typecheck.py b/tensorflow/contrib/labeled_tensor/python/ops/_typecheck.py new file mode 100644 index 0000000000..4a939cb22c --- /dev/null +++ b/tensorflow/contrib/labeled_tensor/python/ops/_typecheck.py @@ -0,0 +1,322 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Minimal runtime type checking library. + +This module should not be considered public API. +""" +# TODO(ericmc,shoyer): Delete this in favor of using pytype or mypy +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import functools +import inspect +import re + + +# used for register_type_abbreviation and _type_repr below. +_TYPE_ABBREVIATIONS = {} + + +class Type(object): + """Base class for type checker types. + + The custom types defined in this module are based on types in the standard + library's typing module (in Python 3.5): + https://docs.python.org/3/library/typing.html + + The only difference should be that we use actual instances of Type classes to + represent custom types rather than the metaclass magic typing uses to create + new class objects. In practice, all this should mean is that we use + `List(int)` rather than `List[int]`. + + Custom types should implement __instancecheck__ and inherit from Type. Every + argument in the constructor must be a type or Type instance, and these + arguments must be stored as a tuple on the `_types` attribute. + """ + + def __init__(self, *types): + self._types = types + + def __repr__(self): + args_repr = ", ".join(repr(t) for t in self._types) + return "typecheck.%s(%s)" % (type(self).__name__, args_repr) + + +class _SingleArgumentType(Type): + """Use this subclass for parametric types that accept only one argument.""" + + def __init__(self, tpe): + super(_SingleArgumentType, self).__init__(tpe) + + @property + def _type(self): + tpe, = self._types # pylint: disable=unbalanced-tuple-unpacking + return tpe + + +class _TwoArgumentType(Type): + """Use this subclass for parametric types that accept two arguments.""" + + def __init__(self, first_type, second_type): + super(_TwoArgumentType, self).__init__(first_type, second_type) + + +class Union(Type): + """A sum type. + + A correct type is any of the types provided. + """ + + def __instancecheck__(self, instance): + return isinstance(instance, self._types) + + +class Optional(_SingleArgumentType): + """An optional type. + + A correct type is either the provided type or NoneType. + """ + + def __instancecheck__(self, instance): + # types.NoneType does not exist in Python 3 + return isinstance(instance, (self._type, type(None))) + + +class List(_SingleArgumentType): + """A typed list. + + A correct type is a list where each element has the single provided type. + """ + + def __instancecheck__(self, instance): + return (isinstance(instance, list) + and all(isinstance(x, self._type) for x in instance)) + + +class Sequence(_SingleArgumentType): + """A typed sequence. + + A correct type is a sequence where each element has the single provided type. + """ + + def __instancecheck__(self, instance): + return (isinstance(instance, collections.Sequence) + and all(isinstance(x, self._type) for x in instance)) + + +class Collection(_SingleArgumentType): + """A sized, iterable container. + + A correct type is an iterable and container with known size where each element + has the single provided type. + + We use this in preference to Iterable because we check each instance of the + iterable at runtime, and hence need to avoid iterables that could be + exhausted. + """ + + def __instancecheck__(self, instance): + return (isinstance(instance, collections.Iterable) + and isinstance(instance, collections.Sized) + and isinstance(instance, collections.Container) + and all(isinstance(x, self._type) for x in instance)) + + +class Tuple(Type): + """A typed tuple. + + A correct type is a tuple with the correct length where each element has + the correct type. + """ + + def __instancecheck__(self, instance): + return (isinstance(instance, tuple) + and len(instance) == len(self._types) + and all(isinstance(x, t) for x, t in zip(instance, self._types))) + + +class Mapping(_TwoArgumentType): + """A typed mapping. + + A correct type has the correct parametric types for keys and values. + """ + + def __instancecheck__(self, instance): + key_type, value_type = self._types # pylint: disable=unbalanced-tuple-unpacking + return (isinstance(instance, collections.Mapping) + and all(isinstance(k, key_type) for k in instance.keys()) + and all(isinstance(k, value_type) for k in instance.values())) + + +class Dict(Mapping): + """A typed dict. + + A correct type has the correct parametric types for keys and values. + """ + + def __instancecheck__(self, instance): + return (isinstance(instance, dict) + and super(Dict, self).__instancecheck__(instance)) + + +def _replace_forward_references(t, context): + """Replace forward references in the given type.""" + if isinstance(t, str): + return context[t] + elif isinstance(t, Type): + return type(t)(*[_replace_forward_references(t, context) for t in t._types]) # pylint: disable=protected-access + else: + return t + + +def register_type_abbreviation(name, alias): + """Register an abbreviation for a type in typecheck tracebacks. + + This makes otherwise very long typecheck errors much more readable. + + Example: + typecheck.register_type_abbreviation(tf.Dimension, 'tf.Dimension') + + Args: + name: type or class to abbreviate. + alias: string alias to substitute. + """ + _TYPE_ABBREVIATIONS[name] = alias + + +def _type_repr(t): + """A more succinct repr for typecheck tracebacks.""" + string = repr(t) + for type_, alias in _TYPE_ABBREVIATIONS.items(): + string = string.replace(repr(type_), alias) + string = re.sub(r"<(class|type) '([\w.]+)'>", r"\2", string) + string = re.sub(r"typecheck\.(\w+)", r"\1", string) + return string + + +class Error(TypeError): + """Exception for typecheck failures.""" + + +def accepts(*types): + """A decorator which checks the input types of a function. + + Based on: + http://stackoverflow.com/questions/15299878/how-to-use-python-decorators-to-check-function-arguments + The above draws from: + https://www.python.org/dev/peps/pep-0318/ + + Args: + *types: A list of Python types. + + Returns: + A function to use as a decorator. + """ + + def check_accepts(f): + """Check the types.""" + spec = inspect.getargspec(f) + + num_function_arguments = len(spec.args) + if len(types) != num_function_arguments: + raise Error( + "Function %r has %d arguments but only %d types were provided in the " + "annotation." % (f, num_function_arguments, len(types))) + + if spec.defaults: + num_defaults = len(spec.defaults) + for (name, a, t) in zip(spec.args[-num_defaults:], + spec.defaults, + types[-num_defaults:]): + allowed_type = _replace_forward_references(t, f.__globals__) + if not isinstance(a, allowed_type): + raise Error("default argument value %r of type %r is not an instance " + "of the allowed type %s for the %s argument to %r" + % (a, type(a), _type_repr(allowed_type), name, f)) + + @functools.wraps(f) + def new_f(*args, **kwds): + """A helper function.""" + for (a, t) in zip(args, types): + allowed_type = _replace_forward_references(t, f.__globals__) + if not isinstance(a, allowed_type): + raise Error("%r of type %r is not an instance of the allowed type %s " + "for %r" % (a, type(a), _type_repr(allowed_type), f)) + return f(*args, **kwds) + + return new_f + + return check_accepts + + +def returns(*types): + """A decorator which checks the return types of a function. + + Based on: + http://stackoverflow.com/questions/15299878/how-to-use-python-decorators-to-check-function-arguments + The above draws from: + https://www.python.org/dev/peps/pep-0318/ + + Args: + *types: A list of Python types. + A list of one element corresponds to a single return value. + A list of several elements corresponds to several return values. + Note that a function with no explicit return value has an implicit + NoneType return and should be annotated correspondingly. + + Returns: + A function to use as a decorator. + """ + + def check_returns(f): + """Check the types.""" + if not types: + raise TypeError("A return type annotation must contain at least one type") + + @functools.wraps(f) + def new_f(*args, **kwds): + """A helper function.""" + return_value = f(*args, **kwds) + + if len(types) == 1: + # The function has a single return value. + allowed_type = _replace_forward_references(types[0], f.__globals__) + if not isinstance(return_value, allowed_type): + raise Error("%r of type %r is not an instance of the allowed type %s " + "for %r" + % (return_value, type(return_value), + _type_repr(allowed_type), f)) + + else: + if len(return_value) != len(types): + raise Error( + "Function %r has %d return values but only %d types were " + "provided in the annotation." % + (f, len(return_value), len(types))) + + for (r, t) in zip(return_value, types): + allowed_type = _replace_forward_references(t, f.__globals__) + if not isinstance(r, allowed_type): + raise Error("%r of type %r is not an instance of allowed type %s " + "for %r" % (r, type(r), _type_repr(allowed_type), f)) + + return return_value + + return new_f + + return check_returns diff --git a/tensorflow/contrib/labeled_tensor/python/ops/core.py b/tensorflow/contrib/labeled_tensor/python/ops/core.py new file mode 100644 index 0000000000..69fd06133b --- /dev/null +++ b/tensorflow/contrib/labeled_tensor/python/ops/core.py @@ -0,0 +1,1197 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Core classes and core ops for LabeledTensor. + +Core ops are ops which will eventually be called by LabeledTensor methods, +and ops which a core op depends upon. +For example, `add` is a core op because we'll eventually support the `+` +operator. +Non-core ops should go in `ops.py`. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import contextlib +import numbers +import types + +import numpy as np +from six import binary_type +from six import string_types +from six import text_type +from six.moves import range # pylint: disable=redefined-builtin + +from tensorflow.contrib.labeled_tensor.python.ops import _typecheck as tc +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + + +# pylint: disable=invalid-name + +# Types coercible to Axis.labels +# We use this instead of collections.Sequence to exclude strings. +LabelsLike = tc.Union(np.ndarray, range, list, tuple) + +# Types coercible to a tf.Dimension +DimensionLike = tc.Optional(tc.Union(tensor_shape.Dimension, int)) + +# Types usable for axis values +AxisValue = tc.Union(LabelsLike, DimensionLike) + +# Valid scalar values for TensorFlow +Scalar = tc.Union(numbers.Number, bool, binary_type, text_type) + +# pylint: enable=invalid-name + + +class Axis(object): + """Size and label information for an axis. + + Axis contains either a tf.Dimension indicating the size of an axis, + or a tuple of tick labels for the axis. + + If tick labels are provided, they must be unique. + """ + + @tc.accepts(object, string_types, AxisValue) + def __init__(self, name, value): + """Construct an Axis. + + Args: + name: Name of the axis. + value: Either None, an int or tf.Dimension giving the size of the axis, + or a sequence that is not a string additionally providing coordinate + (tick) labels. + + Raises: + ValueError: If the user provides labels with duplicate values. + """ + if isinstance(value, tensor_shape.Dimension): + dimension = value + labels = None + elif isinstance(value, int) or value is None: + dimension = tensor_shape.Dimension(value) + labels = None + else: + dimension = tensor_shape.Dimension(len(value)) + labels = tuple(value) + + if dimension.value == 0: + # Treat a zero-length axis as if it has labels. + labels = () + + if labels is not None: + index = dict(zip(labels, range(len(labels)))) + if len(index) != len(labels): + raise ValueError('Tick labels must be unique, but got {}' + .format(labels)) + else: + index = None + + self._name = name # type: string_types + self._dimension = dimension # type: tensor_shape.Dimension + self._labels = labels # type: Optional[tuple] + self._index = index # type: Optional[Dict[Any, int]] + + @property + @tc.returns(string_types) + def name(self): + return self._name + + @tc.returns(string_types) + def __repr__(self): + # Axis('x', Dimension(2)) + # TODO(shoyer): make very long reprs more succint? + return "%s('%s', %r)" % (type(self).__name__, self.name, self.value) + + @tc.returns(bool) + def __eq__(self, other): + return (isinstance(other, Axis) and + self.name == other.name and + self.size == other.size and + self.labels == other.labels) + + def __hash__(self): + return hash((self.name, self.size, self.labels)) + + @tc.returns(bool) + def __ne__(self, other): + return not self == other + + @tc.returns(int) + def __len__(self): + size = self.size + if size is None: + raise ValueError('axis %r has unknown length' % self.name) + return size + + @property + @tc.returns(tc.Optional(tensor_shape.Dimension)) + def dimension(self): + return self._dimension + + @property + @tc.returns(tc.Optional(int)) + def size(self): + return self._dimension.value + + @property + @tc.returns(tc.Union(tuple, tensor_shape.Dimension)) + def value(self): + """Returns the tf.Dimension or tuple specifying axis ticks.""" + if self.labels is None: + return self.dimension + else: + return self.labels + + @property + @tc.returns(tc.Optional(tuple)) + def labels(self): + """Returns the tuple containing coordinate labels, else None.""" + return self._labels + + def index(self, value): + """Returns the integer position of the given tick label.""" + if self._index is None: + raise ValueError('Axis does not have tick labels') + return self._index[value] + + +# tc class for anything that can be coerced into an Axis +# pylint: disable=invalid-name +AxisLike = tc.Union(Axis, tc.Tuple(string_types, AxisValue)) +# pylint: enable=invalid-name + + +@tc.returns(Axis) +@tc.accepts(AxisLike) +def as_axis(axis_data): + """Convert an AxisLike object into an Axis. + + Args: + axis_data: Axis object or tuple (axis_name, axis_value) describing an axis. + + Returns: + Axis object. This may be the original object if axis_data is an Axis. + """ + if isinstance(axis_data, Axis): + axis = axis_data + else: + axis = Axis(*axis_data) + return axis + + +class Axes(collections.Mapping): + """Axis names and indices for a tensor. + + It is an ordered mapping, with keys given by axis name and values given + by Axis objets. Duplicate axis names are not allowed. + """ + + @tc.accepts(object, tc.List(AxisLike)) + def __init__(self, axes): + """Construct an Axes. + + Args: + axes: A list of Axis objects or (axis_name, axis_value) tuples. + + Raises: + ValueError: If the user provides empty or duplicate axis names. + """ + self._axes = collections.OrderedDict() + + for axis_data in axes: + axis = as_axis(axis_data) + + name = axis.name + if name in self._axes: + raise ValueError('Duplicate axis name: %s' % name) + + self._axes[name] = axis + + def __iter__(self): + return iter(self._axes) + + @tc.returns(string_types) + def __repr__(self): + # Axes([('x', Dimension(2)), + # ('y', ['a', 'b', 'c']), + # ('z', Dimension(4))]) + cls_name = type(self).__name__ + values = ["('%s', %r)" % (v.name, v.value) for v in self._axes.values()] + values_repr = (',\n' + ' ' * len(cls_name + '([')).join(values) + return '%s([%s])' % (cls_name, values_repr) + + @tc.returns(Axis) + @tc.accepts(object, string_types) + def __getitem__(self, name): + return self._axes[name] + + @tc.returns(bool) + def __contains__(self, name): + return name in self._axes + + @tc.returns(int) + def __len__(self): + return len(self._axes) + + def __hash__(self): + return hash(tuple(self.items())) + + @tc.accepts(object, string_types) + def remove(self, axis_name): + """Creates a new Axes object without the given axis.""" + if axis_name not in self: + raise KeyError(axis_name) + remaining_axes = [axis for axis in self.values() if axis.name != axis_name] + return Axes(remaining_axes) + + +class LabeledTensor(object): + """A tensor with annotated axes. + + It has the following invariants: + 1) The dimensionality of the tensor is equal to the number of elements + in axes. + 2) The number of coordinate values in the ith dimension is equal to the + size of the tensor in the ith dimension. + + Attributes: + tensor: tf.Tensor containing the data. + axes: lt.Axes containing axis names and coordinate labels. + """ + + @tc.accepts(object, ops.Output, + tc.Union(Axes, tc.Collection(tc.Union(string_types, AxisLike)))) + def __init__(self, tensor, axes): + """Construct a LabeledTenor. + + Args: + tensor: The underlying tensor containing the data. + axes: An Axes object, or a collection of strings, Axis objects or tuples + of (name, value) pairs indicating the axes. + + Raises: + ValueError: If the provided axes do not satisfy the class invariants. + """ + self._tensor = tensor + shape = tensor.get_shape() + + if isinstance(axes, Axes): + unvalidated_axes = axes + else: + mutable_axes = [] + + for position, axis_like in enumerate(axes): + if isinstance(axis_like, string_types): + # The coordinates for this axes are unlabeled. + # Infer the size of the axis. + value = shape[position] + axis_like = (axis_like, value) + + mutable_axes.append(axis_like) + + # Construct the Axis object, which will additionally validate the contents + # of the object. + unvalidated_axes = Axes(mutable_axes) + + # Check our invariants. + + # First, the rank of the tensor must be equal to the number of axes. + if len(shape) != len(unvalidated_axes): + raise ValueError('Tensor rank was not equal to the number of axes: %r, %r' + % (shape, unvalidated_axes)) + + # Second, the size of each tensor dimension must match the size of the + # corresponding indices. + for (d, axis) in zip(shape, unvalidated_axes.values()): + if d != axis.size: + raise ValueError( + 'Provided axis size %d does not match tensor dimension size %d' % + (axis.size, d)) + + self._axes = unvalidated_axes + + def __repr__(self): + # <LabeledTensor 'foo' shape=(2, 3, 4) dtype=float32 + # axes=[('x', Dimension(2)), + # ('y', ('a', 'b', 'c'), + # ('z', Dimension(4))]> + axes = ["('%s', %r)" % (v.name, v.value) for v in self.axes.values()] + axes_repr = (',\n' + ' ' * len(' axes=[')).join(axes) + return ("<%s '%s' shape=%s dtype=%s\n axes=[%s]>" % + (type(self).__name__, self.tensor.name, self.tensor.get_shape(), + self.tensor.dtype.name, axes_repr)) + + @property + def tensor(self): + return self._tensor + + def _as_graph_element(self): + """Support tf.Graph.as_graph_element on LabeledTensor objects. + + This allows operations such as tf.name_scope to take labeled tensors. + + Returns: + self.tensor + """ + return self.tensor + + @property + def axes(self): + return self._axes + + # properties/methods directly borrowed from tf.Tensor: + + @property + def dtype(self): + return self._tensor.dtype + + @property + def name(self): + return self._tensor.name + + def get_shape(self): + """Returns the TensorShape that represents the shape of this tensor. + + See tf.Tensor.get_shape(). + + Returns: + A TensorShape representing the shape of this tensor. + """ + return self._tensor.get_shape() + + # TODO(shoyer): consider how/if to implement .eval(). Maybe it should return + # an xarray.DataArray? + + def __getitem__(self, key): + # This should work exactly like tf.Tensor.__getitem__, except it preserves + # labels. + if not isinstance(key, tuple): + key = (key,) + if len(key) != len(self.axes): + raise ValueError('indexer %r must have the same length as the Tensor ' + 'rank (%r)' % (key, len(self.axes))) + selection = {a: k for a, k in zip(self.axes.keys(), key)} + return slice_function(self, selection) + + # special methods for overloading arithmetic operations: + + def __abs__(self): + return abs_function(self) + + def __neg__(self): + return neg(self) + + def __pos__(self): + return self + + def __add__(self, other): + return add(self, other) + + def __radd__(self, other): + return add(other, self) + + def __sub__(self, other): + return sub(self, other) + + def __rsub__(self, other): + return sub(other, self) + + def __mul__(self, other): + return mul(self, other) + + def __rmul__(self, other): + return mul(other, self) + + def __truediv__(self, other): + return div(self, other) + + __div__ = __truediv__ + + def __rtruediv__(self, other): + return div(other, self) + + __rdiv__ = __rtruediv__ + + def __mod__(self, other): + return mod(self, other) + + def __rmod__(self, other): + return mod(other, self) + + def __pow__(self, other): + return pow_function(self, other) + + def __rpow__(self, other): + return pow_function(other, self) + + # logical operations: + + def __invert__(self): + return logical_not(self) + + def __and__(self, other): + return logical_and(self, other) + + def __or__(self, other): + return logical_or(self, other) + + def __xor__(self, other): + return logical_xor(self, other) + + # boolean operations: + + def __lt__(self, other): + return less(self, other) + + def __le__(self, other): + return less_equal(self, other) + + def __gt__(self, other): + return greater(self, other) + + def __ge__(self, other): + return greater_equal(self, other) + + def __eq__(self, other): + # for consistency with tf.Tensor + if not isinstance(other, LabeledTensor): + return False + + return self.tensor == other.tensor and self.axes == other.axes + + def __ne__(self, other): + return not self == other + + def __hash__(self): + return hash((self.tensor, self.axes)) + + +# typecheck type abbreviations: +# abbreviations for third-party types with very long reprs +tc.register_type_abbreviation(tensor_shape.Dimension, 'tensorflow.Dimension') +tc.register_type_abbreviation(ops.Output, 'tensorflow.Output') +tc.register_type_abbreviation(dtypes.DType, 'tensorflow.DType') +# core LabeledTensor types +tc.register_type_abbreviation(Axis, 'labeled_tensor.Axis') +tc.register_type_abbreviation(Axes, 'labeled_tensor.Axes') +tc.register_type_abbreviation(LabeledTensor, 'labeled_tensor.LabeledTensor') + + +@tc.returns(ops.Output) +@tc.accepts(LabeledTensor) +def _convert_labeled_tensor_to_tensor(value, *args, **kwargs): + # call ops.convert_to_tensor to handle optional arguments appropriately + return ops.convert_to_tensor(value.tensor, *args, **kwargs) + + +ops.register_tensor_conversion_function( + LabeledTensor, _convert_labeled_tensor_to_tensor) + + +# tc class for anything that can be coerced into a LabeledTensor +# pylint: disable=invalid-name +LabeledTensorLike = tc.Union(LabeledTensor, ops.Output, np.ndarray, Scalar) +# pylint: enable=invalid-name + + +@tc.returns(LabeledTensor) +@tc.accepts(LabeledTensorLike, object, tc.Optional(string_types)) +def convert_to_labeled_tensor(value, dtype=None, name=None): + """Converts the given `value` to a `LabeledTensor`. + + This function accepts `LabeledTensor` objects, 0-dimensional `Tensor` objects + and numpy arrays, and Python scalars. Higher dimensional unlabeled tensors + must use the `LabeledTensor` constructor explicitly. + + Args: + value: Object to convert. + dtype: Optional element type for the returned tensor. If missing, the type + is inferred from the type of value. + name: Optional name to use if a new Tensor is created. + + Returns: + `value` converted into a `LabeledTensor` object. + + Raises: + ValueError: If the output would have rank>0 but the input was not already a + `LabeledTensor`. + """ + # TODO(shoyer): consider extending to accept xarray.DataArray as input. + if isinstance(value, LabeledTensor): + axes = value.axes.values() + value = value.tensor + else: + axes = [] + + # We call convert_to_tensor even for LabeledTensor input because it also + # checks to make sure the dtype argument is compatible. + tensor = ops.convert_to_tensor(value, dtype=dtype, name=name) + if len(tensor.get_shape()) != len(axes): + raise ValueError('cannot automatically convert unlabeled arrays or tensors ' + 'with rank>0 into LabeledTensors: %r' % value) + return LabeledTensor(tensor, axes) + + +@tc.returns(Axis) +@tc.accepts(tc.Collection(Axis)) +def concat_axes(axes): + """Concatenate a list of Axes. + + Args: + axes: A collection of Axis objects. + + Returns: + The concatenation of the axes. + If all axes have labels, the result has the concatenation of the labels. + Else, the result has no labels, and its size is the sum of the sizes + of the axes. + + Raises: + ValueError: If `others` is not a collection of Axes or if it is empty. + """ + if not axes: + raise ValueError('axes must not be empty') + for a in axes: + if not isinstance(a, Axis): + raise ValueError('Expected an Axis, but got %r of type %r' % (a, type(a))) + + names = set(a.name for a in axes) + if len(names) > 1: + raise ValueError('axes do not all have the same name: %r' % names) + name, = names + + all_have_labels = all(a.labels is not None for a in axes) + any_has_unknown_size = any(a.size is None for a in axes) + + if all_have_labels: + value = tuple(label for a in axes for label in a.labels) + elif any_has_unknown_size: + value = None + else: + value = sum(len(a) for a in axes) + return Axis(name, value) + + +@tc.returns(LabeledTensor) +@tc.accepts(LabeledTensorLike, tc.Optional(string_types)) +def identity(labeled_tensor, name=None): + """The identity op. + + See tf.identity. + + Args: + labeled_tensor: The input tensor. + name: Optional op name. + + Returns: + The tensor. + """ + with ops.name_scope(name, 'lt_identity', [labeled_tensor]) as scope: + labeled_tensor = convert_to_labeled_tensor(labeled_tensor) + return LabeledTensor( + array_ops.identity(labeled_tensor.tensor, name=scope), + labeled_tensor.axes) + + +# We don't call this slice because that shadows a built-in. Instead, we alias +# this to lt.slice in __init__.py. +@tc.returns(LabeledTensor) +@tc.accepts(LabeledTensorLike, tc.Mapping(string_types, tc.Union(int, slice)), + tc.Optional(string_types)) +def slice_function(labeled_tensor, selection, name=None): + """Slice out a subset of the tensor. + + This is an analogue of tf.slice. + For example: + >>> tensor = tf.reshape(tf.range(0, 6), [3, 2]) + >>> labeled_tensor = lt.LabeledTensor(tensor, ['a', ('b', ['foo', 'bar'])]) + >>> lt.slice(labeled_tensor, {'a': slice(0, 2), 'b': 1}) + <LabeledTensor 'lt_slice:...' shape=(2,) dtype=int32 + axes=[('a', Dimension(2))]> + + Args: + labeled_tensor: The input tensor. + selection: A dictionary of type str -> Union(int, slice of int) mapping + axis names to sub-selections. + name: Optional op name. + + Returns: + The slice as a `LabeledTensor`. + """ + with ops.name_scope(name, 'lt_slice', [labeled_tensor]) as scope: + labeled_tensor = convert_to_labeled_tensor(labeled_tensor) + + slices = [] + + for axis_name in labeled_tensor.axes: + if axis_name not in selection: + # We're not sub-selecting this axis, so use the full slice. + slices.append(slice(None)) + else: + slices.append(selection[axis_name]) + + sliced_tensor = labeled_tensor.tensor[tuple(slices)] + + sliced_axes = [] + for axis, s in zip(labeled_tensor.axes.values(), slices): + # We sub-select this axis's index with the slice s. + + # `s` is either an int or a proper slice. + if isinstance(s, slice): + if axis.labels is None: + # We're not tracking coordinate names for this axis. + sliced_axes.append(axis.name) + else: + sliced_axes.append((axis.name, axis.labels[s])) + else: + # If the slice is an int this dimension now has size 1, so we remove it. + assert isinstance(s, int) + + return LabeledTensor(array_ops.identity(sliced_tensor, name=scope), + sliced_axes) + + +@tc.returns(LabeledTensor) +@tc.accepts(LabeledTensorLike, tc.Optional(tc.Collection(string_types)), + tc.Optional(string_types)) +def transpose(labeled_tensor, axis_order=None, name=None): + """Permute a tensor's axes. + + See tf.transpose. + + Args: + labeled_tensor: The input tensor. + axis_order: Optional desired axis order, as a list of names. By default, the + order of axes is reversed. + name: Optional op name. + + Returns: + The permuted tensor. + + Raises: + ValueError: If axis_order isn't a permutation of the existing axes. + """ + with ops.name_scope(name, 'lt_transpose', [labeled_tensor]) as scope: + labeled_tensor = convert_to_labeled_tensor(labeled_tensor) + + original_order = list(labeled_tensor.axes.keys()) + if axis_order is None: + axis_order = list(reversed(original_order)) + elif sorted(axis_order) != sorted(original_order): + raise ValueError( + 'The new axis order must have the same names as the original axes, ' + 'but the new order is %r while the original order is %r' % + (axis_order, original_order)) + + axis_names = list(labeled_tensor.axes.keys()) + permutation = [axis_names.index(n) for n in axis_order] + + # Note: TensorFlow doesn't copy data for the identity tranpose. + transpose_tensor = array_ops.transpose(labeled_tensor.tensor, + permutation, + name=scope) + + permuted_axes = [labeled_tensor.axes[n] for n in axis_order] + + return LabeledTensor(transpose_tensor, permuted_axes) + + +@tc.returns(LabeledTensor) +@tc.accepts(LabeledTensorLike, tc.Collection(tc.Union(string_types, tc.Tuple( + string_types, collections.Hashable))), tc.Optional(string_types)) +def expand_dims(labeled_tensor, axes, name=None): + """Insert dimensions of size 1. + + See tf.expand_dims. + + Args: + labeled_tensor: The input tensor. + axes: The desired axis names as strings or tuples of (name, label), + where `label` is the coordinate name for the new dimension `name`. + These must include the existing axis names, and the existing names must + appear in the same order in this list as they do in the input tensor. + name: Optional op name. + + Returns: + A tensor with an axis for each axis in axes. + New axes are created with size 1 and do not have labeled coordinates. + + Raises: + AxisOrderError: If axis names don't appear in the same order in axes + and the labeled tensor. + """ + with ops.name_scope(name, 'lt_expand_dims', [labeled_tensor]) as scope: + labeled_tensor = convert_to_labeled_tensor(labeled_tensor) + + axis_names = [a if isinstance(a, string_types) else a[0] for a in axes] + check_axis_order(labeled_tensor, axis_names) + + reshaped_axes = [] + shape = [] + for axis_spec in axes: + if axis_spec in labeled_tensor.axes: + axis = labeled_tensor.axes[axis_spec] + reshaped_axes.append(axis) + shape.append(-1 if axis.size is None else axis.size) + else: + if isinstance(axis_spec, string_types): + reshaped_axes.append((axis_spec, 1)) + else: + (name, label) = axis_spec + reshaped_axes.append((name, (label,))) + + shape.append(1) + + reshaped_tensor = array_ops.reshape(labeled_tensor.tensor, shape, + name=scope) + + return LabeledTensor(reshaped_tensor, reshaped_axes) + +# This should only be added to a graph collection once. +_AXIS_ORDER_KEY = ('__axis_order',) + + +@tc.returns(tc.Optional(tc.List(string_types))) +def get_axis_order(): + """Get the axis_order set by any containing axis_order_scope. + + Returns: + List of strings giving an order to use for axis names, or None, if no axis + order is set. + """ + # By storing axis_order in the graph, we can ensure that axis_order_scope is + # thread-safe. + axis_order_list = ops.get_collection(_AXIS_ORDER_KEY) + if axis_order_list: + axis_order, = axis_order_list + else: + axis_order = None + return axis_order + + +@tc.accepts(tc.Optional(tc.List(string_types))) +def _set_axis_order(axis_order): + axis_order_list = ops.get_collection_ref(_AXIS_ORDER_KEY) + if axis_order_list: + axis_order_list[0] = axis_order + else: + axis_order_list.append(axis_order) + + +@contextlib.contextmanager +@tc.accepts(tc.Optional(tc.List(string_types))) +def axis_order_scope(axis_order=None): + """Set axis order for the result of broadcasting operations within a scope. + + This allows you to ensure that tensors resulting from arithmetic have a + predictable axis order. + + Example usage: + + with lt.axis_order_scope(['x', 'y', 'z']): + # result is guranteed to have the correct axis order + result = w + b + + You can nest scopes, in which case only the inner-most scope applies, e.g., + + with lt.axis_order(['x', 'y', 'z']): + with lt.axis_order(): + result = w + b # uses the default (left-most) axis ordering + + Args: + axis_order: optional list of strings providing axis names. By default, + creates a scope without axis order. + + Yields: + The provided axis_order or `None`. + """ + original_axis_order = get_axis_order() + _set_axis_order(axis_order) + try: + yield axis_order + finally: + _set_axis_order(original_axis_order) + + +@tc.returns(tc.List(string_types)) +def _get_valid_axis_order(): + axis_order = get_axis_order() + if axis_order is None: + raise AxisOrderError('an explicit axis order must be provided with the ' + 'axis_order argument or by using an axis_order_scope') + return axis_order + + +class AxisOrderError(ValueError): + """Error class for cases where there is no valid axis order.""" + + +# TODO(shoyer): should this function accept a list of labeled tensors instead? +@tc.returns(type(None)) +@tc.accepts(LabeledTensorLike, tc.Optional(tc.Collection(string_types))) +def check_axis_order(labeled_tensor, axis_order=None): + """Verify that the given tensor has a consistent axis order. + + Args: + labeled_tensor: The input tensor. All axes on this tensor must appear in + axis_order. + axis_order: Optional desired axis order, as a list of names. If not + provided, defaults to the current axis_order_scope (if set). + + Raises: + AxisOrderError: If the axis_order is unavailable, inconsistent or does not + include all existing axes. + """ + labeled_tensor = convert_to_labeled_tensor(labeled_tensor) + + if axis_order is None: + axis_order = _get_valid_axis_order() + + relevant_axis_order = [a for a in axis_order if a in labeled_tensor.axes] + + if len(relevant_axis_order) < len(labeled_tensor.axes): + raise AxisOrderError( + 'not all axis names appear in the required axis order %r: %r' % + (axis_order, labeled_tensor)) + + if relevant_axis_order != list(labeled_tensor.axes): + raise AxisOrderError( + 'axes on a labeled tensor do not appear in the same order as the ' + 'required axis order %r: %r' % (axis_order, labeled_tensor)) + + +@tc.returns(LabeledTensor) +@tc.accepts(LabeledTensorLike, tc.Optional(tc.Collection(string_types)), + tc.Optional(string_types)) +def impose_axis_order(labeled_tensor, axis_order=None, name=None): + """Impose desired axis order on a labeled tensor. + + Args: + labeled_tensor: The input tensor. + axis_order: Optional desired axis order, as a list of names. If not + provided, defaults to the current axis_order_scope (if set). + name: Optional op name. + + Returns: + Labeled tensor with possibly transposed axes. + + Raises: + AxisOrderError: If no axis_order is provided or axis_order does not contain + all axes on the input tensor. + """ + with ops.name_scope(name, 'lt_impose_axis_order', [labeled_tensor]) as scope: + labeled_tensor = convert_to_labeled_tensor(labeled_tensor) + + if axis_order is None: + axis_order = _get_valid_axis_order() + + relevant_axis_order = [a for a in axis_order if a in labeled_tensor.axes] + + return transpose(labeled_tensor, relevant_axis_order, name=scope) + + +@tc.returns(tc.Optional(list)) +@tc.accepts(list, list) +def _find_consistent_ordering(a, b): + """Find the left-most consistent ordering between two lists of unique items. + + A consistent ordering combines all elements in both a and b while keeping all + elements in their original order in both inputs. The left-most consistent + ordering orders elements from `a` not found in `b` before elements in `b` not + found in `a`. + + For example, given ['x', 'z'] and ['y', 'z'], both ['x', 'y', 'z'] and ['y', + 'x', 'z'] are consistent orderings because each of the inputs appears in + each consistent ordering in the same order, and ['x', 'y', 'z'] is the + left-most, because 'x' appears only in `a` and 'y' appears only in `b`. In + contrast, there is no consistent ordering between ['x', 'y'] and ['y', 'x']. + + Args: + a: list with unique elements. + b: list with unique elements. + + Returns: + List containing all elements in either a or b, or None, if no consistent + ordering exists. + """ + a_set = set(a) + b_set = set(b) + i = 0 + j = 0 + ordering = [] + while i < len(a) and j < len(b): + if a[i] not in b_set: + ordering.append(a[i]) + i += 1 + elif b[j] not in a_set: + ordering.append(b[j]) + j += 1 + elif a[i] == b[j]: + ordering.append(a[i]) + i += 1 + j += 1 + else: + return None + + ordering.extend(a[i:]) + ordering.extend(b[j:]) + + return ordering + + +@tc.returns(LabeledTensor, LabeledTensor, Axes) +@tc.accepts(LabeledTensorLike, LabeledTensorLike, tc.Optional(string_types)) +def align(labeled_tensor_0, labeled_tensor_1, name=None): + """Align the axes of two tensors so they may be broadcast to each other. + + Axes are ordered by the current axis order scope, if present, or by the left- + most consistent ordering. An exception is raised if it is impossible to align + the tensors without a transpose (align never copies the input data). + + Example usage: + + >>> a = lt.LabeledTensor(tf.ones((2, 4)), ['x', 'z']) + >>> b = lt.LabeledTensor(tf.ones((3, 4)), ['y', 'z']) + >>> a2, b2, axes = lt.align(a, b) + >>> a2 + <LabeledTensor 'lt_align_1/lt_align_1/0:...' shape=(2, 1, 4) dtype=float32 + axes=[('x', Dimension(2)), + ('y', Dimension(1)), + ('z', Dimension(4))]> + >>> b2 + <LabeledTensor 'lt_align_1/lt_align_1/1:...' shape=(1, 3, 4) dtype=float32 + axes=[('x', Dimension(1)), + ('y', Dimension(3)), + ('z', Dimension(4))]> + >>> axes + Axes([('x', Dimension(2)), + ('y', Dimension(3)), + ('z', Dimension(4))]) + + Args: + labeled_tensor_0: An input tensor. + labeled_tensor_1: An input tensor. + name: Optional op name. + + Returns: + The aligned tensors and the axes the resulting tensor would have if the two + aligned tensors were broadcast to each other. The aligned tensors have the + same rank but not necessarily the same shape, with axes in the same order. + + Raises: + ValueError: If axes with the same name on the inputs are not equal. + AxisOrderError: If there is no way to reshape the input tensors into the + output without a transpose. + """ + with ops.name_scope(name, 'lt_align', + [labeled_tensor_0, labeled_tensor_1]) as scope: + + labeled_tensor_0 = convert_to_labeled_tensor(labeled_tensor_0) + labeled_tensor_1 = convert_to_labeled_tensor(labeled_tensor_1) + + axes_0 = labeled_tensor_0.axes + axes_1 = labeled_tensor_1.axes + for axis_name in axes_0: + if axis_name in axes_1: + if axes_0[axis_name] != axes_1[axis_name]: + raise ValueError('Mismatched %r axis on input tensors: %r and %r' % + (axis_name, axes_0[axis_name], axes_1[axis_name])) + + axis_scope_order = get_axis_order() + if axis_scope_order is not None: + # we are in an axis_order_scope + axis_names_set = set(axes_0) | set(axes_1) + new_axis_names = [a for a in axis_scope_order if a in axis_names_set] + + check_axis_order(labeled_tensor_0, axis_scope_order) + check_axis_order(labeled_tensor_1, axis_scope_order) + + else: + # attempt to find a consistent ordering + new_axis_names = _find_consistent_ordering(list(axes_0), list(axes_1)) + if new_axis_names is None: + raise AxisOrderError( + 'No consistent axis order allows for aligning tensors with axis ' + 'orders %r and %r without copying data. Use transpose or ' + 'impose_axis_order to reorder axes on one of more of the inputs.' % + (axes_0.keys(), axes_1.keys())) + + labeled_tensor_0 = expand_dims(labeled_tensor_0, + new_axis_names, + name=scope + '0') + labeled_tensor_1 = expand_dims(labeled_tensor_1, + new_axis_names, + name=scope + '1') + + broadcast_axes = [] + for axis_name in new_axis_names: + if axis_name in axes_0: + broadcast_axes.append(axes_0[axis_name]) + else: + broadcast_axes.append(axes_1[axis_name]) + + return labeled_tensor_0, labeled_tensor_1, Axes(broadcast_axes) + + +@tc.returns(types.FunctionType) +@tc.accepts(string_types, collections.Callable) +def define_unary_op(op_name, elementwise_function): + """Define a unary operation for labeled tensors. + + Args: + op_name: string name of the TensorFlow op. + elementwise_function: function to call to evaluate the op on a single + tf.Tensor object. This function must accept two arguments: a tf.Tensor + object, and an optional `name`. + + Returns: + Function defining the given op that acts on LabeledTensors. + """ + + default_name = 'lt_%s' % op_name + + @tc.returns(LabeledTensor) + @tc.accepts(LabeledTensorLike, tc.Optional(string_types)) + def op(labeled_tensor, name=None): + """LabeledTensor version of `tf.{op_name}`. + + See `tf.{op_name}` for full details. + + Args: + labeled_tensor: Input tensor. + name: Optional op name. + + Returns: + A LabeledTensor with result of applying `tf.{op_name}` elementwise. + """ + with ops.name_scope(name, default_name, [labeled_tensor]) as scope: + labeled_tensor = convert_to_labeled_tensor(labeled_tensor) + result_tensor = elementwise_function(labeled_tensor.tensor, name=scope) + return LabeledTensor(result_tensor, labeled_tensor.axes) + + op.__doc__ = op.__doc__.format(op_name=op_name) + op.__name__ = op_name + + return op + + +abs_function = define_unary_op('abs', math_ops.abs) +neg = define_unary_op('neg', math_ops.neg) +sign = define_unary_op('sign', math_ops.sign) +inv = define_unary_op('inv', math_ops.inv) +square = define_unary_op('square', math_ops.square) +round_function = define_unary_op('round', math_ops.round) +sqrt = define_unary_op('sqrt', math_ops.sqrt) +rsqrt = define_unary_op('rsqrt', math_ops.rsqrt) +exp = define_unary_op('exp', math_ops.exp) +log = define_unary_op('log', math_ops.log) +ceil = define_unary_op('ceil', math_ops.ceil) +floor = define_unary_op('floor', math_ops.floor) +cos = define_unary_op('cos', math_ops.cos) +sin = define_unary_op('sin', math_ops.sin) +tan = define_unary_op('tan', math_ops.tan) +acos = define_unary_op('acos', math_ops.acos) +asin = define_unary_op('asin', math_ops.asin) +atan = define_unary_op('atan', math_ops.atan) +lgamma = define_unary_op('lgamma', math_ops.lgamma) +digamma = define_unary_op('digamma', math_ops.digamma) +erf = define_unary_op('erf', math_ops.erf) +erfc = define_unary_op('erfc', math_ops.erfc) +logical_not = define_unary_op('logical_not', math_ops.logical_not) +tanh = define_unary_op('tanh', math_ops.tanh) +sigmoid = define_unary_op('sigmoid', math_ops.sigmoid) + + +@tc.returns(types.FunctionType) +@tc.accepts(string_types, collections.Callable) +def define_binary_op(op_name, elementwise_function): + """Define a binary operation that broadcasts labeled tensors. + + Args: + op_name: string name of the TensorFlow op. + elementwise_function: function to call to evaluate the op on tf.Tensor + objects. This function must accept three arguments: two tf.Tensor objects, + and an optional `name`. + + Returns: + Function defining the given op that acts on LabeledTensors. + """ + + default_name = 'lt_%s' % op_name + + @tc.returns(LabeledTensor) + @tc.accepts(LabeledTensorLike, LabeledTensorLike, tc.Optional(string_types)) + def op(labeled_tensor_0, labeled_tensor_1, name=None): + """LabeledTensor version of `tf.{op_name}` with label based alignment. + + See `tf.{op_name}` for full details. + + Args: + labeled_tensor_0: Input tensor. + labeled_tensor_1: Input tensor. + name: Optional op name. + + Returns: + A LabeledTensor with result of applying `tf.{op_name}` elementwise. + """ + with ops.name_scope(name, default_name, + [labeled_tensor_0, labeled_tensor_1]) as scope: + + align_0, align_1, broadcast_axes = align(labeled_tensor_0, + labeled_tensor_1) + + tensor = elementwise_function(align_0.tensor, align_1.tensor, name=scope) + + return LabeledTensor(tensor, broadcast_axes) + + op.__doc__ = op.__doc__.format(op_name=op_name) + op.__name__ = op_name + + return op + + +add = define_binary_op('add', math_ops.add) +sub = define_binary_op('sub', math_ops.sub) +mul = define_binary_op('mul', math_ops.mul) +div = define_binary_op('div', math_ops.div) +mod = define_binary_op('mod', math_ops.mod) +pow_function = define_binary_op('pow', math_ops.pow) + +equal = define_binary_op('equal', math_ops.equal) +greater = define_binary_op('greater', math_ops.greater) +greater_equal = define_binary_op('greater_equal', math_ops.greater_equal) +not_equal = define_binary_op('not_equal', math_ops.not_equal) +less = define_binary_op('less', math_ops.less) +less_equal = define_binary_op('less_equal', math_ops.less_equal) +logical_and = define_binary_op('logical_and', math_ops.logical_and) +logical_or = define_binary_op('logical_or', math_ops.logical_or) +logical_xor = define_binary_op('logical_xor', math_ops.logical_xor) + +maximum = define_binary_op('maximum', math_ops.maximum) +minimum = define_binary_op('minimum', math_ops.minimum) +squared_difference = define_binary_op( + 'squared_difference', math_ops.squared_difference) +igamma = define_binary_op('igamma', math_ops.igamma) +igammac = define_binary_op('igammac', math_ops.igammac) +zeta = define_binary_op('zeta', math_ops.zeta) +polygamma = define_binary_op('polygamma', math_ops.polygamma) diff --git a/tensorflow/contrib/labeled_tensor/python/ops/core_test.py b/tensorflow/contrib/labeled_tensor/python/ops/core_test.py new file mode 100644 index 0000000000..5710dc34e8 --- /dev/null +++ b/tensorflow/contrib/labeled_tensor/python/ops/core_test.py @@ -0,0 +1,842 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import operator +import re +import textwrap + +import numpy as np +from six.moves import range # pylint: disable=redefined-builtin +import tensorflow as tf + +from tensorflow.contrib.labeled_tensor.python.ops import _typecheck as tc +from tensorflow.contrib.labeled_tensor.python.ops import core +from tensorflow.contrib.labeled_tensor.python.ops import test_util + + +class AxisTest(tf.test.TestCase): + + def setUp(self): + d_7 = tf.Dimension(7) + p_rgb = ['red', 'green', 'blue'] + + self.i_7 = core.Axis('7', d_7) + self.i_7p = core.Axis('7prime', d_7) + self.i_rgb = core.Axis('rgb', p_rgb) + self.i_range = core.Axis('range', range(7)) + self.i_unknown = core.Axis('unknown', None) + + def test_equality(self): + + axes = [self.i_7, self.i_7p, self.i_rgb, self.i_range, self.i_unknown] + for i, axis_0 in enumerate(axes): + for j, axis_1 in enumerate(axes): + if i == j: + self.assertEqual(axis_0, axis_1) + else: + self.assertNotEqual(axis_0, axis_1) + + def test_axis_value(self): + self.assertEqual(self.i_7.value, tf.Dimension(7)) + self.assertTrue(self.i_range.value == tuple(range(7))) + + def test_axis_input(self): + axes = [self.i_7, self.i_7p, self.i_rgb, self.i_range, self.i_unknown] + for axis in axes: + self.assertEqual(axis, core.Axis(axis.name, axis.value)) + + def test_axis_value_input(self): + axis = self.i_range + for value in [range(7), list(range(7)), np.arange(7)]: + self.assertEqual(axis, core.Axis(axis.name, value)) + + def test_size(self): + self.assertEqual(len(self.i_7), 7) + self.assertEqual(len(self.i_rgb), 3) + self.assertEqual(len(self.i_range), 7) + self.assertEqual(self.i_unknown.size, None) + + def test_concat_single(self): + red = core.Axis('rgb', ['red']) + + self.assertEqual(core.concat_axes([red]), red) + + def test_concat_many(self): + red = core.Axis('rgb', ['red']) + green = core.Axis('rgb', ['green']) + blue = core.Axis('rgb', ['blue']) + red_green_blue = core.Axis('rgb', ['red', 'green', 'blue']) + + self.assertEqual(core.concat_axes([red, green, blue]), red_green_blue) + + def test_concat_different_names(self): + red = core.Axis('red', ['red']) + green = core.Axis('green', ['red']) + with self.assertRaises(ValueError): + core.concat_axes([red, green]) + + def test_concat_unknown(self): + red = core.Axis('rgb', None) + green = core.Axis('rgb', None) + self.assertEqual(core.concat_axes([red, green]), red) + + def test_repr(self): + self.assertEqual("Axis('7', Dimension(7))", repr(self.i_7)) + + def test_invalid_input(self): + with self.assertRaises(TypeError): + core.Axis('foo', [{}]) + with self.assertRaises(ValueError): + core.Axis('foo', [1, 2, 3, 1]) + red = core.Axis('foo', ['red']) + with self.assertRaises(tc.Error): + core.concat_axes([red, 1]) + + def test_as_axis(self): + self.assertEqual(self.i_7, core.as_axis(('7', 7))) + self.assertEqual(self.i_7, core.as_axis(self.i_7)) + + +class AxesTest(tf.test.TestCase): + + def setUp(self): + d_7 = tf.Dimension(7) + d_8 = tf.Dimension(8) + p_rgb = ['red', 'green', 'blue'] + p_range = range(7) + + self.i_8 = core.Axis('8', d_8) + + self.a0 = core.Axes([('d7', d_7)]) + self.a1 = core.Axes([('d7', d_7)]) + self.a2 = core.Axes([('d7', d_7), ('rgb', p_rgb)]) + self.a3 = core.Axes([('8', d_8), ('range', p_range)]) + + def test_equality(self): + self.assertEqual(self.a0, self.a0) + self.assertEqual(self.a0, self.a1) + self.assertNotEqual(self.a0, self.a2) + + def test_repr(self): + self.assertEqual("Axes([('d7', Dimension(7))])", repr(self.a0)) + + def test_remove(self): + a = self.a3.remove('range') + self.assertEqual(a, core.Axes([self.i_8])) + with self.assertRaises(KeyError): + self.a3.remove('foobar') + + def test_typecheck_error_message(self): + pattern = ('List(Union(labeled_tensor.Axis, Tuple(..., ' + 'Union(Union(numpy.ndarray, %s, list, tuple), ' + 'Optional(Union(tensorflow.Dimension, int))))))' % + range.__name__) + regexp = re.escape(pattern).replace(re.escape('...'), '.*') + with self.assertRaisesRegexp(tc.Error, 'allowed type ' + regexp): + core.Axes(None) + + +class LabeledTensorTest(test_util.Base): + + def setUp(self): + tensor = tf.ones([7, 3, 8, 1]) + a0 = ('x', range(7)) + a1 = ('channel', ['red', 'green', 'blue']) + a2 = ('y', 8) + a3 = ('z', tf.Dimension(1)) + + self.lt = core.LabeledTensor(tensor, [a0, a1, a2, a3]) + + def test_repr(self): + pattern = textwrap.dedent("""\ + <LabeledTensor '...' shape=(7, 3, 8, 1) dtype=float32 + axes=[('x', ...), + ('channel', ...), + ('y', Dimension(8)), + ('z', Dimension(1))]>""") + regexp = re.escape(pattern).replace(re.escape('...'), '.*') + self.assertRegexpMatches(repr(self.lt), regexp) + + def test_reuse_existing_axes(self): + alt_lt = core.LabeledTensor(self.lt.tensor, self.lt.axes) + self.assertLabeledTensorsEqual(alt_lt, self.lt) + + def test_reuse_existing_axis_objects(self): + alt_lt = core.LabeledTensor(self.lt.tensor, self.lt.axes.values()) + self.assertLabeledTensorsEqual(alt_lt, self.lt) + + def test_indexing_scalars(self): + actual = self.lt[:, :, :, 0] + expected = core.LabeledTensor(self.lt.tensor[:, :, :, 0], + list(self.lt.axes.values())[:-1]) + self.assertLabeledTensorsEqual(actual, expected) + + actual = self.lt[1, :, :, 0] + expected = core.LabeledTensor(self.lt.tensor[1, :, :, 0], + list(self.lt.axes.values())[1:-1]) + self.assertLabeledTensorsEqual(actual, expected) + + actual = self.lt[1, 2, :, 0] + expected = core.LabeledTensor(self.lt.tensor[1, 2, :, 0], + list(self.lt.axes.values())[2:-1]) + self.assertLabeledTensorsEqual(actual, expected) + + def test_indexing_1d(self): + lt_1d = self.lt[1, 2, :, 0] + actual = lt_1d[3] + expected = core.LabeledTensor(lt_1d.tensor[3], []) + self.assertLabeledTensorsEqual(actual, expected) + + def test_indexing_slices(self): + actual = self.lt[:3, :, :, :] + axes = [('x', range(3))] + list(self.lt.axes.values())[1:] + expected = core.LabeledTensor(self.lt.tensor[:3, :, :, :], axes) + self.assertLabeledTensorsEqual(actual, expected) + + def test_invalid_indexing(self): + with self.assertRaises(ValueError): + self.lt[0] # pylint: disable=pointless-statement + with self.assertRaises(ValueError): + self.lt[:, :, :, :, 0] # pylint: disable=pointless-statement + + def test_unknown_size(self): + tensor = tf.placeholder(tf.string, [None]) + actual = core.LabeledTensor(tensor, ['x']) + self.assertIsNone(actual.axes['x'].size) + self.assertIs(actual.axes['x'].value, tensor.get_shape()[0]) + + def test_eq(self): + self.assertEqual(self.lt, self.lt) + self.assertNotEqual(self.lt, self.lt.tensor) + self.assertNotEqual(self.lt.tensor, self.lt) + + def test_hash(self): + lt1 = self.lt + lt2 = core.LabeledTensor(self.lt.tensor, self.lt.axes) + self.assertEqual(lt1, lt2) + self.assertEqual(hash(lt1), hash(lt2)) + + def test_name(self): + self.assertEqual(self.lt.name, self.lt.tensor.name) + + def test_dtype(self): + self.assertEqual(self.lt.dtype, self.lt.tensor.dtype) + + def test_get_shape(self): + self.assertEqual(self.lt.get_shape(), self.lt.tensor.get_shape()) + + def test_convert_to_tensor(self): + expected = self.lt.tensor + actual = tf.convert_to_tensor(self.lt) + self.assertIs(expected, actual) + + +class Base(test_util.Base): + + def setUp(self): + self.x_size = 7 + self.channel_size = 3 + self.z_size = 4 + self.probs_size = 11 + + tensor = tf.range(0, self.x_size * self.channel_size * self.z_size * + self.probs_size) + tensor = tf.reshape(tensor, [self.x_size, self.channel_size, self.z_size, + self.probs_size]) + a0 = ('x', range(self.x_size)) + a1 = ('channel', ['red', 'green', 'blue']) + a2 = 'z' + a3 = ('probs', np.linspace(0.0, 1.0, self.probs_size)) + + self.tensor = tensor + self.a0 = a0 + self.a1 = a1 + self.a2 = a2 + self.a3 = a3 + self.original_lt = core.LabeledTensor(tensor, [a0, a1, a2, a3]) + + self.x_probs_lt = core.slice_function(self.original_lt, {'z': 0, + 'channel': 0}) + self.channel_probs_lt = core.slice_function(self.original_lt, {'x': 3, + 'z': 0}) + + +class IdentityTest(Base): + + def test_name(self): + identity_lt = core.identity(self.original_lt) + self.assertIn('lt_identity', identity_lt.name) + + +class SliceFunctionTest(Base): + + def test_name(self): + select_lt = core.slice_function(self.original_lt, {'channel': 1}) + self.assertIn('lt_slice', select_lt.name) + + def test_scalar(self): + select_lt = core.slice_function(self.original_lt, {'channel': 1}) + golden_lt = core.LabeledTensor(self.tensor[:, 1, :, :], [self.a0, self.a2, + self.a3]) + + self.assertLabeledTensorsEqual(select_lt, golden_lt) + + def test_slice(self): + select_lt = core.slice_function(self.original_lt, {'channel': slice(0, 2)}) + + a1_sliced = ('channel', ['red', 'green']) + golden_lt = core.LabeledTensor(self.tensor[:, :2, :, :], + [self.a0, a1_sliced, self.a2, self.a3]) + + self.assertLabeledTensorsEqual(select_lt, golden_lt) + + def test_slices(self): + select_lt = core.slice_function(self.original_lt, {'x': slice(1, 5), + 'channel': slice(1, + None)}) + + a0_sliced = ('x', range(1, 5)) + a1_sliced = ('channel', ['green', 'blue']) + golden_lt = core.LabeledTensor(self.tensor[1:5, 1:, :, :], + [a0_sliced, a1_sliced, self.a2, self.a3]) + + self.assertLabeledTensorsEqual(select_lt, golden_lt) + + def test_slice_unlabeled(self): + select_lt = core.slice_function(self.original_lt, {'z': slice(1, 3)}) + + a2_sliced = 'z' + golden_lt = core.LabeledTensor(self.tensor[:, :, 1:3, :], + [self.a0, self.a1, a2_sliced, self.a3]) + + self.assertLabeledTensorsEqual(select_lt, golden_lt) + + def test_slice_unknown_shape(self): + lt = core.LabeledTensor(tf.placeholder(tf.float32, [None, 1]), ['x', 'y']) + sliced_lt = core.slice_function(lt, {'y': 0}) + self.assertEqual(list(sliced_lt.axes.values()), [lt.axes['x']]) + + +class TransposeTest(Base): + + def test_name(self): + transpose_lt = core.transpose(self.original_lt, + self.original_lt.axes.keys()) + self.assertIn('lt_transpose', transpose_lt.name) + + def test_identity(self): + transpose_lt = core.transpose(self.original_lt, + self.original_lt.axes.keys()) + golden_lt = self.original_lt + + self.assertLabeledTensorsEqual(transpose_lt, golden_lt) + + def test(self): + transpose_lt = core.transpose(self.original_lt, ['z', 'channel', 'x', + 'probs']) + golden_lt = core.LabeledTensor( + tf.transpose(self.tensor, [2, 1, 0, 3]), + [self.a2, self.a1, self.a0, self.a3]) + + self.assertLabeledTensorsEqual(transpose_lt, golden_lt) + + def test_default_axis_order(self): + transpose_lt = core.transpose(self.original_lt) + golden_lt = core.LabeledTensor( + tf.transpose(self.tensor, [3, 2, 1, 0]), + list(reversed(list(self.original_lt.axes.values())))) + + self.assertLabeledTensorsEqual(transpose_lt, golden_lt) + + def test_invalid_input(self): + with self.assertRaises(ValueError): + core.transpose(self.original_lt, ['channel', 'x', 'probs']) + with self.assertRaises(ValueError): + core.transpose(self.original_lt, ['z', 'foo', 'x', 'probs']) + + +class ExpandDimsTest(Base): + + def test_name(self): + expand_lt = core.expand_dims(self.original_lt, self.original_lt.axes.keys()) + self.assertIn('lt_expand', expand_lt.name) + + def test_identity(self): + expand_lt = core.expand_dims(self.original_lt, self.original_lt.axes.keys()) + golden_lt = self.original_lt + + self.assertLabeledTensorsEqual(expand_lt, golden_lt) + + def test(self): + expand_lt = core.expand_dims(self.original_lt, ['foo', 'x', 'bar', + 'channel', 'z', 'probs', + 'grok']) + golden_lt = core.LabeledTensor( + tf.reshape(self.tensor, [1, self.x_size, 1, self.channel_size, + self.z_size, self.probs_size, 1]), + ['foo', self.a0, 'bar', self.a1, self.a2, self.a3, 'grok']) + + self.assertLabeledTensorsEqual(expand_lt, golden_lt) + + def test_label(self): + expand_lt = core.expand_dims(self.original_lt, ['x', + 'channel', + ('foo', 'bar'), + 'z', + 'probs',]) + golden_lt = core.LabeledTensor( + tf.reshape(self.tensor, [self.x_size, self.channel_size, 1, self.z_size, + self.probs_size]), + [self.a0, self.a1, ('foo', ['bar']), self.a2, self.a3]) + + self.assertLabeledTensorsEqual(expand_lt, golden_lt) + + def test_unknown_dimension(self): + orig_lt = core.LabeledTensor(tf.placeholder(tf.float32, [None]), ['x']) + expand_lt = core.expand_dims(orig_lt, ['x', 'y']) + self.assertEqual(expand_lt.axes, core.Axes([('x', None), ('y', 1)])) + + def test_invalid_input(self): + with self.assertRaises(core.AxisOrderError): + core.expand_dims(self.original_lt, ['foo', 'not_x', 'bar', 'channel', 'z', + 'probs', 'grok']) + with self.assertRaises(core.AxisOrderError): + core.expand_dims(self.original_lt, ['foo', 'z', 'bar', 'channel', 'x', + 'probs', 'grok']) + + +class AxisOrderScopeTest(Base): + + def test(self): + xyz = ['x', 'y', 'z'] + abc = ['a', 'b', 'c'] + + self.assertIsNone(core.get_axis_order()) + + with core.axis_order_scope(xyz): + self.assertEqual(core.get_axis_order(), xyz) + + with core.axis_order_scope(): + self.assertIsNone(core.get_axis_order()) + + with core.axis_order_scope(abc): + self.assertEqual(core.get_axis_order(), abc) + + self.assertIsNone(core.get_axis_order()) + + self.assertEqual(core.get_axis_order(), xyz) + + self.assertIsNone(core.get_axis_order()) + + +class CheckAxisOrderTest(Base): + + def test_passes(self): + axis_order = ['w', 'x', 'y', 'z'] + + lt = core.LabeledTensor(tf.ones((1, 1, 1, 1)), axis_order) + core.check_axis_order(lt, axis_order) + + lt = core.LabeledTensor(tf.ones((1, 1, 1)), axis_order[1:]) + core.check_axis_order(lt, axis_order) + + lt = core.LabeledTensor(tf.ones((1, 1, 1)), axis_order[:-1]) + core.check_axis_order(lt, axis_order) + + def test_invalid(self): + axis_order = ['w', 'x', 'y', 'z'] + lt = core.LabeledTensor(tf.ones((1, 1, 1, 1)), axis_order) + with self.assertRaises(core.AxisOrderError): + core.check_axis_order(lt) + with self.assertRaises(core.AxisOrderError): + core.check_axis_order(lt, axis_order[:-1]) + with self.assertRaises(core.AxisOrderError): + core.check_axis_order(lt, axis_order[::-1]) + + def test_scope(self): + axis_order = ['w', 'x', 'y', 'z'] + lt = core.LabeledTensor(tf.ones((1, 1, 1, 1)), axis_order) + with core.axis_order_scope(axis_order): + core.check_axis_order(lt) + + +class ImposeAxisOrderTest(Base): + + def test_identity(self): + axis_order = ['w', 'x', 'y', 'z'] + lt = core.LabeledTensor(tf.reshape(tf.range(24), (1, 2, 3, 4)), axis_order) + actual = core.impose_axis_order(lt, axis_order) + self.assertLabeledTensorsEqual(lt, actual) + + lt = core.LabeledTensor(tf.reshape(tf.range(6), (1, 2, 3)), axis_order[:3]) + actual = core.impose_axis_order(lt, axis_order) + self.assertLabeledTensorsEqual(lt, actual) + + def test_reverse(self): + axis_order = ['w', 'x', 'y', 'z'] + + lt = core.LabeledTensor(tf.reshape(tf.range(24), (1, 2, 3, 4)), axis_order) + actual = core.impose_axis_order(lt, axis_order[::-1]) + expected = core.transpose(lt, axis_order[::-1]) + self.assertLabeledTensorsEqual(expected, actual) + + lt = core.LabeledTensor(tf.reshape(tf.range(6), (1, 2, 3)), axis_order[:3]) + actual = core.impose_axis_order(lt, axis_order[::-1]) + expected = core.transpose(lt, ['y', 'x', 'w']) + self.assertLabeledTensorsEqual(expected, actual) + + def test_scope(self): + axis_order = ['w', 'x', 'y', 'z'] + + lt = core.LabeledTensor(tf.reshape(tf.range(24), (1, 2, 3, 4)), axis_order) + expected = core.transpose(lt, axis_order[::-1]) + with core.axis_order_scope(axis_order[::-1]): + actual = core.impose_axis_order(lt) + self.assertLabeledTensorsEqual(expected, actual) + + def test_invalid(self): + lt = core.LabeledTensor(tf.reshape(tf.range(2), (1, 2)), ['x', 'y']) + with self.assertRaises(ValueError): + core.impose_axis_order(lt) + with self.assertRaises(ValueError): + core.impose_axis_order(lt, ['x']) + + +class FindConsistentOrderingTest(Base): + + def test(self): + cases = [ + ([], [], []), + (['x'], [], ['x']), + ([], ['x'], ['x']), + (['x'], ['x'], ['x']), + (['x'], ['y'], ['x', 'y']), + (['y'], ['x'], ['y', 'x']), + (['x', 'y'], ['x', 'y'], ['x', 'y']), + (['x', 'y'], ['y', 'x'], None), + (['x', 'y'], ['y', 'z'], ['x', 'y', 'z']), + (['x', 'z'], ['y', 'z'], ['x', 'y', 'z']), + (['x', 'y'], ['x', 'z'], ['x', 'y', 'z']), + (['w', 'x'], ['y', 'z'], ['w', 'x', 'y', 'z']), + (['x', 'y', 'z'], ['z', 'x'], None), + (['x', 'y', 'z'], ['x'], ['x', 'y', 'z']), + ([], ['x', 'y', 'z'], ['x', 'y', 'z']), + ] + for a, b, expected in cases: + actual = core._find_consistent_ordering(a, b) + msg = ('unexpected ordering between %r and %r:\nexpected: %r\nactual: %r' + % (a, b, expected, actual)) + self.assertEqual(expected, actual, msg=msg) + + +class AlignTest(Base): + + def test_name(self): + align_lt_0, align_lt_1, _ = core.align(self.original_lt, self.original_lt) + self.assertIn('lt_align', align_lt_0.name) + self.assertIn('/0', align_lt_0.name) + self.assertIn('lt_align', align_lt_1.name) + self.assertIn('/1', align_lt_1.name) + + def test_identical_shaped_inputs(self): + offset_tensor = self.original_lt.tensor + 1 + offset_lt = core.LabeledTensor(offset_tensor, self.original_lt.axes) + + align_lt, align_offset_lt, broadcast_axes = core.align(self.original_lt, + offset_lt) + + self.assertLabeledTensorsEqual(align_lt, self.original_lt) + self.assertLabeledTensorsEqual(align_offset_lt, offset_lt) + self.assertEqual(broadcast_axes, self.original_lt.axes) + + def test_different_inputs(self): + # The correct axis ordering is ['x', 'channel', 'probs']. + align_x_probs_lt, align_channel_probs_lt, broadcast_axes = core.align( + self.x_probs_lt, self.channel_probs_lt) + + x_probs_golden_lt = core.LabeledTensor( + tf.reshape(self.x_probs_lt.tensor, [self.x_size, 1, self.probs_size]), + [self.a0, 'channel', self.a3]) + + self.assertLabeledTensorsEqual(align_x_probs_lt, x_probs_golden_lt) + + channel_probs_golden_lt = core.LabeledTensor( + tf.reshape(self.channel_probs_lt.tensor, + [1, self.channel_size, self.probs_size]), + ['x', self.a1, self.a3]) + + self.assertLabeledTensorsEqual(align_channel_probs_lt, + channel_probs_golden_lt) + + self.assertEqual(broadcast_axes, core.Axes([self.a0, self.a1, self.a3])) + + def test_axis_order_scope(self): + xz_lt = core.LabeledTensor(tf.ones((2, 3)), ['x', 'z']) + yz_lt = core.LabeledTensor(tf.ones((4, 3)), ['y', 'z']) + + _, _, broadcast_axes = core.align(xz_lt, yz_lt) + self.assertEqual(list(broadcast_axes.keys()), ['x', 'y', 'z']) + + _, _, broadcast_axes = core.align(yz_lt, xz_lt) + self.assertEqual(list(broadcast_axes.keys()), ['y', 'x', 'z']) + + with core.axis_order_scope(['x', 'y', 'z']): + _, _, broadcast_axes = core.align(yz_lt, xz_lt) + self.assertEqual(list(broadcast_axes.keys()), ['x', 'y', 'z']) + + with core.axis_order_scope(['x', 'y']): + with self.assertRaises(core.AxisOrderError): + core.align(xz_lt, yz_lt) + with self.assertRaises(core.AxisOrderError): + core.align(yz_lt, xz_lt) + + def test_invalid_input(self): + lt_0 = core.LabeledTensor(tf.zeros([5]), [('a', range(5))]) + lt_1 = core.LabeledTensor(tf.zeros([5]), [('a', range(1, 6))]) + with self.assertRaises(ValueError): + core.align(lt_0, lt_1) + + +class ConvertToLabeledTensorTest(Base): + + # TODO(shoyer): Simplify these tests once we can reuse labeled tensors in + # assertLabeledTensorsEqual. + + def test_labeled_tensor(self): + actual = core.convert_to_labeled_tensor(self.original_lt) + self.assertLabeledTensorsEqual(actual, self.original_lt) + + def test_python_scalar(self): + actual = core.convert_to_labeled_tensor(42) + golden_lt = core.LabeledTensor(tf.convert_to_tensor(42), []) + self.assertLabeledTensorsEqual(actual, golden_lt) + + def test_numpy_array(self): + actual = core.convert_to_labeled_tensor(np.array(42)) + golden_lt = core.LabeledTensor(tf.convert_to_tensor(42), []) + self.assertLabeledTensorsEqual(actual, golden_lt) + + def test_tensor(self): + actual = core.convert_to_labeled_tensor(tf.constant(42)) + golden_lt = core.LabeledTensor(tf.convert_to_tensor(42), []) + self.assertLabeledTensorsEqual(actual, golden_lt) + + def test_invalid_input(self): + with self.assertRaises(ValueError): + core.convert_to_labeled_tensor(tf.range(5)) + with self.assertRaises(ValueError): + core.convert_to_labeled_tensor(np.array([1, 2])) + + +class DocStringCheckMixin(object): + # requires self.ops to be defined + + def test_function_docstring_and_name(self): + for op_name, _, _, lt_op in self.ops: + if lt_op is not None: + self.assertIn('tf.%s' % op_name, lt_op.__doc__) + self.assertEqual(op_name, lt_op.__name__) + + +class UnaryOpsTestsMixin(object): + # requires self.ops and self.test_lt to be defined + + def test_core_op(self): + for op_name, _, tf_op, lt_op in self.ops: + if tf_op is not None: + golden_lt = core.LabeledTensor(tf_op(self.test_lt.tensor), + self.test_lt.axes) + actual_lt = lt_op(self.test_lt) + self.assertIn(op_name, actual_lt.name) + self.assertLabeledTensorsEqual(golden_lt, actual_lt) + + def test_infix(self): + for op_name, infix_op, _, _ in self.ops: + if infix_op is not None: + expected_lt = core.LabeledTensor(infix_op(self.test_lt.tensor), + self.test_lt.axes) + actual_lt = infix_op(self.test_lt) + self.assertIn(op_name, actual_lt.name) + self.assertLabeledTensorsEqual(expected_lt, actual_lt) + + +class CoreUnaryOpsTest(Base, DocStringCheckMixin, UnaryOpsTestsMixin): + + def setUp(self): + super(CoreUnaryOpsTest, self).setUp() + + self.ops = [ + ('abs', operator.abs, tf.abs, core.abs_function), + ('neg', operator.neg, tf.neg, core.neg), + # TODO(shoyer): add unary + to core TensorFlow + ('pos', None, None, None), + ('sign', None, tf.sign, core.sign), + ('inv', None, tf.inv, core.inv), + ('square', None, tf.square, core.square), + ('round', None, tf.round, core.round_function), + ('sqrt', None, tf.sqrt, core.sqrt), + ('rsqrt', None, tf.rsqrt, core.rsqrt), + ('log', None, tf.log, core.log), + ('exp', None, tf.exp, core.exp), + ('log', None, tf.log, core.log), + ('ceil', None, tf.ceil, core.ceil), + ('floor', None, tf.floor, core.floor), + ('cos', None, tf.cos, core.cos), + ('sin', None, tf.sin, core.sin), + ('tan', None, tf.tan, core.tan), + ('acos', None, tf.acos, core.acos), + ('asin', None, tf.asin, core.asin), + ('atan', None, tf.atan, core.atan), + ('lgamma', None, tf.lgamma, core.lgamma), + ('digamma', None, tf.digamma, core.digamma), + ('erf', None, tf.erf, core.erf), + ('erfc', None, tf.erfc, core.erfc), + ('lgamma', None, tf.lgamma, core.lgamma), + ] + total_size = np.prod([v.size for v in self.original_lt.axes.values()]) + self.test_lt = core.LabeledTensor( + tf.cast(self.original_lt, tf.float32) / total_size, + self.original_lt.axes) + + +class LogicalNotTest(Base, DocStringCheckMixin, UnaryOpsTestsMixin): + + def setUp(self): + super(LogicalNotTest, self).setUp() + self.ops = [ + ('logical_not', operator.invert, tf.logical_not, core.logical_not), + ] + self.test_lt = self.original_lt < 10 + + +class BinaryOpsTestsMixin(object): + # requires self.ops, self.test_lt_1, self.test_lt_2, self.test_lt_1_broadcast + # and self.test_lt_2_broadcast to be defined + + def test_core_op(self): + for op_name, _, tf_op, lt_op in self.ops: + golden_tensor = tf_op(self.test_lt_1_broadcast, + self.test_lt_2_broadcast) + golden_lt = core.LabeledTensor(golden_tensor, self.broadcast_axes) + actual_lt = lt_op(self.test_lt_1, self.test_lt_2) + self.assertIn(op_name, actual_lt.name) + self.assertLabeledTensorsEqual(golden_lt, actual_lt) + + def test_infix(self): + for op_name, infix_op, _, lt_op in self.ops: + if infix_op is not None: + expected_lt = lt_op(self.test_lt_1, self.test_lt_2) + actual_lt = infix_op(self.test_lt_1, self.test_lt_2) + self.assertIn(op_name, actual_lt.name) + self.assertLabeledTensorsEqual(expected_lt, actual_lt) + + +class CoreBinaryOpsTest(Base, DocStringCheckMixin, BinaryOpsTestsMixin): + + def setUp(self): + super(CoreBinaryOpsTest, self).setUp() + + self.x_probs_broadcast_tensor = tf.reshape( + self.x_probs_lt.tensor, [self.x_size, 1, self.probs_size]) + + self.channel_probs_broadcast_tensor = tf.reshape( + self.channel_probs_lt.tensor, [1, self.channel_size, self.probs_size]) + + # == and != are not element-wise for tf.Tensor, so they shouldn't be + # elementwise for LabeledTensor, either. + self.ops = [ + ('add', operator.add, tf.add, core.add), + ('sub', operator.sub, tf.sub, core.sub), + ('mul', operator.mul, tf.mul, core.mul), + ('div', operator.truediv, tf.div, core.div), + ('mod', operator.mod, tf.mod, core.mod), + ('pow', operator.pow, tf.pow, core.pow_function), + ('equal', None, tf.equal, core.equal), + ('less', operator.lt, tf.less, core.less), + ('less_equal', operator.le, tf.less_equal, core.less_equal), + ('not_equal', None, tf.not_equal, core.not_equal), + ('greater', operator.gt, tf.greater, core.greater), + ('greater_equal', operator.ge, tf.greater_equal, core.greater_equal), + ] + self.test_lt_1 = self.x_probs_lt + self.test_lt_2 = self.channel_probs_lt + self.test_lt_1_broadcast = self.x_probs_broadcast_tensor + self.test_lt_2_broadcast = self.channel_probs_broadcast_tensor + self.broadcast_axes = [self.a0, self.a1, self.a3] + + def test_reflexive(self): + labeled_tensor = self.x_probs_lt + 1 # all elements must be >0 for division + for op_name, infix_op, _, lt_op in self.ops: + if infix_op is not None: + expected_lt = lt_op(2, labeled_tensor) + actual_lt = infix_op(2, labeled_tensor) + # Python uses greater for the reflexive version of less (and vise-versa) + if 'less' in op_name: + op_name = op_name.replace('less', 'greater') + elif 'greater' in op_name: + op_name = op_name.replace('greater', 'less') + self.assertIn(op_name, actual_lt.name) + self.assertLabeledTensorsEqual(expected_lt, actual_lt) + + +class LogicalBinaryOpsTest(Base, DocStringCheckMixin, BinaryOpsTestsMixin): + + def setUp(self): + super(LogicalBinaryOpsTest, self).setUp() + + self.ops = [ + ('logical_and', operator.and_, tf.logical_and, core.logical_and), + ('logical_or', operator.or_, tf.logical_or, core.logical_or), + ('logical_xor', operator.xor, tf.logical_xor, core.logical_xor), + ] + self.test_lt_1 = self.original_lt < 10 + self.test_lt_2 = self.original_lt < 5 + self.test_lt_1_broadcast = self.test_lt_1.tensor + self.test_lt_2_broadcast = self.test_lt_2.tensor + self.broadcast_axes = self.test_lt_1.axes + + +class FloatBinaryOpsTest(Base, DocStringCheckMixin, BinaryOpsTestsMixin): + + def setUp(self): + super(FloatBinaryOpsTest, self).setUp() + + self.ops = [ + ('igamma', None, tf.igamma, core.igamma), + ('igammac', None, tf.igammac, core.igammac), + ('zeta', None, tf.zeta, core.zeta), + ('polygamma', None, tf.polygamma, core.polygamma), + ('maximum', None, tf.maximum, core.maximum), + ('minimum', None, tf.minimum, core.minimum), + ('squared_difference', None, tf.squared_difference, + core.squared_difference), + ] + total_size = np.prod([v.size for v in self.original_lt.axes.values()]) + test_lt = core.LabeledTensor( + tf.cast(self.original_lt, tf.float32) / total_size, + self.original_lt.axes) + self.test_lt_1 = test_lt + self.test_lt_2 = 1.0 - test_lt + self.test_lt_1_broadcast = self.test_lt_1.tensor + self.test_lt_2_broadcast = self.test_lt_2.tensor + self.broadcast_axes = self.test_lt_1.axes + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/labeled_tensor/python/ops/io_ops.py b/tensorflow/contrib/labeled_tensor/python/ops/io_ops.py new file mode 100644 index 0000000000..3bb9c21c2e --- /dev/null +++ b/tensorflow/contrib/labeled_tensor/python/ops/io_ops.py @@ -0,0 +1,178 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Input parsing code for LabeledTensors.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from six import string_types + +from tensorflow.contrib.labeled_tensor.python.ops import _typecheck as tc +from tensorflow.contrib.labeled_tensor.python.ops import core +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import parsing_ops + + +class FixedLenFeature(object): + """Configuration for parsing a fixed-length input feature. + + Fields: + axes: A list of Axis objects or tuples (axis_name, axis_value), + where `axis_name` is a string and `axis_value` is None (unknown size), an + integer or a list of tick labels. + dtype: Data type of input. + default_value: Value to be used if an example is missing this feature. It + must be compatible with `dtype`. + """ + + def __init__(self, axes, dtype, default_value=None): + self._axes = [core.as_axis(a) for a in axes] + self._dtype = dtype + self._default_value = default_value + + @property + def axes(self): + return self._axes + + @property + def dtype(self): + return self._dtype + + @property + def default_value(self): + return self._default_value + + +@tc.returns(tc.Dict(string_types, parsing_ops.FixedLenFeature)) +@tc.accepts(tc.Mapping(string_types, FixedLenFeature)) +def _labeled_to_unlabeled_features(features): + """Convert a dict of lt.FixedLenFeature into a dict of tf.FixedLenFeature.""" + unlabeled_features = {} + for name, labeled_feature in features.items(): + shape = [ax.size for ax in labeled_feature.axes] + if any(size is None for size in shape): + # This should be caught on the TensorFlow side, but it isn't yet: + # https://github.com/tensorflow/tensorflow/issues/2874 + raise ValueError('axes with unknown size are not supported') + dtype = labeled_feature.dtype + default_value = labeled_feature.default_value + unlabeled_features[name] = parsing_ops.FixedLenFeature( + shape, dtype, default_value) + return unlabeled_features + + +@tc.returns(tc.Dict(string_types, core.LabeledTensor)) +@tc.accepts(core.LabeledTensorLike, tc.Mapping(string_types, FixedLenFeature), + tc.Optional(string_types), object) +def parse_example(serialized, features, name=None, example_names=None): + """Parse `Example` protos into a `dict` of labeled tensors. + + See tf.parse_example. + + Args: + serialized: A 1-D LabeledTensor of strings, a batch of binary serialized + `Example` protos. + features: A `dict` mapping feature keys to `labeled_tensor.FixedLenFeature` + values. + name: A name for this operation (optional). + example_names: A vector (1-D Tensor) of strings (optional), the names of + the serialized protos in the batch. + + Returns: + A `dict` mapping feature keys to `LabeledTensor` values. The single axis + from `serialized` will be prepended to the axes provided by each feature. + + Raises: + ValueError: if any feature is invalid. + """ + serialized = core.convert_to_labeled_tensor(serialized) + unlabeled_features = _labeled_to_unlabeled_features(features) + + unlabeled_parsed = parsing_ops.parse_example( + serialized.tensor, unlabeled_features, name, example_names) + + parsed = {} + for name, parsed_feature in unlabeled_parsed.items(): + axes = list(serialized.axes.values()) + features[name].axes + parsed[name] = core.LabeledTensor(parsed_feature, axes) + + return parsed + + +@tc.returns(tc.Dict(string_types, core.LabeledTensor)) +@tc.accepts(core.LabeledTensorLike, tc.Mapping(string_types, FixedLenFeature), + tc.Optional(string_types), object) +def parse_single_example(serialized, features, name=None, example_names=None): + """Parses a single `Example` proto. + + See tf.parse_single_example. + + Args: + serialized: A scalar string Tensor or LabeledTensor, a single serialized + Example. + features: A `dict` mapping feature keys to `labeled_tensor.FixedLenFeature` + values. + name: A name for this operation (optional). + example_names: (Optional) A scalar string Tensor, the associated name. + + Returns: + A `dict` mapping feature keys to `LabeledTensor` values. + + Raises: + ValueError: if any feature is invalid. + """ + serialized = core.convert_to_labeled_tensor(serialized) + unlabeled_features = _labeled_to_unlabeled_features(features) + + unlabeled_parsed = parsing_ops.parse_single_example( + serialized.tensor, unlabeled_features, name, example_names) + + parsed = {} + for name, parsed_feature in unlabeled_parsed.items(): + parsed[name] = core.LabeledTensor(parsed_feature, features[name].axes) + + return parsed + + +@tc.returns(core.LabeledTensor) +@tc.accepts(dtypes.DType, tc.Collection(tc.Union(string_types, core.AxisLike)), + tc.Optional(string_types)) +def placeholder(dtype, axes, name=None): + """Create a placeholder for a labeled tensor. + + For example: + + lt.placeholder(tf.float32, ['batch', ('channel', ['r', 'g', 'b'])]) + + See tf.placeholder for more details. + + Args: + dtype: The type of elements in the tensor to be fed. + axes: sequence of strings (denoting axes of unknown size) and/or objects + convertable to lt.Axis to label the result. + name: Optional op name. + + Returns: + Placeholder labeled tensor. + """ + with ops.name_scope(name, 'lt_placeholder', []) as scope: + axes = core.Axes([(axis, None) if isinstance(axis, string_types) else axis + for axis in axes]) + shape = [axis.size for axis in axes.values()] + tensor = array_ops.placeholder(dtype, shape, name=scope) + return core.LabeledTensor(tensor, axes) diff --git a/tensorflow/contrib/labeled_tensor/python/ops/io_ops_test.py b/tensorflow/contrib/labeled_tensor/python/ops/io_ops_test.py new file mode 100644 index 0000000000..b9d3d9cec2 --- /dev/null +++ b/tensorflow/contrib/labeled_tensor/python/ops/io_ops_test.py @@ -0,0 +1,106 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.contrib.labeled_tensor.python.ops import core +from tensorflow.contrib.labeled_tensor.python.ops import io_ops +from tensorflow.contrib.labeled_tensor.python.ops import test_util + + +class ParseBase(test_util.Base): + + def setUp(self): + super(ParseBase, self).setUp() + examples = [ + tf.train.Example(features=tf.train.Features(feature={ + 'a': tf.train.Feature( + int64_list=tf.train.Int64List(value=[1])), + 'b': tf.train.Feature( + int64_list=tf.train.Int64List(value=[2, 3, 4])), + })), + tf.train.Example(features=tf.train.Features(feature={ + 'a': tf.train.Feature( + int64_list=tf.train.Int64List(value=[5])), + 'b': tf.train.Feature( + int64_list=tf.train.Int64List(value=[6, 7, 8])), + })), + ] + self.serialized = core.LabeledTensor( + tf.constant([ex.SerializeToString() for ex in examples]), ['batch']) + self.features = {'a': io_ops.FixedLenFeature([], tf.int64), + 'b': io_ops.FixedLenFeature([('x', 3)], tf.int64)} + + +class TestParseExample(ParseBase): + + def test(self): + expected_a = core.LabeledTensor(tf.constant([1, 5]), ['batch']) + expected_b = core.LabeledTensor(tf.constant([[2, 3, 4], [6, 7, 8]]), + ['batch', 'x']) + parsed = io_ops.parse_example(self.serialized, self.features) + self.assertLabeledTensorsEqual(expected_a, parsed['a']) + self.assertLabeledTensorsEqual(expected_b, parsed['b']) + + def test_placeholder(self): + serialized = core.LabeledTensor(tf.placeholder(tf.string, [None]), + ['batch']) + # should not raise + io_ops.parse_example(serialized, self.features) + + +class TestParseSingleExample(ParseBase): + + def test(self): + expected_a = core.LabeledTensor(tf.constant(1), []) + expected_b = core.LabeledTensor(tf.constant([2, 3, 4]), ['x']) + parsed = io_ops.parse_single_example(self.serialized[0], self.features) + self.assertLabeledTensorsEqual(expected_a, parsed['a']) + self.assertLabeledTensorsEqual(expected_b, parsed['b']) + + def test_unknown_size(self): + features = {'a': io_ops.FixedLenFeature([('x', None)], tf.int64)} + serialized = tf.placeholder(tf.string, []) + with self.assertRaisesRegexp(ValueError, 'unknown size'): + io_ops.parse_single_example(serialized, features) + + +class PlaceholderTest(test_util.Base): + + def test_name(self): + placeholder_lt = io_ops.placeholder(tf.float32, []) + self.assertIn('lt_placeholder', placeholder_lt.name) + + def test(self): + placeholder_lt = io_ops.placeholder(tf.float32, + ['batch', ('x', ['a', 'b'])]) + self.assertEqual(placeholder_lt.dtype, tf.float32) + self.assertEqual(placeholder_lt.axes, + core.Axes([('batch', None), ('x', ['a', 'b'])])) + + def test_feed(self): + sess = tf.Session() + placeholder_lt = io_ops.placeholder(tf.float32, []) + two_times = 2.0 * placeholder_lt + result = sess.run(two_times, {placeholder_lt.tensor: 1}) + self.assertEqual(result, 2.0) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/labeled_tensor/python/ops/nn.py b/tensorflow/contrib/labeled_tensor/python/ops/nn.py new file mode 100644 index 0000000000..dce16ccf27 --- /dev/null +++ b/tensorflow/contrib/labeled_tensor/python/ops/nn.py @@ -0,0 +1,42 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Neural network ops for LabeledTensors.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.labeled_tensor.python.ops import core +from tensorflow.python.ops import nn + + +relu = core.define_unary_op('relu', nn.relu) +relu6 = core.define_unary_op('relu6', nn.relu6) +crelu = core.define_unary_op('crelu', nn.crelu) +elu = core.define_unary_op('elu', nn.elu) +softplus = core.define_unary_op('softplus', nn.softplus) + +l2_loss = core.define_unary_op('l2_loss', nn.l2_loss) +sigmoid_cross_entropy_with_logits = core.define_binary_op( + 'sigmoid_cross_entropy_with_logits', + nn.sigmoid_cross_entropy_with_logits) +softmax = core.define_unary_op('softmax', nn.softmax) +log_softmax = core.define_unary_op('log_softmax', nn.log_softmax) +softmax_cross_entropy_with_logits = core.define_binary_op( + 'softmax_cross_entropy_with_logits', + nn.softmax_cross_entropy_with_logits) +sparse_softmax_cross_entropy_with_logits = core.define_binary_op( + 'sparse_softmax_cross_entropy_with_logits', + nn.sparse_softmax_cross_entropy_with_logits) diff --git a/tensorflow/contrib/labeled_tensor/python/ops/nn_test.py b/tensorflow/contrib/labeled_tensor/python/ops/nn_test.py new file mode 100644 index 0000000000..18cbd8b4ed --- /dev/null +++ b/tensorflow/contrib/labeled_tensor/python/ops/nn_test.py @@ -0,0 +1,70 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.contrib.labeled_tensor.python.ops import core +from tensorflow.contrib.labeled_tensor.python.ops import nn +from tensorflow.contrib.labeled_tensor.python.ops import test_util + + +class NNTests(test_util.Base): + + def setUp(self): + super(NNTests, self).setUp() + self.axes = ['x'] + self.original_lt = core.LabeledTensor([0.0, 0.5, 1.0], self.axes) + self.other_lt = 1 - self.original_lt + + def test_unary_ops(self): + ops = [ + ('relu', tf.nn.relu, nn.relu), + ('relu6', tf.nn.relu6, nn.relu6), + ('crelu', tf.nn.crelu, nn.crelu), + ('elu', tf.nn.elu, nn.elu), + ('softplus', tf.nn.softplus, nn.softplus), + ('l2_loss', tf.nn.l2_loss, nn.l2_loss), + ('softmax', tf.nn.softmax, nn.softmax), + ('log_softmax', tf.nn.log_softmax, nn.log_softmax), + ] + for op_name, tf_op, lt_op in ops: + golden_tensor = tf_op(self.original_lt.tensor) + golden_lt = core.LabeledTensor(golden_tensor, self.axes) + actual_lt = lt_op(self.original_lt) + self.assertIn(op_name, actual_lt.name) + self.assertLabeledTensorsEqual(golden_lt, actual_lt) + + def test_binary_ops(self): + ops = [ + ('sigmoid_cross_entropy_with_logits', + tf.nn.sigmoid_cross_entropy_with_logits, + nn.sigmoid_cross_entropy_with_logits), + ('softmax_cross_entropy_with_logits', + tf.nn.softmax_cross_entropy_with_logits, + nn.softmax_cross_entropy_with_logits), + ('sparse_softmax_cross_entropy_with_logits', + tf.nn.sparse_softmax_cross_entropy_with_logits, + nn.sparse_softmax_cross_entropy_with_logits), + ] + for op_name, tf_op, lt_op in ops: + golden_tensor = tf_op(self.original_lt.tensor, self.other_lt.tensor) + golden_lt = core.LabeledTensor(golden_tensor, self.axes) + actual_lt = lt_op(self.original_lt, self.other_lt) + self.assertIn(op_name, actual_lt.name) + self.assertLabeledTensorsEqual(golden_lt, actual_lt) diff --git a/tensorflow/contrib/labeled_tensor/python/ops/ops.py b/tensorflow/contrib/labeled_tensor/python/ops/ops.py new file mode 100644 index 0000000000..a9ddbbd2cf --- /dev/null +++ b/tensorflow/contrib/labeled_tensor/python/ops/ops.py @@ -0,0 +1,1207 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Non-core ops for LabeledTensor.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import types + +import numpy as np +from six import string_types + +from tensorflow.contrib.labeled_tensor.python.ops import _typecheck as tc +from tensorflow.contrib.labeled_tensor.python.ops import core +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import numerics +from tensorflow.python.ops import random_ops +from tensorflow.python.training import input # pylint: disable=redefined-builtin + + +@tc.returns(core.LabeledTensor) +@tc.accepts(core.LabeledTensor, ops.Output, core.Axis, + tc.Optional(string_types)) +def _gather_1d_on_axis(labeled_tensor, indexer, axis, name=None): + with ops.name_scope(name, 'lt_take', [labeled_tensor]) as scope: + temp_axes = core.Axes( + [axis] + list(labeled_tensor.axes.remove(axis.name).values())) + transposed = core.transpose(labeled_tensor, temp_axes.keys()) + indexed = core.LabeledTensor(array_ops.gather(transposed.tensor, indexer), + temp_axes) + return core.transpose(indexed, labeled_tensor.axes.keys(), name=scope) + + +@tc.returns(core.LabeledTensor) +@tc.accepts(core.LabeledTensorLike, + tc.Mapping(string_types, tc.Union( + slice, collections.Hashable, collections.Sequence)), + tc.Optional(string_types)) +def select(labeled_tensor, selection, name=None): + """Slice out a subset of the tensor. + + Args: + labeled_tensor: The input tensor. + selection: A dictionary mapping an axis name to a scalar, slice or list of + values to select. Currently supports two types of selections: + (a) Any number of scalar and/or slice selections. + (b) Exactly one list selection, without any scalars or slices. + name: Optional op name. + + Returns: + The selection as a `LabeledTensor`. + + Raises: + ValueError: If the tensor doesn't have an axis in the selection or if + that axis lacks labels. + KeyError: If any labels in a selection are not found in the original axis. + NotImplementedError: If you attempt to combine a list selection with + scalar selection or another list selection. + """ + with ops.name_scope(name, 'lt_select', [labeled_tensor]) as scope: + labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) + + slices = {} + indexers = {} + for axis_name, value in selection.items(): + if axis_name not in labeled_tensor.axes: + raise ValueError( + 'The tensor does not have an axis named %s. Its axes are: %r' % + (axis_name, labeled_tensor.axes.keys())) + axis = labeled_tensor.axes[axis_name] + if axis.labels is None: + raise ValueError( + 'The axis named %s does not have labels. The axis is: %r' % + (axis_name, axis)) + + if isinstance(value, slice): + # TODO(shoyer): consider deprecating using slices in favor of lists + if value.start is None: + start = None + else: + start = axis.index(value.start) + + if value.stop is None: + stop = None + else: + # For now, follow the pandas convention of making labeled slices + # inclusive of both bounds. + stop = axis.index(value.stop) + 1 + + if value.step is not None: + raise NotImplementedError('slicing with a step is not yet supported') + + slices[axis_name] = slice(start, stop) + + else: + # We're allowing anything NumPy treats as a scalar or 1D array. + value = np.asarray(value) + if value.ndim == 0: + slices[axis_name] = axis.index(value.item()) + elif value.ndim == 1: + if indexers: + raise NotImplementedError( + 'select does not yet support more than one list selection at ' + 'the same time') + indexer = [axis.index(v) for v in value.tolist()] + indexers[axis_name] = ops.convert_to_tensor( + indexer, dtype=dtypes.int64) + else: + raise NotImplementedError( + 'select does not yet support selections with more than one ' + 'dimension: %s on axis %r' % (value, axis_name)) + + if indexers and slices: + raise NotImplementedError( + 'select does not yet support combined scalar and list selection') + + # For now, handle array selection separately, because tf.gather_nd does + # not support gradients yet. Later, using gather_nd will let us combine + # these paths. + if indexers: + (axis_name, indexer), = indexers.items() + axis = core.Axis(axis_name, selection[axis_name]) + return _gather_1d_on_axis(labeled_tensor, indexer, axis, name=scope) + else: + return core.slice_function(labeled_tensor, slices, name=scope) + + +@tc.returns(core.LabeledTensor) +@tc.accepts(tc.Collection(core.LabeledTensorLike), string_types, + tc.Optional(string_types)) +def concat(labeled_tensors, axis_name, name=None): + """Concatenate tensors along a dimension. + + See tf.concat. + + Args: + labeled_tensors: A list of input LabeledTensors. + axis_name: The name of the axis along which to concatenate. + name: Optional op name. + + Returns: + The concatenated tensor. + The coordinate labels for the concatenation dimension are also concatenated, + if they are available for every tensor. + + Raises: + ValueError: If fewer than one tensor inputs is provided, if the tensors + have incompatible axes, or if `axis_name` isn't the name of an axis. + """ + with ops.name_scope(name, 'lt_concat', labeled_tensors) as scope: + labeled_tensors = [core.convert_to_labeled_tensor(lt) + for lt in labeled_tensors] + + if len(labeled_tensors) < 1: + raise ValueError('concat expects at least 1 tensor, but received %s' % + labeled_tensors) + + # All tensors must have these axes. + axes_0 = labeled_tensors[0].axes + axis_names = list(axes_0.keys()) + + if axis_name not in axis_names: + raise ValueError('%s not in %s' % (axis_name, axis_names)) + + shared_axes = axes_0.remove(axis_name) + + tensors = [labeled_tensors[0].tensor] + concat_axis_list = [axes_0[axis_name]] + for labeled_tensor in labeled_tensors[1:]: + current_shared_axes = labeled_tensor.axes.remove(axis_name) + if current_shared_axes != shared_axes: + # TODO(shoyer): add more specific checks about what went wrong, + # including raising AxisOrderError when appropriate + raise ValueError('Mismatched shared axes: the first tensor ' + 'had axes %r but this tensor has axes %r.' % + (shared_axes, current_shared_axes)) + + # Accumulate the axis labels, if they're available. + concat_axis_list.append(labeled_tensor.axes[axis_name]) + tensors.append(labeled_tensor.tensor) + + concat_axis = core.concat_axes(concat_axis_list) + concat_dimension = axis_names.index(axis_name) + concat_tensor = array_ops.concat(concat_dimension, tensors, name=scope) + values = list(axes_0.values()) + concat_axes = (values[:concat_dimension] + [concat_axis] + + values[concat_dimension + 1:]) + + return core.LabeledTensor(concat_tensor, concat_axes) + + +# TODO(shoyer): rename pack/unpack to stack/unstack + + +@tc.returns(core.LabeledTensor) +@tc.accepts( + tc.Collection(core.LabeledTensorLike), + tc.Union(string_types, core.AxisLike), + int, tc.Optional(string_types)) +def pack(labeled_tensors, new_axis, axis_position=0, name=None): + """Pack tensors along a new axis. + + See tf.pack. + + Args: + labeled_tensors: The input tensors, which must have identical axes. + new_axis: The name of the new axis, or a tuple containing the name + and coordinate labels. + axis_position: Optional integer position at which to insert the new axis. + name: Optional op name. + + Returns: + The packed tensors as a single LabeledTensor, with `new_axis` in the given + `axis_position`. + + Raises: + ValueError: If fewer than one input tensors is provided, or if the tensors + don't have identical axes. + """ + with ops.name_scope(name, 'lt_pack', labeled_tensors) as scope: + labeled_tensors = [core.convert_to_labeled_tensor(lt) + for lt in labeled_tensors] + + if len(labeled_tensors) < 1: + raise ValueError('pack expects at least 1 tensors, but received %s' % + labeled_tensors) + + axes_0 = labeled_tensors[0].axes + for t in labeled_tensors: + if t.axes != axes_0: + raise ValueError('Non-identical axes. Expected %s but got %s' % + (axes_0, t.axes)) + + pack_op = array_ops.stack( + [t.tensor for t in labeled_tensors], axis=axis_position, name=scope) + axes = list(axes_0.values()) + axes.insert(axis_position, new_axis) + return core.LabeledTensor(pack_op, axes) + + +@tc.returns(tc.List(core.LabeledTensor)) +@tc.accepts(core.LabeledTensorLike, tc.Optional(string_types), + tc.Optional(string_types)) +def unpack(labeled_tensor, axis_name=None, name=None): + """Unpack the tensor. + + See tf.unpack. + + Args: + labeled_tensor: The input tensor. + axis_name: Optional name of axis to unpack. By default, the first axis is + used. + name: Optional op name. + + Returns: + The list of unpacked LabeledTensors. + + Raises: + ValueError: If `axis_name` is not an axis on the input. + """ + with ops.name_scope(name, 'lt_unpack', [labeled_tensor]) as scope: + labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) + + axis_names = list(labeled_tensor.axes.keys()) + if axis_name is None: + axis_name = axis_names[0] + + if axis_name not in axis_names: + raise ValueError('%s not in %s' % (axis_name, axis_names)) + axis = axis_names.index(axis_name) + + unpack_ops = array_ops.unstack(labeled_tensor.tensor, axis=axis, name=scope) + axes = [a for i, a in enumerate(labeled_tensor.axes.values()) + if i != axis] + return [core.LabeledTensor(t, axes) for t in unpack_ops] + + +@tc.returns(core.LabeledTensor) +@tc.accepts(core.LabeledTensorLike, tc.Collection(string_types), + tc.Collection(tc.Union(string_types, core.AxisLike)), + tc.Optional(string_types)) +def reshape(labeled_tensor, existing_axes, new_axes, name=None): + """Reshape specific axes of a LabeledTensor. + + Non-indicated axes remain in their original locations. + + Args: + labeled_tensor: The input tensor. + existing_axes: List of axis names found on the input tensor. These must + appear sequentially in the list of axis names on the input. In other + words, they must be a valid slice of `list(labeled_tensor.axes.keys())`. + new_axes: List of strings, tuples of (axis_name, axis_value) or Axis objects + providing new axes with which to replace `existing_axes` in the reshaped + result. At most one element of `new_axes` may be a string, indicating an + axis with unknown size. + name: Optional op name. + + Returns: + The reshaped LabeledTensor. + + Raises: + ValueError: If `existing_axes` are not all axes on the input, or if more + than one of `new_axes` has unknown size. + AxisOrderError: If `existing_axes` are not a slice of axis names on the + input. + """ + with ops.name_scope(name, 'lt_reshape', [labeled_tensor]) as scope: + labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) + + original_axis_names = list(labeled_tensor.axes.keys()) + existing_axes = list(existing_axes) + if not set(existing_axes) <= set(original_axis_names): + raise ValueError('existing_axes %r are not contained in the set of axis ' + 'names %r on the input labeled tensor' % + (existing_axes, original_axis_names)) + + start = original_axis_names.index(existing_axes[0]) + stop = original_axis_names.index(existing_axes[-1]) + 1 + + if existing_axes != original_axis_names[start:stop]: + # We could support existing_axes that aren't a slice by using transpose, + # but that could lead to unpredictable performance consequences because + # transposes are not free in TensorFlow. If we did transpose + # automatically, the user might never realize that their data is being + # produced with the wrong order. (The later will occur with some frequency + # because of how broadcasting automatically choose axis order.) + # So for now we've taken the strict approach. + raise core.AxisOrderError( + 'existing_axes %r are not a slice of axis names %r on the input ' + 'labeled tensor. Use `transpose` or `impose_axis_order` to reorder ' + 'axes on the input explicitly.' % + (existing_axes, original_axis_names)) + + if sum(isinstance(axis, string_types) for axis in new_axes) > 1: + raise ValueError( + 'at most one axis in new_axes can have unknown size. All other ' + 'axes must have an indicated integer size or labels: %r' % new_axes) + + original_values = list(labeled_tensor.axes.values()) + axis_size = lambda axis: -1 if axis.size is None else axis.size + shape = [axis_size(axis) for axis in original_values[:start]] + for axis_ref in new_axes: + if isinstance(axis_ref, string_types): + shape.append(-1) + else: + axis = core.as_axis(axis_ref) + shape.append(axis_size(axis)) + shape.extend(axis_size(axis) for axis in original_values[stop:]) + + reshaped_tensor = array_ops.reshape( + labeled_tensor.tensor, shape, name=scope) + axes = original_values[:start] + list(new_axes) + original_values[stop:] + return core.LabeledTensor(reshaped_tensor, axes) + + +@tc.returns(core.LabeledTensor) +@tc.accepts(core.LabeledTensorLike, string_types, string_types, + tc.Optional(string_types)) +def rename_axis(labeled_tensor, existing_name, new_name, name=None): + """Rename an axis of LabeledTensor. + + Args: + labeled_tensor: The input tensor. + existing_name: Name for an existing axis on the input. + new_name: Desired replacement name. + name: Optional op name. + + Returns: + LabeledTensor with renamed axis. + + Raises: + ValueError: If `existing_name` is not an axis on the input. + """ + with ops.name_scope(name, 'lt_rename_axis', [labeled_tensor]) as scope: + if existing_name not in labeled_tensor.axes: + raise ValueError('existing_name %r are not contained in the set of axis ' + 'names %r on the input labeled tensor' % + (existing_name, labeled_tensor.axes.keys())) + new_axis = core.Axis(new_name, labeled_tensor.axes[existing_name].value) + return reshape(labeled_tensor, [existing_name], [new_axis], name=scope) + + +@tc.returns(tc.List(core.LabeledTensor)) +@tc.accepts(string_types, collections.Callable, int, bool, + tc.Collection(core.LabeledTensorLike), bool, + tc.Optional(string_types)) +def _batch_helper(default_name, + batch_fn, + batch_size, + enqueue_many, + labeled_tensors, + allow_smaller_final_batch, + name=None): + with ops.name_scope(name, default_name, labeled_tensors) as scope: + labeled_tensors = [core.convert_to_labeled_tensor(lt) + for lt in labeled_tensors] + + batch_ops = batch_fn([t.tensor for t in labeled_tensors], scope) + # TODO(shoyer): Remove this when they sanitize the TF API. + if not isinstance(batch_ops, list): + assert isinstance(batch_ops, ops.Tensor) + batch_ops = [batch_ops] + + if allow_smaller_final_batch: + batch_size = None + + @tc.returns(core.Axes) + @tc.accepts(core.Axes) + def output_axes(axes): + if enqueue_many: + if 'batch' not in axes or list(axes.keys()).index('batch') != 0: + raise ValueError( + 'When enqueue_many is True, input tensors must have an axis ' + 'called "batch" as their first dimension, ' + 'but axes were %s' % axes) + culled_axes = axes.remove('batch') + return core.Axes([('batch', batch_size)] + list(culled_axes.values())) + else: + return core.Axes([('batch', batch_size)] + list(axes.values())) + + output_labeled_tensors = [] + for i, tensor in enumerate(batch_ops): + axes = output_axes(labeled_tensors[i].axes) + output_labeled_tensors.append(core.LabeledTensor(tensor, axes)) + + return output_labeled_tensors + + +@tc.returns(tc.List(core.LabeledTensor)) +@tc.accepts( + tc.Collection(core.LabeledTensorLike), int, int, int, bool, bool, + tc.Optional(string_types)) +def batch(labeled_tensors, + batch_size, + num_threads=1, + capacity=32, + enqueue_many=False, + allow_smaller_final_batch=False, + name=None): + """Rebatch a tensor. + + See tf.batch. + + Args: + labeled_tensors: The input tensors. + batch_size: The output batch size. + num_threads: See tf.batch. + capacity: See tf.batch. + enqueue_many: If true, the input tensors must contain a 'batch' axis as + their first axis. + If false, the input tensors must not contain a 'batch' axis. + See tf.batch. + allow_smaller_final_batch: See tf.batch. + name: Optional op name. + + Returns: + The rebatched tensors. + If enqueue_many is false, the output tensors will have a new 'batch' axis + as their first axis. + + Raises: + ValueError: If enqueue_many is True and the first axis of the tensors + isn't "batch". + """ + + def fn(tensors, scope): + return input.batch(tensors, + batch_size=batch_size, + num_threads=num_threads, + capacity=capacity, + enqueue_many=enqueue_many, + allow_smaller_final_batch=allow_smaller_final_batch, + name=scope) + + return _batch_helper('lt_batch', fn, batch_size, enqueue_many, + labeled_tensors, allow_smaller_final_batch, name) + + +@tc.returns(tc.List(core.LabeledTensor)) +@tc.accepts( + tc.Collection(core.LabeledTensorLike), int, int, int, bool, int, + tc.Optional(int), bool, tc.Optional(string_types)) +def shuffle_batch(labeled_tensors, + batch_size, + num_threads=1, + capacity=32, + enqueue_many=False, + min_after_dequeue=0, + seed=None, + allow_smaller_final_batch=False, + name=None): + """Rebatch a tensor, with shuffling. + + See tf.batch. + + Args: + labeled_tensors: The input tensors. + batch_size: The output batch size. + num_threads: See tf.batch. + capacity: See tf.batch. + enqueue_many: If true, the input tensors must contain a 'batch' axis as + their first axis. + If false, the input tensors must not contain a 'batch' axis. + See tf.batch. + min_after_dequeue: Minimum number of elements in the queue after a dequeue, + used to ensure mixing. + seed: Optional random seed. + allow_smaller_final_batch: See tf.batch. + name: Optional op name. + + Returns: + The rebatched tensors. + If enqueue_many is false, the output tensors will have a new 'batch' axis + as their first axis. + + Raises: + ValueError: If enqueue_many is True and the first axis of the tensors + isn't "batch". + """ + + def fn(tensors, scope): + return input.shuffle_batch( + tensors, + batch_size=batch_size, + num_threads=num_threads, + capacity=capacity, + enqueue_many=enqueue_many, + min_after_dequeue=min_after_dequeue, + seed=seed, + allow_smaller_final_batch=allow_smaller_final_batch, + name=scope) + + return _batch_helper('lt_shuffle_batch', fn, batch_size, enqueue_many, + labeled_tensors, allow_smaller_final_batch, name) + + +@tc.returns(core.LabeledTensor) +@tc.accepts(core.LabeledTensorLike, tc.Mapping(string_types, int), + tc.Optional(int), tc.Optional(string_types)) +def random_crop(labeled_tensor, shape_map, seed=None, name=None): + """Randomly crops a tensor to a given size. + + See tf.random_crop. + + Args: + labeled_tensor: The input tensor. + shape_map: A dictionary mapping axis names to the size of the random crop + for that dimension. + seed: An optional random seed. + name: An optional op name. + + Returns: + A tensor of the same rank as `labeled_tensor`, cropped randomly in the + selected dimensions. + + Raises: + ValueError: If the shape map contains an axis name not in the input tensor. + """ + with ops.name_scope(name, 'lt_random_crop', [labeled_tensor]) as scope: + labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) + + for axis_name in shape_map: + if axis_name not in labeled_tensor.axes: + raise ValueError('Selection axis %s not in axes %s' % + (axis_name, labeled_tensor.axes)) + + shape = [] + axes = [] + for axis in labeled_tensor.axes.values(): + if axis.name in shape_map: + size = shape_map[axis.name] + shape.append(size) + # We lose labels for the axes we crop, leaving just the size. + axes.append((axis.name, size)) + else: + shape.append(len(axis)) + axes.append(axis) + + crop_op = random_ops.random_crop(labeled_tensor.tensor, + shape, + seed=seed, + name=scope) + + return core.LabeledTensor(crop_op, axes) + + +# TODO(shoyer): Allow the user to select the axis over which to map. +@tc.returns(core.LabeledTensor) +@tc.accepts(collections.Callable, core.LabeledTensorLike, + tc.Optional(string_types)) +def map_fn(fn, labeled_tensor, name=None): + """Map on the list of tensors unpacked from labeled_tensor. + + See tf.map_fn. + + Args: + fn: The function to apply to each unpacked LabeledTensor. + It should have type LabeledTensor -> LabeledTensor. + labeled_tensor: The input tensor. + name: Optional op name. + + Returns: + A tensor that packs the results of applying fn to the list of tensors + unpacked from labeled_tensor. + """ + with ops.name_scope(name, 'lt_map_fn', [labeled_tensor]) as scope: + labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) + unpack_lts = unpack(labeled_tensor) + map_lts = [fn(t) for t in unpack_lts] + return pack(map_lts, list(labeled_tensor.axes.values())[0], name=scope) + + +@tc.returns(core.LabeledTensor) +@tc.accepts(core.LabeledTensorLike, tc.Optional(tc.Collection(string_types)), + tc.Optional(string_types)) +def squeeze(labeled_tensor, axis_names=None, name=None): + """Remove size-1 dimensions. + + See tf.squeeze. + + Args: + labeled_tensor: The input tensor. + axis_names: The names of the dimensions to remove, or None to remove + all size-1 dimensions. + name: Optional op name. + + Returns: + A tensor with the specified dimensions removed. + + Raises: + ValueError: If the named axes are not in the tensor, or if they are + not size-1. + """ + with ops.name_scope(name, 'lt_squeeze', [labeled_tensor]) as scope: + labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) + + if axis_names is None: + axis_names = [a.name for a in labeled_tensor.axes.values() if len(a) == 1] + + for axis_name in axis_names: + if axis_name not in labeled_tensor.axes: + raise ValueError('axis %s is not in tensor axes %s' % + (axis_name, labeled_tensor.axes)) + elif len(labeled_tensor.axes[axis_name]) != 1: + raise ValueError( + 'cannot squeeze axis with size greater than 1: (%s, %s)' % + (axis_name, labeled_tensor.axes[axis_name])) + + squeeze_dimensions = [] + axes = [] + for i, axis in enumerate(labeled_tensor.axes.values()): + if axis.name in axis_names: + squeeze_dimensions.append(i) + else: + axes.append(axis) + + if squeeze_dimensions: + squeeze_op = array_ops.squeeze(labeled_tensor.tensor, + squeeze_dimensions, + name=scope) + else: + squeeze_op = array_ops.identity(labeled_tensor.tensor, name=scope) + + return core.LabeledTensor(squeeze_op, axes) + +# pylint: disable=invalid-name +ReduceAxis = tc.Union( + string_types, tc.Tuple(string_types, collections.Hashable)) +ReduceAxes = tc.Optional(tc.Union(ReduceAxis, tc.Collection(ReduceAxis))) +# pylint: enable=invalid-name + + +@tc.returns(core.LabeledTensor) +@tc.accepts(core.LabeledTensorLike, core.LabeledTensorLike, + tc.Optional(string_types)) +def matmul(a, b, name=None): + """Matrix multiply two tensors with rank 1 or 2. + + If both tensors have rank 2, a matrix-matrix product is performed. + If one tensor has rank 1 and the other has rank 2, then a matrix-vector + product is performed. + If both tensors have rank 1, then a vector dot-product is performed. + (This behavior matches that of `numpy.dot`.) + + Both tensors must share exactly one dimension in common, which is the + dimension the operation is summed along. The inputs will be automatically + transposed if necessary as part of the matmul op. + + We intend to eventually support `matmul` on higher rank input, and also + eventually support summing over any number shared dimensions (via an `axis` + argument), but neither of these features has been implemented yet. + + Args: + a: First LabeledTensor. + b: Second LabeledTensor. + name: Optional op name. + + Returns: + LabeledTensor with the result of matrix multiplication. Axes are ordered by + the current axis_order_scope, if set, or in or order of appearance on the + inputs. + + Raises: + NotImplementedError: If inputs have rank >2 or share multiple axes. + ValueError: If the inputs have rank 0 or do not share any axes. + """ + with ops.name_scope(name, 'lt_matmul', [a, b]) as scope: + + a = core.convert_to_labeled_tensor(a) + b = core.convert_to_labeled_tensor(b) + + if len(a.axes) > 2 or len(b.axes) > 2: + # We could use tf.batch_matmul to make this work, but we would also need + # to use tf.tile and/or tf.transpose. These are more expensive than doing + # reshapes, so it's not clear if it's a good idea to do this + # automatically. + raise NotImplementedError( + 'matmul currently requires inputs with rank 2 or less, but ' + 'inputs have ranks %r and %r' % (len(a.axes), len(b.axes))) + + if not a.axes or not b.axes: + raise ValueError( + 'matmul currently requires inputs with at least rank 1, but ' + 'inputs have ranks %r and %r' % (len(a.axes), len(b.axes))) + + shared_axes = set(a.axes) & set(b.axes) + if len(shared_axes) > 1: + raise NotImplementedError( + 'matmul does not yet support summing over multiple shared axes: %r. ' + 'Use transpose and reshape to create a single shared axis to sum ' + 'over.' % shared_axes) + if not shared_axes: + raise ValueError('there must have exactly one axis in common between ' + 'input to matmul: %r, %r' % + (a.axes.keys(), b.axes.keys())) + shared_axis, = shared_axes + + if a.axes[shared_axis] != b.axes[shared_axis]: + raise ValueError('axis %r does not match on input arguments: %r vs %r' % + (shared_axis, a.axes[shared_axis].value, + b.axes[shared_axis].value)) + + result_axes = [] + for axes in [a.axes, b.axes]: + for axis in axes.values(): + if axis.name != shared_axis: + result_axes.append(axis) + + axis_scope_order = core.get_axis_order() + if axis_scope_order is not None: + result_axis_names = [axis.name for axis in result_axes] + new_axis_names = [name for name in axis_scope_order + if name in result_axis_names] + if new_axis_names != result_axis_names: + # switch a and b + b, a = a, b + # result_axes is a list of length 1 or 2 + result_axes = result_axes[::-1] + + squeeze_dims = [] + + if len(a.axes) == 1: + a_tensor = array_ops.reshape(a.tensor, (1, -1)) + squeeze_dims.append(0) + transpose_a = False + else: + a_tensor = a.tensor + transpose_a = list(a.axes.keys()).index(shared_axis) == 0 + + if len(b.axes) == 1: + b_tensor = array_ops.reshape(b.tensor, (-1, 1)) + squeeze_dims.append(1) + transpose_b = False + else: + b_tensor = b.tensor + transpose_b = list(b.axes.keys()).index(shared_axis) == 1 + + result_op = math_ops.matmul(a_tensor, + b_tensor, + transpose_a=transpose_a, + transpose_b=transpose_b) + + if squeeze_dims: + result_op = array_ops.squeeze(result_op, squeeze_dims) + result_op = array_ops.identity(result_op, name=scope) + + return core.LabeledTensor(result_op, result_axes) + + +@tc.returns(types.FunctionType) +@tc.accepts(string_types, collections.Callable) +def define_reduce_op(op_name, reduce_fn): + """Define a reduction op for labeled tensors. + + Args: + op_name: string name of the TensorFlow op. + reduce_fn: function to call to evaluate the op on a tf.Tensor. + + Returns: + Function defining the given reduction op that acts on a LabeledTensor. + """ + + default_name = 'lt_%s' % op_name + + @tc.returns(core.LabeledTensor) + @tc.accepts(core.LabeledTensorLike, ReduceAxes, tc.Optional(string_types)) + def op(labeled_tensor, axes=None, name=None): + """Computes the given reduction across the given axes of a LabeledTensor. + + See `tf.{op_name}` for full details. + + Args: + labeled_tensor: The input tensor. + axes: A set of axes or None. + If None, all axes will be reduced. + Axes must all be strings, in which case those dimensions will be + removed, or pairs of (name, None) or (name, label), in which case those + dimensions will be kept. + name: Optional op name. + + Returns: + The reduced LabeledTensor. + + Raises: + ValueError: if any of the axes to reduce over are not found on + `labeled_tensor`. + """ + with ops.name_scope(name, default_name, [labeled_tensor]) as scope: + labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) + + if axes is None: + axes = labeled_tensor.axes.keys() + + if isinstance(axes, (string_types, tuple)): + axes = [axes] + + reduction_axes = {} + axes_to_squeeze = [] + for a in axes: + if isinstance(a, string_types): + # We squeeze out this axis. + reduction_axes[a] = a + axes_to_squeeze.append(a) + else: + # We keep this axis, with the user-provided labels. + (axis_name, label) = a + if label is not None: + # The input was a single label, so make it a list so it can be + # turned into an Axis. + label = [label] + reduction_axes[axis_name] = (axis_name, label) + + for axis_name in reduction_axes: + if axis_name not in labeled_tensor.axes: + raise ValueError('Axis %s not in axes %s' % + (axis_name, labeled_tensor.axes)) + + intermediate_axes = [] + reduction_dimensions = [] + for i, axis in enumerate(labeled_tensor.axes.values()): + if axis.name in reduction_axes: + intermediate_axes.append(reduction_axes[axis.name]) + reduction_dimensions.append(i) + else: + intermediate_axes.append(axis) + + reduce_op = reduce_fn(labeled_tensor.tensor, + reduction_dimensions, + keep_dims=True) + reduce_lt = core.LabeledTensor(reduce_op, intermediate_axes) + + return squeeze(reduce_lt, axes_to_squeeze, name=scope) + + op.__doc__ = op.__doc__.format(op_name=op_name) + op.__name__ = op_name + + return op + + +reduce_all = define_reduce_op('reduce_all', math_ops.reduce_all) +reduce_any = define_reduce_op('reduce_any', math_ops.reduce_any) +reduce_logsumexp = define_reduce_op('reduce_logsumexp', + math_ops.reduce_logsumexp) +reduce_max = define_reduce_op('reduce_max', math_ops.reduce_max) +reduce_mean = define_reduce_op('reduce_mean', math_ops.reduce_mean) +reduce_min = define_reduce_op('reduce_min', math_ops.reduce_min) +reduce_prod = define_reduce_op('reduce_prod', math_ops.reduce_prod) +reduce_sum = define_reduce_op('reduce_sum', math_ops.reduce_sum) + + +@tc.returns(core.LabeledTensor) +@tc.accepts(core.LabeledTensorLike, tc.Mapping(str, tc.Union(int, ops.Tensor)), + tc.Optional(string_types)) +def tile(labeled_tensor, multiples, name=None): + """Constructs a tensor by tiling a given tensor. + + Only axes without tick-labels can be tiled. (Otherwise, axis labels on tiled + tensors would no longer be unique.) + + See lt.tile. + + Args: + labeled_tensor: The input tensor. + multiples: A mapping where the keys are axis names and the values are the + integer number of times to tile along that axis. Only axes with a multiple + different than 1 need be included. + name: Optional op name. + + Returns: + A tensor with the indicated axes tiled. + + Raises: + ValueError: If the tiled axes are not axes in the input tensor, or if any + axes in multiples have tick labels. + """ + with ops.name_scope(name, 'lt_tile', [labeled_tensor]) as scope: + labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) + + if not set(multiples.keys()) <= set(labeled_tensor.axes.keys()): + raise ValueError('tile axes %r are not contained in the set of axis ' + 'names %r on the input labeled tensor' % + (multiples.keys(), labeled_tensor.axes)) + + labeled_axes = [name for name in multiples + if labeled_tensor.axes[name].labels is not None] + if labeled_axes: + raise ValueError('cannot tile axes with tick labels: %r' % labeled_axes) + + multiples_list = [multiples.get(name, 1) for name in labeled_tensor.axes] + tile_op = array_ops.tile(labeled_tensor.tensor, multiples_list, name=scope) + + new_axes = [axis.name if axis.labels is None else axis + for axis in labeled_tensor.axes.values()] + return core.LabeledTensor(tile_op, new_axes) + + +@tc.returns(core.LabeledTensor) +@tc.accepts(core.LabeledTensorLike, + tc.Mapping(str, tc.Tuple(core.AxisValue, core.AxisValue)), + string_types, tc.Optional(string_types)) +def pad(labeled_tensor, paddings, mode='CONSTANT', name=None): + """Pads a tensor. + + See tf.pad. + + Args: + labeled_tensor: The input tensor. + paddings: A mapping where the keys are axis names and the values are + tuples where the first element is the padding to insert at the beginning + of the axis and the second is the padding to insert at the end of the + axis. + mode: One of "CONSTANT", "REFLECT", or "SYMMETRIC". + name: Optional op name. + + Returns: + A tensor with the indicated axes padded, optionally with those axes extended + with the provided labels. + + Raises: + ValueError: If the padded axes are not axes in the input tensor. + """ + with ops.name_scope(name, 'lt_pad', [labeled_tensor]) as scope: + labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) + + if not set(paddings.keys()) <= set(labeled_tensor.axes.keys()): + raise ValueError('pad axes %r are not contained in the set of axis ' + 'names %r on the input labeled tensor' % + (paddings.keys(), labeled_tensor.axes)) + + new_axes = [] + padding_pairs = [] + for name, axis in labeled_tensor.axes.items(): + if name in paddings: + padding_before, padding_after = paddings[name] + axis_before = core.Axis(name, padding_before) + axis_after = core.Axis(name, padding_after) + new_axes.append(core.concat_axes([axis_before, axis, axis_after])) + padding_pairs.append((len(axis_before), len(axis_after))) + else: + new_axes.append(axis) + padding_pairs.append((0, 0)) + + pad_op = array_ops.pad( + labeled_tensor.tensor, padding_pairs, mode, name=scope) + + return core.LabeledTensor(pad_op, new_axes) + + +@tc.returns(core.LabeledTensor) +@tc.accepts(tc.Union(np.ndarray, list, tuple, core.Scalar), + tc.Optional(dtypes.DType), + tc.Optional(tc.Union( + core.Axes, + tc.Collection(tc.Union(string_types, core.AxisLike)))), + tc.Optional(string_types)) +def constant(value, dtype=None, axes=None, name=None): + """Creates a constant tensor. + + If `axes` includes any strings, shape is inferred from `value`. Otherwise, + the sizes of the given `axes` are used to set `shape` for `tf.constant`. + + See tf.constant for more details. + + Args: + value: The input tensor. + dtype: The type of the returned tensor. + axes: Optional Axes, list of strings or list of objects coercible to Axis + objects. By default, axes are assumed to be an empty list (i.e., `value` + is treated as a scalar). + name: Optional op name. + + Returns: + The tensor with elements set to zero. + """ + with ops.name_scope(name, 'lt_constant', [value]) as scope: + + if axes is None: + axes = [] + + if isinstance(axes, core.Axes): + axes = axes.values() + + if any(isinstance(ax, string_types) for ax in axes): + # need to infer shape + shape = None + else: + # axes already indicate shape + axes = [core.as_axis(a) for a in axes] + shape = [a.size for a in axes] + + op = array_ops.constant(value, dtype=dtype, shape=shape, name=scope) + return core.LabeledTensor(op, axes) + + +@tc.returns(core.LabeledTensor) +@tc.accepts(core.LabeledTensorLike, tc.Optional(dtypes.DType), + tc.Optional(string_types)) +def zeros_like(labeled_tensor, dtype=None, name=None): + """Creates an identical tensor with all elements set to zero. + + Args: + labeled_tensor: The input tensor. + dtype: The type of the returned tensor. + name: Optional op name. + + Returns: + The tensor with elements set to zero. + """ + with ops.name_scope(name, 'lt_zeros_like', [labeled_tensor]) as scope: + labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) + op = array_ops.zeros_like(labeled_tensor.tensor, dtype=dtype, name=scope) + return core.LabeledTensor(op, labeled_tensor.axes) + + +@tc.returns(core.LabeledTensor) +@tc.accepts(core.LabeledTensorLike, tc.Optional(dtypes.DType), + tc.Optional(string_types)) +def ones_like(labeled_tensor, dtype=None, name=None): + """Creates an identical tensor with all elements set to one. + + Args: + labeled_tensor: The input tensor. + dtype: The type of the returned tensor. + name: Optional op name. + + Returns: + The tensor with elements set to one. + """ + with ops.name_scope(name, 'lt_ones_like', [labeled_tensor]) as scope: + labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) + op = array_ops.ones_like(labeled_tensor.tensor, dtype=dtype, name=scope) + return core.LabeledTensor(op, labeled_tensor.axes) + + +@tc.returns(core.LabeledTensor) +@tc.accepts(core.LabeledTensorLike, tc.Optional(dtypes.DType), + tc.Optional(string_types)) +def cast(labeled_tensor, dtype=None, name=None): + """Casts a labeled tensor to a new type. + + Args: + labeled_tensor: The input tensor. + dtype: The type of the returned tensor. + name: Optional op name. + + Returns: + A labeled tensor with the new dtype. + """ + with ops.name_scope(name, 'lt_cast', [labeled_tensor]) as scope: + labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) + op = math_ops.cast(labeled_tensor.tensor, dtype=dtype, name=scope) + return core.LabeledTensor(op, labeled_tensor.axes) + + +@tc.returns(core.LabeledTensor) +@tc.accepts(core.LabeledTensorLike, string_types, tc.Optional(string_types)) +def verify_tensor_all_finite(labeled_tensor, message, name=None): + """Asserts a tensor doesn't contain NaNs or Infs. + + See tf.verify_tensor_all_finite. + + Args: + labeled_tensor: The input tensor. + message: Message to log on failure. + name: Optional op name. + + Returns: + The input tensor. + """ + with ops.name_scope(name, 'lt_verify_tensor_all_finite', + [labeled_tensor]) as scope: + labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) + op = numerics.verify_tensor_all_finite(labeled_tensor.tensor, + msg=message, + name=scope) + return core.LabeledTensor(op, labeled_tensor.axes) + + +@tc.returns(core.LabeledTensor) +@tc.accepts(core.LabeledTensorLike, core.LabeledTensorLike, + tc.Optional(string_types)) +def boolean_mask(labeled_tensor, mask, name=None): + """Apply a boolean mask to a labeled tensor. + + Unlike `tf.boolean_mask`, this currently only works on 1-dimensional masks. + The mask is applied to the first axis of `labeled_tensor`. Labels on the first + axis are removed, because True indices in `mask` may not be known dynamically. + + Args: + labeled_tensor: The input tensor. + mask: The type of the returned tensor. + name: Optional op name. + + Returns: + The masked labeled tensor. + + Raises: + ValueError: if the first axis of the mask + """ + with ops.name_scope(name, 'lt_boolean_mask', [labeled_tensor, mask]) as scope: + labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) + mask = core.convert_to_labeled_tensor(mask) + + if len(mask.axes) > 1: + raise NotImplementedError( + "LabeledTensor's boolean_mask currently only supports 1D masks") + mask_axis = list(mask.axes.values())[0] + lt_axis = list(labeled_tensor.axes.values())[0] + if mask_axis != lt_axis: + raise ValueError('the first axis of the labeled tensor and the mask ' + 'are not equal:\n%r\n%r' % (lt_axis, mask_axis)) + op = array_ops.boolean_mask(labeled_tensor.tensor, mask.tensor, name=scope) + # TODO(shoyer): attempt to infer labels for the masked values, by calling + # tf.contrib.util.constant_value on the mask? + axes = [lt_axis.name] + list(labeled_tensor.axes.values())[1:] + return core.LabeledTensor(op, axes) + + +@tc.returns(core.LabeledTensor) +@tc.accepts(core.LabeledTensorLike, core.LabeledTensorLike, + core.LabeledTensorLike, tc.Optional(string_types)) +def where(condition, x, y, name=None): + """Return elements from x or y depending on condition. + + See `tf.where` for more details. This function currently only implements the + three argument version of where. + + Args: + condition: LabeledTensor of type `bool`. + x: LabeledTensor for values where condition is true. + y: LabeledTensor for values where condition is false. + name: Optional op name. + + Returns: + The labeled tensor with values according to condition. + + Raises: + ValueError: if `x` and `y` have different axes, or if the axes of `x` do not + start with the axes of `condition`. + """ + with ops.name_scope(name, 'lt_where', [condition, x, y]) as scope: + condition = core.convert_to_labeled_tensor(condition) + x = core.convert_to_labeled_tensor(x) + y = core.convert_to_labeled_tensor(y) + + if not condition.axes == x.axes == y.axes: + raise ValueError('all inputs to `where` must have equal axes') + + op = array_ops.where(condition.tensor, x.tensor, y.tensor, name=scope) + return core.LabeledTensor(op, x.axes) diff --git a/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py b/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py new file mode 100644 index 0000000000..c19fc09f93 --- /dev/null +++ b/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py @@ -0,0 +1,918 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from six.moves import range # pylint: disable=redefined-builtin +import tensorflow as tf + +from tensorflow.contrib.labeled_tensor.python.ops import core +from tensorflow.contrib.labeled_tensor.python.ops import ops +from tensorflow.contrib.labeled_tensor.python.ops import test_util + + +class Base(test_util.Base): + + def setUp(self): + super(Base, self).setUp() + + self.x_size = 7 + self.channel_size = 3 + self.z_size = 4 + self.probs_size = 11 + + tensor = tf.range(0, self.x_size * self.channel_size * self.z_size * + self.probs_size) + tensor = tf.reshape(tensor, [self.x_size, self.channel_size, self.z_size, + self.probs_size]) + a0 = ('x', range(self.x_size)) + a1 = ('channel', ['red', 'green', 'blue']) + a2 = 'z' + a3 = ('probs', np.linspace(0.0, 1.0, self.probs_size)) + + self.tensor = tensor + self.a0 = a0 + self.a1 = a1 + self.a2 = a2 + self.a2_resolved = ('z', self.z_size) + self.a3 = a3 + self.original_lt = core.LabeledTensor(tensor, [a0, a1, a2, a3]) + + self.x_probs_lt = core.slice_function(self.original_lt, {'z': 0}) + self.x_probs_lt = ops.select(self.x_probs_lt, {'channel': 'red'}) + self.channel_probs_lt = core.slice_function(self.original_lt, {'x': 3, + 'z': 0}) + + +class SelectTest(Base): + + def test_name(self): + select_lt = ops.select(self.original_lt, {'channel': 'green'}) + self.assertIn('lt_select', select_lt.name) + + def test_scalar(self): + select_lt = ops.select(self.original_lt, {'channel': 'green'}) + golden_lt = core.LabeledTensor(self.tensor[:, 1, :, :], [self.a0, self.a2, + self.a3]) + self.assertLabeledTensorsEqual(select_lt, golden_lt) + + def test_slice(self): + select_lt = ops.select(self.original_lt, {'channel': slice('red', 'green')}) + a1_sliced = ('channel', ['red', 'green']) + golden_lt = core.LabeledTensor(self.tensor[:, :2, :, :], + [self.a0, a1_sliced, self.a2, self.a3]) + self.assertLabeledTensorsEqual(select_lt, golden_lt) + + def test_slices(self): + select_lt = ops.select(self.original_lt, {'x': slice(1, 4), + 'channel': slice('green', None)}) + + a0_sliced = ('x', range(1, 5)) + a1_sliced = ('channel', ['green', 'blue']) + golden_lt = core.LabeledTensor(self.tensor[1:5, 1:, :, :], + [a0_sliced, a1_sliced, self.a2, self.a3]) + self.assertLabeledTensorsEqual(select_lt, golden_lt) + + def test_list(self): + select_lt = ops.select(self.original_lt, {'channel': ['red', 'green']}) + a1_sliced = ('channel', ['red', 'green']) + golden_lt = core.LabeledTensor(self.tensor[:, :2, :, :], + [self.a0, a1_sliced, self.a2, self.a3]) + self.assertLabeledTensorsEqual(select_lt, golden_lt) + + def test_list_one_item(self): + select_lt = ops.select(self.original_lt, {'channel': ['red']}) + a1_sliced = ('channel', ['red']) + golden_lt = core.LabeledTensor(self.tensor[:, :1, :, :], + [self.a0, a1_sliced, self.a2, self.a3]) + self.assertLabeledTensorsEqual(select_lt, golden_lt) + + def test_list_zero_items(self): + select_lt = ops.select(self.original_lt, {'channel': []}) + golden_lt = core.LabeledTensor(self.tensor[:, :0, :, :], + [self.a0, 'channel', self.a2, self.a3]) + self.assertLabeledTensorsEqual(select_lt, golden_lt) + + def test_scalars(self): + select_lt = ops.select(self.original_lt, {'x': 1, 'channel': 'green'}) + golden_lt = core.LabeledTensor(self.tensor[1, 1, :, :], + [self.a2, self.a3]) + self.assertLabeledTensorsEqual(select_lt, golden_lt) + + def test_invalid_input(self): + with self.assertRaises(ValueError): + ops.select(self.original_lt, {'foo': 1}) + with self.assertRaises(ValueError): + ops.select(self.original_lt, {'z': 1}) + with self.assertRaises(KeyError): + ops.select(self.original_lt, {'channel': 'purple'}) + with self.assertRaises(KeyError): + ops.select(self.original_lt, {'channel': ['red', 'purple']}) + with self.assertRaises(NotImplementedError): + ops.select(self.original_lt, {'channel': ['red'], 'x': [1]}) + with self.assertRaises(NotImplementedError): + ops.select(self.original_lt, {'channel': ['red'], 'x': 1}) + with self.assertRaises(NotImplementedError): + ops.select(self.original_lt, {'channel': slice('red', 'green', 2)}) + + +class ConcatTest(Base): + + def setUp(self): + super(ConcatTest, self).setUp() + + self.red_lt = ops.select(self.original_lt, {'channel': ['red']}) + self.green_lt = ops.select(self.original_lt, {'channel': ['green']}) + self.blue_lt = ops.select(self.original_lt, {'channel': ['blue']}) + + def test_name(self): + concat_lt = ops.concat([self.red_lt, self.blue_lt], 'channel') + self.assertIn('lt_concat', concat_lt.name) + + def test(self): + concat_lt = ops.concat([self.red_lt, self.green_lt], 'channel') + golden_lt = ops.select(self.original_lt, {'channel': ['red', 'green']}) + + self.assertLabeledTensorsEqual(concat_lt, golden_lt) + + def test_transposed(self): + green_transposed = core.transpose(self.green_lt, + ['probs', 'channel', 'z', 'x']) + with self.assertRaises(ValueError): + ops.concat([self.red_lt, green_transposed], 'channel') + + def test_invalid_input(self): + with self.assertRaises(ValueError): + ops.concat([], 'channel') + with self.assertRaises(ValueError): + ops.concat([self.red_lt, self.red_lt], 'channel') + with self.assertRaises(ValueError): + ops.concat([self.red_lt, self.red_lt], 'foo') + + +class PackTest(Base): + + def test_name(self): + pack_lt = ops.pack([self.original_lt, self.original_lt], 'batch') + self.assertIn('lt_pack', pack_lt.name) + + def test(self): + pack_lt = ops.pack([self.original_lt, self.original_lt], 'batch') + golden_lt = core.LabeledTensor( + tf.stack([self.original_lt.tensor, self.original_lt.tensor]), + ['batch', self.a0, self.a1, self.a2, self.a3]) + + self.assertLabeledTensorsEqual(pack_lt, golden_lt) + + def test_axis(self): + pack_lt = ops.pack([self.original_lt, self.original_lt], + new_axis='batch', + axis_position=4) + golden_lt = core.LabeledTensor( + tf.stack( + [self.original_lt.tensor, self.original_lt.tensor], axis=4), + [self.a0, self.a1, self.a2, self.a3, 'batch']) + + self.assertLabeledTensorsEqual(pack_lt, golden_lt) + + def test_invalid_input(self): + with self.assertRaises(ValueError): + ops.pack([self.original_lt, self.original_lt], 'channel') + + +class UnpackTest(Base): + + def test_name(self): + unpack_lts = ops.unpack(self.original_lt) + for t in unpack_lts: + self.assertIn('lt_unpack', t.name) + + def test(self): + unpack_lt = ops.unpack(self.original_lt)[0] + golden_lt = core.LabeledTensor( + tf.unstack(self.original_lt.tensor)[0], [self.a1, self.a2, self.a3]) + + self.assertLabeledTensorsEqual(unpack_lt, golden_lt) + + def test_axis(self): + unpack_lt = ops.unpack(self.original_lt, axis_name='z')[0] + golden_lt = core.LabeledTensor( + tf.unstack( + self.original_lt.tensor, axis=2)[0], [self.a0, self.a1, self.a3]) + + self.assertLabeledTensorsEqual(unpack_lt, golden_lt) + + def test_invalid_input(self): + with self.assertRaises(ValueError): + ops.unpack(self.original_lt, axis_name='not_found') + + +class ReshapeTest(Base): + + def test_name(self): + reshape_lt = ops.reshape(self.original_lt, ['channel'], ['foo']) + self.assertIn('lt_reshape', reshape_lt.name) + + def test_identity(self): + reshape_lt = ops.reshape(self.original_lt, self.original_lt.axes.keys(), + self.original_lt.axes.values()) + self.assertLabeledTensorsEqual(reshape_lt, self.original_lt) + + def test_known_size(self): + new_dim_size = self.channel_size * self.z_size * self.probs_size + reshape_lt = ops.reshape(self.original_lt, ['channel', 'z', 'probs'], + [('new_dim', new_dim_size)]) + golden_lt = core.LabeledTensor( + tf.reshape(self.original_lt.tensor, [self.x_size, -1]), + [self.original_lt.axes['x'], 'new_dim']) + self.assertLabeledTensorsEqual(reshape_lt, golden_lt) + + def test_unknown_size(self): + reshape_lt = ops.reshape(self.original_lt, ['channel', 'z', 'probs'], + ['new_dim']) + golden_lt = core.LabeledTensor( + tf.reshape(self.original_lt.tensor, [self.x_size, -1]), + [self.original_lt.axes['x'], 'new_dim']) + self.assertLabeledTensorsEqual(reshape_lt, golden_lt) + + def test_unknown_dimension(self): + orig_lt = core.LabeledTensor(tf.placeholder(tf.float32, [None]), ['x']) + reshape_lt = ops.reshape(orig_lt, ['x'], ['y', ('z', 1)]) + self.assertEqual(reshape_lt.axes, core.Axes([('y', None), ('z', 1)])) + with self.test_session() as sess: + result = sess.run(reshape_lt, feed_dict={orig_lt.tensor: [1, 2]}) + np.testing.assert_array_equal(result, [[1], [2]]) + + def test_with_labels(self): + new_dim_size = self.channel_size * self.z_size * self.probs_size + reshape_lt = ops.reshape(self.original_lt, ['channel', 'z', 'probs'], + [('new_dim', range(new_dim_size))]) + golden_lt = core.LabeledTensor( + tf.reshape(self.original_lt.tensor, [self.x_size, -1]), + [self.original_lt.axes['x'], ('new_dim', range(new_dim_size))]) + self.assertLabeledTensorsEqual(reshape_lt, golden_lt) + + def test_invalid_input(self): + with self.assertRaisesRegexp(ValueError, 'not contained in the set'): + ops.reshape(self.original_lt, ['foo'], ['bar']) + with self.assertRaisesRegexp(core.AxisOrderError, + 'not a slice of axis names'): + ops.reshape(self.original_lt, ['probs', 'z'], ['bar']) + with self.assertRaisesRegexp(ValueError, 'at most one axis in new_axes'): + ops.reshape(self.original_lt, ['probs'], ['foo', 'bar']) + + +class RenameAxisTest(Base): + + def test_name(self): + rename_axis_lt = ops.rename_axis(self.original_lt, 'channel', 'foo') + self.assertIn('lt_rename_axis', rename_axis_lt.name) + + def test_identity(self): + rename_axis_lt = ops.rename_axis(self.original_lt, 'channel', 'channel') + self.assertLabeledTensorsEqual(rename_axis_lt, self.original_lt) + + def test_new_name(self): + rename_axis_lt = ops.rename_axis(self.original_lt, 'channel', 'foo') + expected_axes = [(name if name != 'channel' else 'foo', axis.value) + for name, axis in self.original_lt.axes.items()] + expected_lt = core.LabeledTensor(self.original_lt.tensor, expected_axes) + self.assertLabeledTensorsEqual(rename_axis_lt, expected_lt) + + def test_invalid_input(self): + with self.assertRaisesRegexp(ValueError, 'not contained in the set'): + ops.rename_axis(self.original_lt, 'foo', 'bar') + + +class BatchTest(Base): + + def setUp(self): + super(BatchTest, self).setUp() + + tensors = [] + for i in range(10): + offset_lt = core.LabeledTensor(tf.constant(i), []) + tensors.append(core.add(self.original_lt, offset_lt)) + self.pack_lt = ops.pack(tensors, 'batch') + + def test_name(self): + batch_ops = ops.batch([self.pack_lt, self.pack_lt], + batch_size=2, + enqueue_many=True) + for bo in batch_ops: + self.assertIn('lt_batch', bo.name) + + def test_enqueue_many(self): + [batch_2_op] = ops.batch([self.pack_lt], batch_size=2, enqueue_many=True) + self.assertEqual(len(batch_2_op.axes['batch']), 2) + + [batch_10_op] = ops.batch([batch_2_op], batch_size=10, enqueue_many=True) + + self.assertLabeledTensorsEqual(self.pack_lt, batch_10_op) + + def test_no_enqueue_many(self): + [batch_2_op] = ops.batch([self.original_lt], batch_size=2) + self.assertEqual(len(batch_2_op.axes['batch']), 2) + + [batch_10_op] = ops.batch([batch_2_op], batch_size=10, enqueue_many=True) + + self.assertLabeledTensorsEqual( + ops.pack(10 * [self.original_lt], 'batch'), batch_10_op) + + def test_invalid_input(self): + with self.assertRaises(ValueError): + ops.batch([self.original_lt], 3, enqueue_many=True) + + def test_allow_smaller_final_batch(self): + [batch_2_op] = ops.batch([self.original_lt], batch_size=2, + allow_smaller_final_batch=True) + self.assertEqual(batch_2_op.axes['batch'].size, None) + + +class ShuffleBatchTest(Base): + + def setUp(self): + super(ShuffleBatchTest, self).setUp() + + tensors = [] + for i in range(10): + offset_lt = core.LabeledTensor(tf.constant(i), []) + tensors.append(core.add(self.original_lt, offset_lt)) + self.pack_lt = ops.pack(tensors, 'batch') + + def test_name(self): + batch_lts = ops.shuffle_batch([self.pack_lt, self.pack_lt], + batch_size=2, + enqueue_many=True) + for blt in batch_lts: + self.assertIn('lt_shuffle_batch', blt.name) + + def test_enqueue_many(self): + [batch_2_lt] = ops.shuffle_batch([self.pack_lt], + batch_size=2, + enqueue_many=True, + min_after_dequeue=8, + seed=0) + self.assertEqual(len(batch_2_lt.axes['batch']), 2) + + [batch_10_lt] = ops.batch([batch_2_lt], batch_size=10, enqueue_many=True) + + self.assertEqual(batch_10_lt.axes, self.pack_lt.axes) + [batch_10, pack] = self.eval([batch_10_lt.tensor, self.pack_lt.tensor]) + self.assertFalse((batch_10 == pack).all()) + + def test_allow_smaller_final_batch(self): + [batch_2_op] = ops.shuffle_batch([self.original_lt], batch_size=2, + allow_smaller_final_batch=True) + self.assertEqual(batch_2_op.axes['batch'].size, None) + + +class RandomCropTest(Base): + + def test_name(self): + crop_lt = ops.random_crop(self.original_lt, {'probs': 3}) + self.assertIn('lt_random_crop', crop_lt.name) + + def test_single(self): + crop_lt = ops.random_crop(self.original_lt, {'probs': 3}) + + self.assertEqual( + core.Axes([self.a0, self.a1, self.a2_resolved, ('probs', 3)]), + crop_lt.axes) + + def test_double(self): + crop_lt = ops.random_crop(self.original_lt, {'probs': 3, 'channel': 2}) + + self.assertEqual( + core.Axes([self.a0, ('channel', 2), self.a2_resolved, ('probs', 3)]), + crop_lt.axes) + + def test_size1(self): + crop_lt = ops.random_crop(self.original_lt, {'probs': 1}) + + self.assertEqual( + core.Axes([self.a0, self.a1, self.a2_resolved, ('probs', 1)]), + crop_lt.axes) + + def test_different_seeds(self): + crop_0_lt = ops.random_crop(self.original_lt, {'probs': 3, + 'channel': 2}, + seed=0) + crop_1_lt = ops.random_crop(self.original_lt, {'probs': 3, + 'channel': 2}, + seed=1) + + self.assertEqual(crop_0_lt.axes, crop_1_lt.axes) + [crop_0, crop_1] = self.eval([crop_0_lt.tensor, crop_1_lt.tensor]) + self.assertFalse((crop_0 == crop_1).all()) + + def test_identical_seeds(self): + crop_0_lt = ops.random_crop(self.original_lt, {'probs': 3, + 'channel': 2}, + seed=0) + crop_1_lt = ops.random_crop(self.original_lt, {'probs': 3, + 'channel': 2}, + seed=0) + + self.assertLabeledTensorsEqual(crop_0_lt, crop_1_lt) + + def test_crop_idempotent(self): + crop_0_lt = ops.random_crop(self.original_lt, {'probs': 3, + 'channel': 2}, + seed=0) + crop_1_lt = ops.random_crop(crop_0_lt, {'probs': 3, 'channel': 2}, seed=1) + + self.assertLabeledTensorsEqual(crop_0_lt, crop_1_lt) + + def test_invalid_input(self): + with self.assertRaises(ValueError): + ops.random_crop(self.original_lt, {'foobar': 2}) + + +class MapFnTest(Base): + + def test_name(self): + map_lt = ops.map_fn(core.identity, self.original_lt) + self.assertIn('lt_map_fn', map_lt.name) + + def test_identity(self): + map_lt = ops.map_fn(core.identity, self.original_lt) + self.assertLabeledTensorsEqual(map_lt, self.original_lt) + + def test_callable_object(self): + + class Identity(object): + + def __call__(self, other): + return other + + map_lt = ops.map_fn(Identity(), self.original_lt) + self.assertLabeledTensorsEqual(map_lt, self.original_lt) + + def test_slice(self): + map_lt = ops.map_fn(lambda t: core.slice_function(t, {'channel': 1}), + self.original_lt) + slice_lt = core.slice_function(self.original_lt, {'channel': 1}) + self.assertLabeledTensorsEqual(map_lt, slice_lt) + + +class SqueezeTest(Base): + + def setUp(self): + super(SqueezeTest, self).setUp() + + self.squeezable_lt = core.slice_function(self.original_lt, + {'channel': slice(0, 1), + 'probs': slice(0, 1)}) + + def test_name(self): + squeeze_lt = ops.squeeze(self.squeezable_lt) + self.assertIn('lt_squeeze', squeeze_lt.name) + + def test_none(self): + none_lt = ops.squeeze(self.squeezable_lt, None) + axes_lt = ops.squeeze(self.squeezable_lt, ['channel', 'probs']) + self.assertLabeledTensorsEqual(none_lt, axes_lt) + + def test(self): + squeeze_lt = ops.squeeze(self.squeezable_lt, ['probs']) + golden_lt = core.slice_function(self.squeezable_lt, {'probs': 0}) + self.assertLabeledTensorsEqual(squeeze_lt, golden_lt) + + def test_invalid_input(self): + with self.assertRaises(ValueError): + ops.squeeze(self.original_lt, ['channel']) + with self.assertRaises(ValueError): + ops.squeeze(self.squeezable_lt, ['foo']) + + +class MatMulTest(Base): + + def test_name(self): + x_lt = core.LabeledTensor(tf.ones((3,)), ['x']) + matmul_lt = ops.matmul(x_lt, x_lt) + self.assertIn('lt_matmul', matmul_lt.name) + + def test_vector_vector(self): + x_lt = core.LabeledTensor(tf.range(3), ['x']) + matmul_lt = ops.matmul(x_lt, x_lt) + golden_lt = core.convert_to_labeled_tensor(5) + self.assertLabeledTensorsEqual(matmul_lt, golden_lt) + + def test_matrix_vector(self): + xy_lt = core.LabeledTensor(tf.reshape(tf.range(6), (2, 3)), ['x', 'y']) + y_lt = core.LabeledTensor(tf.range(3), ['y']) + + matmul_lt = ops.matmul(xy_lt, y_lt) + golden_lt = core.LabeledTensor( + tf.matmul(xy_lt.tensor, tf.reshape(y_lt.tensor, (-1, 1)))[:, 0], ['x']) + self.assertLabeledTensorsEqual(matmul_lt, golden_lt) + + matmul_lt = ops.matmul(y_lt, xy_lt) + self.assertLabeledTensorsEqual(matmul_lt, golden_lt) + + def test_matrix_matrix(self): + xy_lt = core.LabeledTensor(tf.reshape(tf.range(6), (2, 3)), ['x', 'y']) + yz_lt = core.LabeledTensor(tf.reshape(tf.range(12), (3, 4)), ['y', 'z']) + + matmul_lt = ops.matmul(xy_lt, yz_lt) + golden_lt = core.LabeledTensor( + tf.matmul(xy_lt.tensor, yz_lt.tensor), ['x', 'z']) + self.assertLabeledTensorsEqual(matmul_lt, golden_lt) + + transpose = lambda x: core.transpose(x, list(x.axes.keys())[::-1]) + + matmul_lt = ops.matmul(xy_lt, transpose(yz_lt)) + self.assertLabeledTensorsEqual(matmul_lt, golden_lt) + + matmul_lt = ops.matmul(transpose(xy_lt), yz_lt) + self.assertLabeledTensorsEqual(matmul_lt, golden_lt) + + matmul_lt = ops.matmul(transpose(xy_lt), transpose(yz_lt)) + self.assertLabeledTensorsEqual(matmul_lt, golden_lt) + + matmul_lt = ops.matmul(yz_lt, xy_lt) + self.assertLabeledTensorsEqual(matmul_lt, transpose(golden_lt)) + + def test_matrix_matrix_axis_order(self): + xy_lt = core.LabeledTensor(tf.reshape(tf.range(6), (2, 3)), ['x', 'y']) + yz_lt = core.LabeledTensor(tf.reshape(tf.range(12), (3, 4)), ['y', 'z']) + + golden_lt = core.LabeledTensor( + tf.matmul(xy_lt.tensor, yz_lt.tensor), ['x', 'z']) + + with core.axis_order_scope(['x', 'y', 'z']): + + matmul_lt = ops.matmul(xy_lt, yz_lt) + self.assertLabeledTensorsEqual(matmul_lt, golden_lt) + + matmul_lt = ops.matmul(yz_lt, xy_lt) + self.assertLabeledTensorsEqual(matmul_lt, golden_lt) + + def test_invalid(self): + scalar_lt = core.LabeledTensor(tf.ones(()), []) + x_lt = core.LabeledTensor(tf.ones((2,)), ['x']) + x2_lt = core.LabeledTensor(tf.ones((3,)), ['x']) + y_lt = core.LabeledTensor(tf.ones((3,)), ['y']) + xy_lt = core.LabeledTensor(tf.ones((2, 3)), ['x', 'y']) + xyz_lt = core.LabeledTensor(tf.ones((2, 3, 1)), ['x', 'y', 'z']) + + with self.assertRaisesRegexp(ValueError, 'inputs with at least rank'): + ops.matmul(x_lt, scalar_lt) + + with self.assertRaises(NotImplementedError): + ops.matmul(x_lt, xyz_lt) + + with self.assertRaisesRegexp(ValueError, 'exactly one axis in common'): + ops.matmul(x_lt, y_lt) + + with self.assertRaises(NotImplementedError): + ops.matmul(xy_lt, xy_lt) + + with self.assertRaisesRegexp(ValueError, 'does not match'): + ops.matmul(x_lt, x2_lt) + + +class ReduceSumTest(Base): + + def test_name(self): + sum_lt = ops.reduce_sum(self.original_lt, {'channel'}) + self.assertIn('lt_reduce_sum', sum_lt.name) + + def test_drop_axis(self): + sum_lt = ops.reduce_sum(self.original_lt, {'channel'}) + golden_lt = core.LabeledTensor( + tf.reduce_sum(self.original_lt.tensor, 1), [self.a0, self.a2, self.a3]) + self.assertLabeledTensorsEqual(sum_lt, golden_lt) + + def test_drop_scalar_axis(self): + sum_lt = ops.reduce_sum(self.original_lt, 'channel') + golden_lt = core.LabeledTensor( + tf.reduce_sum(self.original_lt.tensor, 1), [self.a0, self.a2, self.a3]) + self.assertLabeledTensorsEqual(sum_lt, golden_lt) + + def test_keep_axis(self): + sum_lt = ops.reduce_sum(self.original_lt, {('channel', 'hihowareyou')}) + golden_lt = core.LabeledTensor( + tf.reduce_sum(self.original_lt.tensor, + 1, keep_dims=True), + [self.a0, ('channel', ['hihowareyou']), self.a2, self.a3]) + self.assertLabeledTensorsEqual(sum_lt, golden_lt) + + def test_keep_scalar_axis(self): + sum_lt = ops.reduce_sum(self.original_lt, ('channel', 'hihowareyou')) + golden_lt = core.LabeledTensor( + tf.reduce_sum(self.original_lt.tensor, + 1, keep_dims=True), + [self.a0, ('channel', ['hihowareyou']), self.a2, self.a3]) + self.assertLabeledTensorsEqual(sum_lt, golden_lt) + + def test_scalar(self): + scalar_lt = core.LabeledTensor(tf.constant(42), []) + reduce_lt = ops.reduce_sum(scalar_lt, []) + self.assertLabeledTensorsEqual(reduce_lt, scalar_lt) + + def test_empty_list(self): + reduce_lt = ops.reduce_sum(self.original_lt, []) + self.assertLabeledTensorsEqual(reduce_lt, self.original_lt) + + def test_none(self): + sum_lt = ops.reduce_sum(self.original_lt) + golden_lt = core.LabeledTensor(tf.reduce_sum(self.original_lt.tensor), []) + self.assertLabeledTensorsEqual(sum_lt, golden_lt) + + def test_function_docstring_and_name(self): + self.assertIn('tf.reduce_sum', ops.reduce_sum.__doc__) + self.assertEqual('reduce_sum', ops.reduce_sum.__name__) + + +class ReduceMeanTest(Base): + + def test_name(self): + actual_lt = ops.reduce_mean(self.original_lt, {'channel'}) + self.assertIn('lt_reduce_mean', actual_lt.name) + + def test(self): + actual_lt = ops.reduce_mean(self.original_lt, {'channel'}) + golden_lt = core.LabeledTensor( + tf.reduce_mean(self.original_lt.tensor, 1), [self.a0, self.a2, self.a3]) + self.assertLabeledTensorsEqual(actual_lt, golden_lt) + + +class ReduceProdTest(Base): + + def test_name(self): + result_lt = ops.reduce_prod(self.original_lt, {'channel'}) + self.assertIn('lt_reduce_prod', result_lt.name) + + def test(self): + result_lt = ops.reduce_prod(self.original_lt, {'channel'}) + golden_lt = core.LabeledTensor( + tf.reduce_prod(self.original_lt.tensor, 1), [self.a0, self.a2, self.a3]) + self.assertLabeledTensorsEqual(result_lt, golden_lt) + + +class ReduceMinTest(Base): + + def test_name(self): + result_lt = ops.reduce_min(self.original_lt, {'channel'}) + self.assertIn('lt_reduce_min', result_lt.name) + + def test(self): + result_lt = ops.reduce_min(self.original_lt, {'channel'}) + golden_lt = core.LabeledTensor( + tf.reduce_min(self.original_lt.tensor, 1), [self.a0, self.a2, self.a3]) + self.assertLabeledTensorsEqual(result_lt, golden_lt) + + +class ReduceMaxTest(Base): + + def test_name(self): + result_lt = ops.reduce_max(self.original_lt, {'channel'}) + self.assertIn('lt_reduce_max', result_lt.name) + + def test(self): + result_lt = ops.reduce_max(self.original_lt, {'channel'}) + golden_lt = core.LabeledTensor( + tf.reduce_max(self.original_lt.tensor, 1), [self.a0, self.a2, self.a3]) + self.assertLabeledTensorsEqual(result_lt, golden_lt) + + +class BaseReduceBoolean(Base): + + def setUp(self): + super(BaseReduceBoolean, self).setUp() + self.bool_tensor = tf.cast(self.original_lt.tensor > 5, tf.bool) + self.bool_lt = core.LabeledTensor(self.bool_tensor, self.original_lt.axes) + + +class ReduceAllTest(BaseReduceBoolean): + + def test_name(self): + result_lt = ops.reduce_all(self.bool_lt, {'channel'}) + self.assertIn('lt_reduce_all', result_lt.name) + + def test(self): + result_lt = ops.reduce_all(self.bool_lt, {'channel'}) + golden_lt = core.LabeledTensor( + tf.reduce_all(self.bool_tensor, 1), [self.a0, self.a2, self.a3]) + self.assertLabeledTensorsEqual(result_lt, golden_lt) + + +class ReduceAnyTest(BaseReduceBoolean): + + def test_name(self): + result_lt = ops.reduce_any(self.bool_lt, {'channel'}) + self.assertIn('lt_reduce_any', result_lt.name) + + def test(self): + result_lt = ops.reduce_any(self.bool_lt, {'channel'}) + golden_lt = core.LabeledTensor( + tf.reduce_any(self.bool_tensor, 1), [self.a0, self.a2, self.a3]) + self.assertLabeledTensorsEqual(result_lt, golden_lt) + + +class TileTest(Base): + + def test_name(self): + tile_lt = ops.tile(self.original_lt, {'z': 2}) + self.assertIn('lt_tile', tile_lt.name) + + def test(self): + for multiple in [2, tf.constant(2)]: + tile_lt = ops.tile(self.original_lt, {'z': multiple}) + golden_op = tf.tile(self.original_lt.tensor, [1, 1, multiple, 1]) + golden_axes = ['z' if axis.name == 'z' else axis + for axis in self.original_lt.axes.values()] + golden_lt = core.LabeledTensor(golden_op, golden_axes) + self.assertLabeledTensorsEqual(tile_lt, golden_lt) + + def test_invalid_input(self): + with self.assertRaisesRegexp(ValueError, 'are not contained in the set'): + ops.tile(self.original_lt, {'foo': 5}) + with self.assertRaisesRegexp(ValueError, 'axes with tick labels'): + ops.tile(self.original_lt, {'x': 5}) + + +class PadTest(Base): + + def test_name(self): + pad_lt = ops.pad(self.original_lt, {'x': (1, 1), + 'channel': ([], ['alpha'])}) + self.assertIn('lt_pad', pad_lt.name) + + def test(self): + pad_lt = ops.pad(self.original_lt, {'x': (1, 1), + 'channel': ([], ['alpha'])}) + + golden_op = tf.pad(self.original_lt.tensor, [[1, 1], [0, 1], [0, 0], + [0, 0]]) + golden_axes = [('x', self.x_size + 2), + ('channel', ['red', 'green', 'blue', 'alpha']), self.a2, + self.a3] + golden_lt = core.LabeledTensor(golden_op, golden_axes) + self.assertLabeledTensorsEqual(pad_lt, golden_lt) + + def test_invalid_input(self): + with self.assertRaisesRegexp(ValueError, 'are not contained in the set'): + ops.pad(self.original_lt, {'foo': (1, 1), 'channel': ([], ['alpha'])}) + + +class ConstantTest(Base): + + def test_name(self): + constant_lt = ops.constant(1) + self.assertIn('lt_constant', constant_lt.name) + + def test_scalar(self): + constant_lt = ops.constant(1) + golden_lt = core.LabeledTensor(tf.constant(1), []) + self.assertLabeledTensorsEqual(constant_lt, golden_lt) + + def test_infer_shape(self): + constant_lt = ops.constant([1, 2], axes=['x']) + golden_lt = core.LabeledTensor(tf.constant([1, 2]), ['x']) + self.assertLabeledTensorsEqual(constant_lt, golden_lt) + + def test_specify_shape(self): + constant_lt = ops.constant(1, axes=[('x', 3)]) + golden_lt = core.LabeledTensor(tf.constant(1, shape=(3,)), ['x']) + self.assertLabeledTensorsEqual(constant_lt, golden_lt) + + def test_existing_axes(self): + golden_lt = core.LabeledTensor(tf.constant([1, 2]), ['x']) + constant_lt = ops.constant([1, 2], axes=golden_lt.axes) + self.assertLabeledTensorsEqual(constant_lt, golden_lt) + + +class ZerosLikeTest(Base): + + def test_name(self): + like_lt = ops.zeros_like(self.original_lt) + self.assertIn('lt_zeros_like', like_lt.name) + + def test(self): + like_lt = ops.zeros_like(self.original_lt) + golden_lt = core.LabeledTensor( + tf.zeros_like(self.original_lt.tensor), self.original_lt.axes) + self.assertLabeledTensorsEqual(like_lt, golden_lt) + + +class OnesLikeTest(Base): + + def test_name(self): + like_lt = ops.ones_like(self.original_lt) + self.assertIn('lt_ones_like', like_lt.name) + + def test(self): + like_lt = ops.ones_like(self.original_lt) + golden_lt = core.LabeledTensor( + tf.ones_like(self.original_lt.tensor), self.original_lt.axes) + self.assertLabeledTensorsEqual(like_lt, golden_lt) + + +class CastTest(Base): + + def test_name(self): + cast_lt = ops.cast(self.original_lt, tf.float16) + self.assertIn('lt_cast', cast_lt.name) + + def test(self): + cast_lt = ops.cast(self.original_lt, tf.float16) + golden_lt = core.LabeledTensor( + tf.cast(self.original_lt.tensor, tf.float16), self.original_lt.axes) + self.assertLabeledTensorsEqual(cast_lt, golden_lt) + + +class VerifyTensorAllFiniteTest(Base): + + def setUp(self): + super(VerifyTensorAllFiniteTest, self).setUp() + + self.finite_lt = core.LabeledTensor(tf.constant(42.0), []) + self.nan_lt = core.LabeledTensor(tf.constant(np.nan), []) + + self.checked_finite_lt = ops.verify_tensor_all_finite(self.finite_lt, '') + self.checked_nan_lt = ops.verify_tensor_all_finite(self.nan_lt, '') + + def test_name(self): + self.assertIn('lt_verify_tensor_all_finite', self.checked_finite_lt.name) + self.assertIn('lt_verify_tensor_all_finite', self.checked_nan_lt.name) + + def test_finite(self): + self.assertLabeledTensorsEqual(self.finite_lt, self.checked_finite_lt) + + def test_nan(self): + with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, + 'Tensor had NaN values'): + self.eval([self.checked_nan_lt]) + + +class BooleanMaskTest(Base): + + def test_name(self): + mask = core.LabeledTensor(tf.range(7) > 3, [self.a0]) + masked_lt = ops.boolean_mask(self.original_lt, mask) + self.assertIn('lt_boolean_mask', masked_lt.name) + + def test(self): + mask = core.LabeledTensor(tf.range(7) > 3, [self.a0]) + masked_lt = ops.boolean_mask(self.original_lt, mask) + golden_lt = core.LabeledTensor( + tf.boolean_mask(self.original_lt.tensor, mask.tensor), + ['x', self.a1, self.a2, self.a3]) + self.assertLabeledTensorsEqual(masked_lt, golden_lt) + + def test_invalid_rank(self): + mask = core.LabeledTensor(tf.ones((7, 3)) > 3, [self.a0, self.a1]) + with self.assertRaises(NotImplementedError): + ops.boolean_mask(self.original_lt, mask) + + def test_mismatched_axis(self): + mask = core.LabeledTensor(tf.range(7) > 3, ['foo']) + with self.assertRaisesRegexp(ValueError, 'not equal'): + ops.boolean_mask(self.original_lt, mask) + + +class WhereTest(Base): + + def test_name(self): + condition = core.LabeledTensor(tf.range(5) < 3, ['x']) + where_lt = ops.where(condition, condition, condition) + self.assertIn('lt_where', where_lt.name) + + def test(self): + condition = core.LabeledTensor(tf.range(5) < 3, ['x']) + x = core.LabeledTensor(tf.ones(5), ['x']) + y = core.LabeledTensor(tf.zeros(5), ['x']) + where_lt = ops.where(condition, x, y) + + golden_lt = core.LabeledTensor( + tf.concat(0, [tf.ones(3), tf.zeros(2)]), ['x']) + self.assertLabeledTensorsEqual(where_lt, golden_lt) + + def test_mismatched_axes(self): + condition = core.LabeledTensor(tf.range(5) < 3, ['x']) + with self.assertRaisesRegexp(ValueError, 'equal axes'): + ops.where(condition, condition[:3], condition) + with self.assertRaisesRegexp(ValueError, 'equal axes'): + ops.where(condition, condition, condition[:3]) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/labeled_tensor/python/ops/sugar.py b/tensorflow/contrib/labeled_tensor/python/ops/sugar.py new file mode 100644 index 0000000000..914493f473 --- /dev/null +++ b/tensorflow/contrib/labeled_tensor/python/ops/sugar.py @@ -0,0 +1,131 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tools to make it a bit easier to use LabeledTensor.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from six import string_types + +from tensorflow.contrib.labeled_tensor.python.ops import _typecheck as tc +from tensorflow.contrib.labeled_tensor.python.ops import core +from tensorflow.contrib.labeled_tensor.python.ops import ops +from tensorflow.python.framework import ops as tf_ops + + +class ReshapeCoder(object): + """Utility class for mapping to and from another shape. + + For example, say you have a function `crop_center` which expects a + LabeledTensor with axes named ['batch', 'row', 'column', 'depth'], and + you have a LabeledTensor `masked_image_lt` with axes ['batch', 'row', + 'column', 'channel', 'mask']. + + To call `crop_center` with `masked_image_lt` you'd normally have to write: + + >>> reshape_lt = lt.reshape(masked_image_lt, ['channel', 'mask'], ['depth']) + >>> crop_lt = crop_center(reshape_lt) + >>> result_lt = lt.reshape(crop_lt, ['depth'], + ... [masked_image_lt.axes['channel'], masked_image_lt.axes['mask']]) + + ReshapeCoder takes care of this renaming logic for you, allowing you to + instead write: + + >>> rc = ReshapeCoder(['channel', 'mask'], ['depth']) + >>> result_lt = rc.decode(crop_center(rc.encode(masked_image_lt))) + + Here, `decode` restores the original axes 'channel' and 'mask', so + `crop_center` must not have modified the size of the 'depth' axis. + """ + + @tc.accepts(object, tc.Collection(str), + tc.Collection(tc.Union(str, core.AxisLike)), tc.Optional(str)) + def __init__(self, existing_axis_names, new_axes, name=None): + self._name = name + self._existing_axis_names = existing_axis_names + self._new_axes = new_axes + + self._existing_axes = None + + @tc.returns(core.LabeledTensor) + @tc.accepts(object, core.LabeledTensorLike) + def encode(self, labeled_tensor): + """Reshape the input to the target shape. + + If called several times, the axes named in existing_axis_names must be + identical. + + Args: + labeled_tensor: The input tensor. + + Returns: + The input reshaped to the target shape. + + Raises: + ValueError: If the axes in existing_axis_names don't match the axes of + a tensor in a previous invocation of this method. + """ + with tf_ops.name_scope(self._name, 'lt_reshape_encode', + [labeled_tensor]) as scope: + labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) + + reshape_lt = ops.reshape(labeled_tensor, + self._existing_axis_names, + self._new_axes, + name=scope) + + axes = [labeled_tensor.axes[n] for n in self._existing_axis_names] + if self._existing_axes is not None and self._existing_axes != axes: + raise ValueError( + 'input axes %r do not match axes from previous method call %r' % + (axes, self._existing_axes)) + else: + self._existing_axes = axes + + return reshape_lt + + @tc.returns(core.LabeledTensor) + @tc.accepts(object, core.LabeledTensorLike) + def decode(self, labeled_tensor): + """Reshape the input to the original shape. + + This is the inverse of encode. + Encode must have been called at least once prior to this method being + called. + + Args: + labeled_tensor: The input tensor. + + Returns: + The input reshaped to the original shape. + + Raises: + ValueError: If this method was called before encode was called. + """ + if self._existing_axes is None: + raise ValueError('decode called before encode') + + with tf_ops.name_scope(self._name, 'lt_reshape_decode', + [labeled_tensor]) as scope: + labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) + + new_axis_names = [axis if isinstance(axis, string_types) else + core.as_axis(axis).name for axis in self._new_axes] + + return ops.reshape(labeled_tensor, + new_axis_names, + self._existing_axes, + name=scope) diff --git a/tensorflow/contrib/labeled_tensor/python/ops/sugar_test.py b/tensorflow/contrib/labeled_tensor/python/ops/sugar_test.py new file mode 100644 index 0000000000..3923f5a174 --- /dev/null +++ b/tensorflow/contrib/labeled_tensor/python/ops/sugar_test.py @@ -0,0 +1,106 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from six.moves import range # pylint: disable=redefined-builtin +import tensorflow as tf + +from tensorflow.contrib.labeled_tensor.python.ops import core +from tensorflow.contrib.labeled_tensor.python.ops import ops +from tensorflow.contrib.labeled_tensor.python.ops import sugar +from tensorflow.contrib.labeled_tensor.python.ops import test_util + + +class Base(test_util.Base): + + def setUp(self): + super(Base, self).setUp() + + self.small_lt = core.LabeledTensor(tf.constant([1]), [('x', 1)]) + + +class ReshapeCoderTest(Base): + + def setUp(self): + super(ReshapeCoderTest, self).setUp() + + self.batch_size = 8 + self.num_rows = 50 + self.num_columns = 100 + self.channels = ['red', 'green', 'blue'] + self.masks = [False, True] + + tensor = tf.range(0, self.batch_size * self.num_rows * self.num_columns * + len(self.channels) * len(self.masks)) + tensor = tf.reshape(tensor, [self.batch_size, self.num_rows, + self.num_columns, len(self.channels), + len(self.masks)]) + + self.batch_axis = ('batch', range(self.batch_size)) + self.row_axis = ('row', range(self.num_rows)) + self.column_axis = ('column', range(self.num_columns)) + self.channel_axis = ('channel', self.channels) + self.mask_axis = ('mask', self.masks) + + axes = [self.batch_axis, self.row_axis, self.column_axis, self.channel_axis, + self.mask_axis] + self.masked_image_lt = core.LabeledTensor(tensor, axes) + + def test_name(self): + rc = sugar.ReshapeCoder(['channel', 'mask'], ['depth']) + encode_lt = rc.encode(self.masked_image_lt) + decode_lt = rc.decode(encode_lt) + self.assertIn('lt_reshape_encode', encode_lt.name) + self.assertIn('lt_reshape_decode', decode_lt.name) + + def test_bijection_flat(self): + rc = sugar.ReshapeCoder(['channel', 'mask'], ['depth']) + + encode_lt = rc.encode(self.masked_image_lt) + golden_axes = core.Axes([self.batch_axis, self.row_axis, self.column_axis, + ('depth', len(self.channels) * len(self.masks))]) + self.assertEqual(encode_lt.axes, golden_axes) + + decode_lt = rc.decode(encode_lt) + self.assertLabeledTensorsEqual(decode_lt, self.masked_image_lt) + + def test_bijection_with_labels(self): + depth_axis = core.Axis('depth', range(len(self.channels) * len(self.masks))) + rc = sugar.ReshapeCoder(['channel', 'mask'], [depth_axis, + ('other', ['label'])]) + + encode_lt = rc.encode(self.masked_image_lt) + golden_axes = core.Axes([self.batch_axis, self.row_axis, self.column_axis, + depth_axis, ('other', ['label'])]) + self.assertEqual(encode_lt.axes, golden_axes) + + decode_lt = rc.decode(encode_lt) + self.assertLabeledTensorsEqual(decode_lt, self.masked_image_lt) + + def test_invalid_input(self): + with self.assertRaises(ValueError): + rc = sugar.ReshapeCoder(['channel', 'mask'], ['depth']) + rc.decode(self.masked_image_lt) + with self.assertRaises(ValueError): + rc = sugar.ReshapeCoder(['channel', 'mask'], ['depth']) + rc.encode(self.masked_image_lt) + rc.encode(ops.select(self.masked_image_lt, {'channel': 'red'})) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/labeled_tensor/python/ops/test_util.py b/tensorflow/contrib/labeled_tensor/python/ops/test_util.py new file mode 100644 index 0000000000..521314010e --- /dev/null +++ b/tensorflow/contrib/labeled_tensor/python/ops/test_util.py @@ -0,0 +1,47 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Utils for writing tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + + +class Base(tf.test.TestCase): + """A class with some useful methods for testing.""" + + def eval(self, tensors): + with self.test_session() as sess: + coord = tf.train.Coordinator() + threads = tf.train.start_queue_runners(sess=sess, coord=coord) + + try: + results = sess.run(tensors) + finally: + coord.request_stop() + coord.join(threads) + + return results + + def assertTensorsEqual(self, tensor_0, tensor_1): + [tensor_0_eval, tensor_1_eval] = self.eval([tensor_0, tensor_1]) + self.assertAllEqual(tensor_0_eval, tensor_1_eval) + + def assertLabeledTensorsEqual(self, tensor_0, tensor_1): + self.assertEqual(tensor_0.axes, tensor_1.axes) + self.assertTensorsEqual(tensor_0.tensor, tensor_1.tensor) |