aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/labeled_tensor/README.md
blob: adce979e2acd1516d19572e31af1fd7b1c7a225c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Labels for TensorFlow

LabeledTensor is a library for adding semantically meaningful dimension and
coordinate labels to tensors in Tensorflow.

LabeledTensor was inspired by [xarray](http://xarray.pydata.org) and
[pandas](http://pandas.pydata.org), projects that adds labels to NumPy array.

## Data model

`LabeledTensor` is an immutable object consisting of two components:

- `tensor`: the `tf.Tensor` object containing the labeled tensor's data.
- `axes`: an OrderedDict-like object with keys given by axis names (e.g.,
  ``"channel"``) and values given by `Axis` objects.

`Axis` objects keep track of the size of a dimension and, optionally, coordinate
labels along that axis (e.g., `("red", "green", "blue")`) in the form of a
tuple stored in `Axis.labels`.

Operations on `LabeledTensors` use, preserve and transform axis names and
labels.

## Quick start

Try out the following snippet in a script or Jupyter notebook:

    import tensorflow as tf

    lt = tf.contrib.labeled_tensor

    # Create two LabeledTensors:
    raw_image = tf.ones((299, 299, 3))
    axes = ['row', 'column', ('channel', ['red', 'green', 'blue'])]
    image = lt.LabeledTensor(raw_image, axes)
    assert image.tensor is raw_image
    weights = lt.LabeledTensor(tf.constant([0.1, 0.3, 0.6]),
                               [image.axes['channel']])

    # Examples of valid operations:
    lt.transpose(image, ['column', 'row', 'channel'])
    lt.reshape(image, ['row', 'column'], ['pixel'])
    lt.concat([image, image], 'row')
    lt.reduce_sum(image, ['channel'])
    lt.select(image, {'channel': 'red'})
    lt.cast(image / 256.0, tf.uint8)
    image * weights
    lt.matmul(image[0, :, :], weights)
    tf.cos(image)  # automatically converts to tf.Tensor

## Adding a custom op

LabeledTensor has wrappers for [quite a
few](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/labeled_tensor/__init__.py)
TensorFlow ops.

To easily add your own, you can use the `define_unary_op`, `define_binary_op`
and `define_reduce_op` functions, e.g.,

    log = lt.define_unary_op('log', tf.log)

## Questions

Please reach out to the authors:

- Stephan Hoyer (shoyer@google.com, github.com/shoyer)
- Eric Christiansen (ericmc@google.com, github.com/emchristiansen)