aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/docs_src/extend/new_data_formats.md
blob: 47a8344b70adade03612532d6fab340b2576bed7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
# Reading custom file and record formats

PREREQUISITES:

*   Some familiarity with C++.
*   Must have
    @{$install_sources$downloaded TensorFlow source}, and be
    able to build it.

We divide the task of supporting a file format into two pieces:

*   File formats: We use a reader `tf.data.Dataset` to read raw *records* (which
    are typically represented by scalar string tensors, but can have more
    structure) from a file.
*   Record formats: We use decoder or parsing ops to turn a string record
    into tensors usable by TensorFlow.

For example, to re-implement `tf.contrib.data.make_csv_dataset` function, we
could use `tf.data.TextLineDataset` to extract the records, and then
use `tf.data.Dataset.map` and `tf.decode_csv` to parses the CSV records from
each line of text in the dataset.

[TOC]

## Writing a `Dataset` for a file format

A `tf.data.Dataset` represents a sequence of *elements*, which can be the
individual records in a file. There are several examples of "reader" datasets
that are already built into TensorFlow:

*   `tf.data.TFRecordDataset`
    ([source in `kernels/data/reader_dataset_ops.cc`](https://www.tensorflow.org/code/tensorflow/core/kernels/data/reader_dataset_ops.cc))
*   `tf.data.FixedLengthRecordDataset`
    ([source in `kernels/data/reader_dataset_ops.cc`](https://www.tensorflow.org/code/tensorflow/core/kernels/data/reader_dataset_ops.cc))
*   `tf.data.TextLineDataset`
    ([source in `kernels/data/reader_dataset_ops.cc`](https://www.tensorflow.org/code/tensorflow/core/kernels/data/reader_dataset_ops.cc))

Each of these implementations comprises three related classes:

* A `tensorflow::DatasetOpKernel` subclass (e.g. `TextLineDatasetOp`), which
  tells TensorFlow how to construct a dataset object from the inputs to and
  attrs of an op, in its `MakeDataset()` method.

* A `tensorflow::GraphDatasetBase` subclass (e.g. `TextLineDatasetOp::Dataset`),
  which represents the *immutable* definition of the dataset itself, and tells
  TensorFlow how to construct an iterator object over that dataset, in its
  `MakeIteratorInternal()` method.

* A `tensorflow::DatasetIterator<Dataset>` subclass (e.g.
  `TextLineDatasetOp::Dataset::Iterator`), which represents the *mutable* state
  of an iterator over a particular dataset, and tells TensorFlow how to get the
  next element from the iterator, in its `GetNextInternal()` method.

The most important method is the `GetNextInternal()` method, since it defines
how to actually read records from the file and represent them as one or more
`Tensor` objects.

To create a new reader dataset called (for example) `MyReaderDataset`, you will
need to:

1. In C++, define subclasses of `tensorflow::DatasetOpKernel`,
   `tensorflow::GraphDatasetBase`, and `tensorflow::DatasetIterator<Dataset>`
   that implement the reading logic.
2. In C++, register a new reader op and kernel with the name
   `"MyReaderDataset"`.
3. In Python, define a subclass of `tf.data.Dataset` called `MyReaderDataset`.

You can put all the C++ code in a single file, such as
`my_reader_dataset_op.cc`. It will help if you are
familiar with @{$adding_an_op$the adding an op how-to}. The following skeleton
can be used as a starting point for your implementation:

```c++
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"

namespace myproject {
namespace {

using ::tensorflow::DT_STRING;
using ::tensorflow::PartialTensorShape;
using ::tensorflow::Status;

class MyReaderDatasetOp : public tensorflow::DatasetOpKernel {
 public:

  MyReaderDatasetOp(tensorflow::OpKernelConstruction* ctx)
      : DatasetOpKernel(ctx) {
    // Parse and validate any attrs that define the dataset using
    // `ctx->GetAttr()`, and store them in member variables.
  }

  void MakeDataset(tensorflow::OpKernelContext* ctx,
                   tensorflow::DatasetBase** output) override {
    // Parse and validate any input tensors 0that define the dataset using
    // `ctx->input()` or the utility function
    // `ParseScalarArgument<T>(ctx, &arg)`.

    // Create the dataset object, passing any (already-validated) arguments from
    // attrs or input tensors.
    *output = new Dataset(ctx);
  }

 private:
  class Dataset : public tensorflow::GraphDatasetBase {
   public:
    Dataset(tensorflow::OpKernelContext* ctx) : GraphDatasetBase(ctx) {}

    std::unique_ptr<tensorflow::IteratorBase> MakeIteratorInternal(
        const string& prefix) const override {
      return std::unique_ptr<tensorflow::IteratorBase>(new Iterator(
          {this, tensorflow::strings::StrCat(prefix, "::MyReader")}));
    }

    // Record structure: Each record is represented by a scalar string tensor.
    //
    // Dataset elements can have a fixed number of components of different
    // types and shapes; replace the following two methods to customize this
    // aspect of the dataset.
    const tensorflow::DataTypeVector& output_dtypes() const override {
      static auto* const dtypes = new tensorflow::DataTypeVector({DT_STRING});
      return *dtypes;
    }
    const std::vector<PartialTensorShape>& output_shapes() const override {
      static std::vector<PartialTensorShape>* shapes =
          new std::vector<PartialTensorShape>({{}});
      return *shapes;
    }

    string DebugString() const override { return "MyReaderDatasetOp::Dataset"; }

   protected:
    // Optional: Implementation of `GraphDef` serialization for this dataset.
    //
    // Implement this method if you want to be able to save and restore
    // instances of this dataset (and any iterators over it).
    Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
                              tensorflow::Node** output) const override {
      // Construct nodes to represent any of the input tensors from this
      // object's member variables using `b->AddScalar()` and `b->AddVector()`.
      std::vector<tensorflow::Node*> input_tensors;
      TF_RETURN_IF_ERROR(b->AddDataset(this, input_tensors, output));
      return Status::OK();
    }

   private:
    class Iterator : public tensorflow::DatasetIterator<Dataset> {
     public:
      explicit Iterator(const Params& params)
          : DatasetIterator<Dataset>(params), i_(0) {}

      // Implementation of the reading logic.
      //
      // The example implementation in this file yields the string "MyReader!"
      // ten times. In general there are three cases:
      //
      // 1. If an element is successfully read, store it as one or more tensors
      //    in `*out_tensors`, set `*end_of_sequence = false` and return
      //    `Status::OK()`.
      // 2. If the end of input is reached, set `*end_of_sequence = true` and
      //    return `Status::OK()`.
      // 3. If an error occurs, return an error status using one of the helper
      //    functions from "tensorflow/core/lib/core/errors.h".
      Status GetNextInternal(tensorflow::IteratorContext* ctx,
                             std::vector<tensorflow::Tensor>* out_tensors,
                             bool* end_of_sequence) override {
        // NOTE: `GetNextInternal()` may be called concurrently, so it is
        // recommended that you protect the iterator state with a mutex.
        tensorflow::mutex_lock l(mu_);
        if (i_ < 10) {
          // Create a scalar string tensor and add it to the output.
          tensorflow::Tensor record_tensor(ctx->allocator({}), DT_STRING, {});
          record_tensor.scalar<string>()() = "MyReader!";
          out_tensors->emplace_back(std::move(record_tensor));
          ++i_;
          *end_of_sequence = false;
        } else {
          *end_of_sequence = true;
        }
        return Status::OK();
      }

     protected:
      // Optional: Implementation of iterator state serialization for this
      // iterator.
      //
      // Implement these two methods if you want to be able to save and restore
      // instances of this iterator.
      Status SaveInternal(tensorflow::IteratorStateWriter* writer) override {
        tensorflow::mutex_lock l(mu_);
        TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
        return Status::OK();
      }
      Status RestoreInternal(tensorflow::IteratorContext* ctx,
                             tensorflow::IteratorStateReader* reader) override {
        tensorflow::mutex_lock l(mu_);
        TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
        return Status::OK();
      }

     private:
      tensorflow::mutex mu_;
      int64 i_ GUARDED_BY(mu_);
    };
  };
};

// Register the op definition for MyReaderDataset.
//
// Dataset ops always have a single output, of type `variant`, which represents
// the constructed `Dataset` object.
//
// Add any attrs and input tensors that define the dataset here.
REGISTER_OP("MyReaderDataset")
    .Output("handle: variant")
    .SetIsStateful()
    .SetShapeFn(tensorflow::shape_inference::ScalarShape);

// Register the kernel implementation for MyReaderDataset.
REGISTER_KERNEL_BUILDER(Name("MyReaderDataset").Device(tensorflow::DEVICE_CPU),
                        MyReaderDatasetOp);

}  // namespace
}  // namespace myproject
```

The last step is to build the C++ code and add a Python wrapper. The easiest way
to do this is by @{$adding_an_op#build_the_op_library$compiling a dynamic
library} (e.g. called `"my_reader_dataset_op.so"`), and adding a Python class
that subclasses `tf.data.Dataset` to wrap it. An example Python program is
given here:

```python
import tensorflow as tf

# Assumes the file is in the current working directory.
my_reader_dataset_module = tf.load_op_library("./my_reader_dataset_op.so")

class MyReaderDataset(tf.data.Dataset):

  def __init__(self):
    super(MyReaderDataset, self).__init__()
    # Create any input attrs or tensors as members of this class.

  def _as_variant_tensor(self):
    # Actually construct the graph node for the dataset op.
    #
    # This method will be invoked when you create an iterator on this dataset
    # or a dataset derived from it.
    return my_reader_dataset_module.my_reader_dataset()

  # The following properties define the structure of each element: a scalar
  # `tf.string` tensor. Change these properties to match the `output_dtypes()`
  # and `output_shapes()` methods of `MyReaderDataset::Dataset` if you modify
  # the structure of each element.
  @property
  def output_types(self):
    return tf.string

  @property
  def output_shapes(self):
    return tf.TensorShape([])

  @property
  def output_classes(self):
    return tf.Tensor

if __name__ == "__main__":
  # Create a MyReaderDataset and print its elements.
  with tf.Session() as sess:
    iterator = MyReaderDataset().make_one_shot_iterator()
    next_element = iterator.get_next()
    try:
      while True:
        print(sess.run(next_element))  # Prints "MyReader!" ten times.
    except tf.errors.OutOfRangeError:
      pass
```

You can see some examples of `Dataset` wrapper classes in
[`tensorflow/python/data/ops/dataset_ops.py`](https://www.tensorflow.org/code/tensorflow/python/data/ops/dataset_ops.py).

## Writing an Op for a record format

Generally this is an ordinary op that takes a scalar string record as input, and
so follow @{$adding_an_op$the instructions to add an Op}.
You may optionally take a scalar string key as input, and include that in error
messages reporting improperly formatted data.  That way users can more easily
track down where the bad data came from.

Examples of Ops useful for decoding records:

*   `tf.parse_single_example` (and `tf.parse_example`)
*   `tf.decode_csv`
*   `tf.decode_raw`

Note that it can be useful to use multiple Ops to decode a particular record
format.  For example, you may have an image saved as a string in
[a `tf.train.Example` protocol buffer](https://www.tensorflow.org/code/tensorflow/core/example/example.proto).
Depending on the format of that image, you might take the corresponding output
from a `tf.parse_single_example` op and call `tf.image.decode_jpeg`,
`tf.image.decode_png`, or `tf.decode_raw`.  It is common to take the output
of `tf.decode_raw` and use `tf.slice` and `tf.reshape` to extract pieces.