aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/docs_src/guide/datasets_for_estimators.md
blob: b04af78cd820f1b3506f62112f25dd8fdb73e76c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
# Datasets for Estimators

The @{tf.data} module contains a collection of classes that allows you to
easily load data, manipulate it, and pipe it into your model. This document
introduces the API by walking through two simple examples:

* Reading in-memory data from numpy arrays.
* Reading lines from a csv file.

<!-- TODO(markdaoust): Add links to an example reading from multiple-files
(image_retraining), and a from_generator example. -->

## Basic input

Taking slices from an array is the simplest way to get started with `tf.data`.

The @{$premade_estimators$Premade Estimators} chapter describes
the following `train_input_fn`, from
[`iris_data.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/iris_data.py),
to pipe the data into the Estimator:

``` python
def train_input_fn(features, labels, batch_size):
    """An input function for training"""
    # Convert the inputs to a Dataset.
    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

    # Shuffle, repeat, and batch the examples.
    dataset = dataset.shuffle(1000).repeat().batch(batch_size)

    # Return the dataset.
    return dataset
```

Let's look at this more closely.

### Arguments

This function expects three arguments. Arguments expecting an "array" can
accept nearly anything that can be converted to an array with `numpy.array`.
One exception is
[`tuple`](https://docs.python.org/3/tutorial/datastructures.html#tuples-and-sequences)
which, as we will see, has special meaning for `Datasets`.

* `features`: A `{'feature_name':array}` dictionary (or
  [`DataFrame`](https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.html))
  containing the raw input features.
* `labels` : An array containing the
  [label](https://developers.google.com/machine-learning/glossary/#label)
  for each example.
* `batch_size` : An integer indicating the desired batch size.

In [`premade_estimator.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/premade_estimator.py)
we retrieved the Iris data using the `iris_data.load_data()` function.
You can run it, and unpack the results as follows:

``` python
import iris_data

# Fetch the data
train, test = iris_data.load_data()
features, labels = train
```

Then we passed this data to the input function, with a line similar to this:

``` python
batch_size=100
iris_data.train_input_fn(features, labels, batch_size)
```

Let's walk through the `train_input_fn()`.

### Slices

The function starts by using the @{tf.data.Dataset.from_tensor_slices} function
to create a @{tf.data.Dataset} representing slices of the array. The array is
sliced across the first dimension. For example, an array containing the
@{$tutorials/layers$mnist training data} has a shape of `(60000, 28, 28)`.
Passing this to `from_tensor_slices` returns a `Dataset` object containing
60000 slices, each one a 28x28 image.

The code that returns this `Dataset` is as follows:

``` python
train, test = tf.keras.datasets.mnist.load_data()
mnist_x, mnist_y = train

mnist_ds = tf.data.Dataset.from_tensor_slices(mnist_x)
print(mnist_ds)
```

This will print the following line, showing the
@{$guide/tensors#shapes$shapes} and
@{$guide/tensors#data_types$types} of the items in
the dataset. Note that a `Dataset` does not know how many items it contains.

``` None
<TensorSliceDataset shapes: (28,28), types: tf.uint8>
```

The `Dataset` above represents a simple collection of arrays, but datasets are
much more powerful than this. A `Dataset` can transparently handle any nested
combination of dictionaries or tuples (or
[`namedtuple`](https://docs.python.org/2/library/collections.html#collections.namedtuple)
).

For example after converting the iris `features`
to a standard python dictionary, you can then convert the dictionary of arrays
to a `Dataset` of dictionaries as follows:

``` python
dataset = tf.data.Dataset.from_tensor_slices(dict(features))
print(dataset)
```
``` None
<TensorSliceDataset

  shapes: {
    SepalLength: (), PetalWidth: (),
    PetalLength: (), SepalWidth: ()},

  types: {
      SepalLength: tf.float64, PetalWidth: tf.float64,
      PetalLength: tf.float64, SepalWidth: tf.float64}
>
```

Here we see that when a `Dataset` contains structured elements, the `shapes`
and `types` of the `Dataset` take on the same structure. This dataset contains
dictionaries of @{$guide/tensors#rank$scalars}, all of type
`tf.float64`.

The first line of the iris `train_input_fn` uses the same functionality, but
adds another level of structure. It creates a dataset containing
`(features_dict, label)` pairs.

The following code shows that the label is a scalar with type `int64`:

``` python
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
print(dataset)
```
```
<TensorSliceDataset
    shapes: (
        {
          SepalLength: (), PetalWidth: (),
          PetalLength: (), SepalWidth: ()},
        ()),

    types: (
        {
          SepalLength: tf.float64, PetalWidth: tf.float64,
          PetalLength: tf.float64, SepalWidth: tf.float64},
        tf.int64)>
```

### Manipulation

Currently the `Dataset` would iterate over the data once, in a fixed order, and
only produce a single element at a time. It needs further processing before it
can be used for training. Fortunately, the `tf.data.Dataset` class provides
methods to better prepare the data for training. The next line of the input
function takes advantage of several of these methods:

``` python
# Shuffle, repeat, and batch the examples.
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
```

The @{tf.data.Dataset.shuffle$`shuffle`} method uses a fixed-size buffer to
shuffle the items as they pass through. In this case the `buffer_size` is
greater than the number of examples in the `Dataset`, ensuring that the data is
completely shuffled (The Iris data set only contains 150 examples).

The @{tf.data.Dataset.repeat$`repeat`} method restarts the `Dataset` when
it reaches the end. To limit the number of epochs, set the `count` argument.

The @{tf.data.Dataset.batch$`batch`} method collects a number of examples and
stacks them, to create batches. This adds a dimension to their shape. The new
dimension is added as the first dimension. The following code uses
the `batch` method on the MNIST `Dataset`, from earlier. This results in a
`Dataset` containing 3D arrays representing stacks of `(28,28)` images:

``` python
print(mnist_ds.batch(100))
```

``` none
<BatchDataset
  shapes: (?, 28, 28),
  types: tf.uint8>
```
Note that the dataset has an unknown batch size because the last batch will
have fewer elements.

In `train_input_fn`, after batching the `Dataset` contains 1D vectors of
elements where each scalar was previously:

```python
print(dataset)
```
```
<TensorSliceDataset
    shapes: (
        {
          SepalLength: (?,), PetalWidth: (?,),
          PetalLength: (?,), SepalWidth: (?,)},
        (?,)),

    types: (
        {
          SepalLength: tf.float64, PetalWidth: tf.float64,
          PetalLength: tf.float64, SepalWidth: tf.float64},
        tf.int64)>
```


### Return

At this point the `Dataset` contains `(features_dict, labels)` pairs.
This is the format expected by the `train` and `evaluate` methods, so the
`input_fn` returns the dataset.

The `labels` can/should be omitted when using the `predict` method.

<!--
  TODO(markdaoust): link to `input_fn` doc when it exists
-->


## Reading a CSV File

The most common real-world use case for the `Dataset` class is to stream data
from files on disk. The @{tf.data} module includes a variety of
file readers. Let's see how parsing the Iris dataset from the csv file looks
using a `Dataset`.

The following call to the `iris_data.maybe_download` function downloads the
data if necessary, and returns the pathnames of the resulting files:

``` python
import iris_data
train_path, test_path = iris_data.maybe_download()
```

The [`iris_data.csv_input_fn`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/iris_data.py)
function contains an alternative implementation that parses the csv files using
a `Dataset`.

Let's look at how to build an Estimator-compatible input function that reads
from the local files.

### Build the `Dataset`

We start by building a @{tf.data.TextLineDataset$`TextLineDataset`} object to
read the file one line at a time. Then, we call the
@{tf.data.Dataset.skip$`skip`} method to skip over the first line of the file, which contains a header, not an example:

``` python
ds = tf.data.TextLineDataset(train_path).skip(1)
```

### Build a csv line parser

We will start by building a function to parse a single line.

The following `iris_data.parse_line` function accomplishes this task using the
@{tf.decode_csv} function, and some simple python code:

We must parse each of the lines in the dataset in order to generate the
necessary `(features, label)` pairs. The following `_parse_line` function
calls @{tf.decode_csv} to parse a single line into its features
and the label. Since Estimators require that features be represented as a
dictionary, we rely on Python's built-in `dict` and `zip` functions to build
that dictionary.  The feature names are the keys of that dictionary.
We then call the dictionary's `pop` method to remove the label field from
the features dictionary:

``` python
# Metadata describing the text columns
COLUMNS = ['SepalLength', 'SepalWidth',
           'PetalLength', 'PetalWidth',
           'label']
FIELD_DEFAULTS = [[0.0], [0.0], [0.0], [0.0], [0]]
def _parse_line(line):
    # Decode the line into its fields
    fields = tf.decode_csv(line, FIELD_DEFAULTS)

    # Pack the result into a dictionary
    features = dict(zip(COLUMNS,fields))

    # Separate the label from the features
    label = features.pop('label')

    return features, label
```

### Parse the lines

Datasets have many methods for manipulating the data while it is being piped
to a model. The most heavily-used method is @{tf.data.Dataset.map$`map`}, which
applies a transformation to each element of the `Dataset`.

The `map` method takes a `map_func` argument that describes how each item in the
`Dataset` should be transformed.

<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/datasets/map.png">
</div>
<div style="text-align: center">
The @{tf.data.Dataset.map$`map`} method applies the `map_func` to
transform each item in the <code>Dataset</code>.
</div>

So to parse the lines as they are streamed out of the csv file, we pass our
`_parse_line` function to the `map` method:

``` python
ds = ds.map(_parse_line)
print(ds)
```
``` None
<MapDataset
shapes: (
    {SepalLength: (), PetalWidth: (), ...},
    ()),
types: (
    {SepalLength: tf.float32, PetalWidth: tf.float32, ...},
    tf.int32)>
```

Now instead of simple scalar strings, the dataset contains `(features, label)`
pairs.

the remainder of the `iris_data.csv_input_fn` function is identical
to `iris_data.train_input_fn` which was covered in the in the
[Basic input](#basic_input) section.

### Try it out

This function can be used as a replacement for
`iris_data.train_input_fn`. It can be used to feed an estimator as follows:

``` python
train_path, test_path = iris_data.maybe_download()

# All the inputs are numeric
feature_columns = [
    tf.feature_column.numeric_column(name)
    for name in iris_data.CSV_COLUMN_NAMES[:-1]]

# Build the estimator
est = tf.estimator.LinearClassifier(feature_columns,
                                    n_classes=3)
# Train the estimator
batch_size = 100
est.train(
    steps=1000,
    input_fn=lambda : iris_data.csv_input_fn(train_path, batch_size))
```

Estimators expect an `input_fn` to take no arguments. To work around this
restriction, we use `lambda` to capture the arguments and provide the expected
interface.

## Summary

The `tf.data` module provides a collection of classes and functions for easily
reading data from a variety of sources. Furthermore, `tf.data` has simple
powerful methods for applying a wide variety of standard and custom
transformations.

Now you have the basic idea of how to efficiently load data into an
Estimator. Consider the following documents next:


* @{$custom_estimators}, which demonstrates how to build your own
  custom `Estimator` model.
* The @{$low_level_intro#datasets$Low Level Introduction}, which demonstrates
  how to experiment directly with `tf.data.Datasets` using TensorFlow's low
  level APIs.
* @{$guide/datasets} which goes into great detail about additional
  functionality of `Datasets`.