aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/g3doc/how_tos/new_data_formats/index.md
blob: b1b09fe1ff1d718977ca37c626805968edbbeecd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
# Extending TF: Supporting new data formats

PREREQUISITES:

*   Some familiarity with C++.
*   Must have
    [downloaded TensorFlow source](../../get_started/os_setup.md#source), and be
    able to build it.

We divide the task of supporting a file format into two pieces:

*   File formats: We use a *Reader* Op to read a *record* (which can be any
    string) from a file.
*   Record formats: We use decoder or parsing Ops to turn a string record
    into tensors usable by TensorFlow.

For example, to read a
[CSV file](https://en.wikipedia.org/wiki/Comma-separated_values), we use
[a Reader for text files](../../api_docs/python/io_ops.md#TextLineReader)
followed by
[an Op that parses CSV data from a line of text](../../api_docs/python/io_ops.md#decode_csv).

<!-- TOC-BEGIN This section is generated by neural network: DO NOT EDIT! -->
## Contents
* [Writing a Reader for a file format](#AUTOGENERATED-writing-a-reader-for-a-file-format)
* [Writing an Op for a record format](#AUTOGENERATED-writing-an-op-for-a-record-format)


<!-- TOC-END This section was generated by neural network, THANKS FOR READING! -->

## Writing a Reader for a file format <div class="md-anchor" id="AUTOGENERATED-writing-a-reader-for-a-file-format">{#AUTOGENERATED-writing-a-reader-for-a-file-format}</div>

A `Reader` is something that reads records from a file.  There are some examples
of Reader Ops already built into TensorFlow:

*   [`tf.TFRecordReader`](../../api_docs/python/io_ops.md#TFRecordReader)
    ([source in kernels/tf_record_reader_op.cc](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/kernels/tf_record_reader_op.cc))
*   [`tf.FixedLengthRecordReader`](../../api_docs/python/io_ops.md#FixedLengthRecordReader)
    ([source in kernels/fixed_length_record_reader_op.cc](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/kernels/fixed_length_record_reader_op.cc))
*   [`tf.TextLineReader`](../../api_docs/python/io_ops.md#TextLineReader)
    ([source in kernels/text_line_reader_op.cc](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/kernels/text_line_reader_op.cc))

You can see these all expose the same interface, the only differences
are in their constructors.  The most important method is `read()`.
It takes a queue argument, which is where it gets filenames to
read from whenever it needs one (e.g. when the `read` op first runs, or
the previous `read` reads the last record from a file).  It produces
two scalar tensors: a string key and and a string value.

To create a new reader called `SomeReader`, you will need to:

1.  In C++, define a subclass of
    [`tensorflow::ReaderBase`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/kernels/reader_base.h)
    called `SomeReader`.
2.  In C++, register a new reader op and kernel with the name `"SomeReader"`.
3.  In Python, define a subclass of [`tf.ReaderBase`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/python/ops/io_ops.py) called `SomeReader`.

You can put all the C++ code in a file in
`tensorflow/core/user_ops/some_reader_op.cc`.  The code to read a file will live
in a descendant of the C++ `ReaderBase` class, which is defined in
[tensorflow/core/kernels/reader_base.h](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/kernels/reader_base.h).
You will need to implement the following methods:

*   `OnWorkStartedLocked`: open the next file
*   `ReadLocked`: read a record or report EOF/error
*   `OnWorkFinishedLocked`: close the current file, and
*   `ResetLocked`: get a clean slate after, e.g., an error

These methods have names ending in "Locked" since `ReaderBase` makes sure
to acquire a mutex before calling any of these methods, so you generally don't
have to worry about thread safety (though that only protects the members of the
class, not global state).

For `OnWorkStartedLocked`, the name of the file to open is the value returned by
the `current_work()` method.  `ReadLocked()` has this signature:

```c++
Status ReadLocked(string* key, string* value, bool* produced, bool* at_end)
```

If `ReadLocked()` successfully reads a record from the file, it should fill in:

*   `*key`: with an identifier for the record, that a human could use to find
    this record again.  You can include the filename from `current_work()`,
    and append a record number or whatever.
*   `*value`: with the contents of the record.
*   `*produced`: set to `true`.

If you hit the end of a file (EOF), set `*at_end` to `true`.  In either case,
return `Status::OK()`.  If there is an error, simply return it using one of the
helper functions from
[tensorflow/core/lib/core/errors.h](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/lib/core/errors.h)
without modifying any arguments.

Next you will create the actual Reader op.  It will help if you are familiar
with [the adding an op how-to](../adding_an_op/index.md).  The main steps
are:

*   Registering the op.
*   Define and register an `OpKernel`.

To register the op, you will use a `REGISTER_OP()` call defined in
[tensorflow/core/framework/op.h](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/op.h).
Reader ops never take any input and always have a single output with type
`Ref(string)`.  They should always call `SetIsStateful()`, and have a string
`container` and `shared_name` attrs.  You may optionally define additional attrs
for configuration or include documentation in a `Doc()`.  For examples, see
[tensorflow/core/ops/io_ops.cc](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/ops/io_ops.cc),
e.g.:

```c++
#include "tensorflow/core/framework/op.h"

REGISTER_OP("TextLineReader")
    .Output("reader_handle: Ref(string)")
    .Attr("skip_header_lines: int = 0")
    .Attr("container: string = ''")
    .Attr("shared_name: string = ''")
    .SetIsStateful()
    .Doc(R"doc(
A Reader that outputs the lines of a file delimited by '\n'.
)doc");
```
    
To define an `OpKernel`, Readers can use the shortcut of descending from
`ReaderOpKernel`, defined in
[tensorflow/core/framework/reader_op_kernel.h](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/reader_op_kernel.h),
and implement a constructor that calls `SetReaderFactory()`.  After defining
your class, you will need to register it using `REGISTER_KERNEL_BUILDER(...)`.
An example with no attrs:

```c++
#include "tensorflow/core/framework/reader_op_kernel.h"

class TFRecordReaderOp : public ReaderOpKernel {
 public:
  explicit TFRecordReaderOp(OpKernelConstruction* context)
      : ReaderOpKernel(context) {
    Env* env = context->env();
    SetReaderFactory([this, env]() { return new TFRecordReader(name(), env); });
  }
};

REGISTER_KERNEL_BUILDER(Name("TFRecordReader").Device(DEVICE_CPU),
                        TFRecordReaderOp);
```

An example with attrs:

```c++
#include "tensorflow/core/framework/reader_op_kernel.h"

class TextLineReaderOp : public ReaderOpKernel {
 public:
  explicit TextLineReaderOp(OpKernelConstruction* context)
      : ReaderOpKernel(context) {
    int skip_header_lines = -1;
    OP_REQUIRES_OK(context,
                   context->GetAttr("skip_header_lines", &skip_header_lines));
    OP_REQUIRES(context, skip_header_lines >= 0,
                errors::InvalidArgument("skip_header_lines must be >= 0 not ",
                                        skip_header_lines));
    Env* env = context->env();
    SetReaderFactory([this, skip_header_lines, env]() {
      return new TextLineReader(name(), skip_header_lines, env);
    });
  }
};

REGISTER_KERNEL_BUILDER(Name("TextLineReader").Device(DEVICE_CPU),
                        TextLineReaderOp);
```

The last step is to add the Python wrapper.  You will import
`tensorflow.python.ops.io_ops` in
[tensorflow/python/user_ops/user_ops.py](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/python/user_ops/user_ops.py)
and add a descendant of [`io_ops.ReaderBase`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/python/ops/io_ops.py).

```python
from tensorflow.python.framework import ops
from tensorflow.python.ops import common_shapes
from tensorflow.python.ops import io_ops

class SomeReader(io_ops.ReaderBase):

    def __init__(self, name=None):
        rr = gen_user_ops.some_reader(name=name)
        super(SomeReader, self).__init__(rr)


ops.NoGradient("SomeReader")
ops.RegisterShape("SomeReader")(common_shapes.scalar_shape)
```

You can see some examples in
[`tensorflow/python/ops/io_ops.py`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/python/ops/io_ops.py).

## Writing an Op for a record format <div class="md-anchor" id="AUTOGENERATED-writing-an-op-for-a-record-format">{#AUTOGENERATED-writing-an-op-for-a-record-format}</div>

Generally this is an ordinary op that takes a scalar string record as input, and
so follow [the instructions to add an Op](../adding_an_op/index.md).  You may
optionally take a scalar string key as input, and include that in error messages
reporting improperly formatted data.  That way users can more easily track down
where the bad data came from.

Examples of Ops useful for decoding records:

*   [`tf.parse_single_example`](../../api_docs/python/io_ops.md#parse_single_example)
    (and
    [`tf.parse_example`](../../api_docs/python/io_ops.md#parse_example))
*   [`tf.decode_csv`](../../api_docs/python/io_ops.md#decode_csv)
*   [`tf.decode_raw`](../../api_docs/python/io_ops.md#decode_raw)

Note that it can be useful to use multiple Ops to decode a particular record
format.  For example, you may have an image saved as a string in
[a tf.train.Example protocol buffer](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/example/example.proto).
Depending on the format of that image, you might take the corresponding output
from a
[`tf.parse_single_example`](../../api_docs/python/io_ops.md#parse_single_example)
op and call [`tf.decode_jpeg`](../../api_docs/python/image.md#decode_jpeg),
[`tf.decode_png`](../../api_docs/python/image.md#decode_png), or
[`tf.decode_raw`](../../api_docs/python/io_ops.md#decode_raw).  It is common to
take the output of `tf.decode_raw` and use
[`tf.slice`](../../api_docs/python/array_ops.md#slice) and
[`tf.reshape`](../../api_docs/python/array_ops.md#reshape) to extract pieces.