aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu/README.md
blob: 69fb11e344f07337955b4ab57960cb853d156a2d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# TPU support for TensorFlow #

This directory contains code required to re-target a TensorFlow model to run
on TPUs.

## Example usage - TPU Estimator

Below shows example usage of the TPU Estimator for a simple convolutional
network.

```python
import tensorflow as tf

from tensorflow.contrib.tpu.python.tpu import tpu_config
from tensorflow.contrib.tpu.python.tpu import tpu_estimator
from tensorflow.contrib.tpu.python.tpu import tpu_optimizer

def model_fn(features, labels, mode, params):
  # Define the model to construct the logits
  logits = # ...
  loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
  optimizer = tpu_optimizer.CrossShardOptimizer(
    tf.train.GradientDescentOptimizer(learning_rate=FLAGS.learning_rate))
  train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
  return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

def input_fn(params):
  # ...
  pass

def main():
  run_config = tpu_config.RunConfig(
    master=FLAGS.master,
    # ...
  )
  estimator = tpu_estimator.TpuEstimator(
    model_fn=model_fn,
    use_tpu=FLAGS.use_tpu,
    config=run_config,
    batch_size=FLAGS.batch_size)
  estimator.train(input_fn=input_fn, max_steps=FLAGS.train_steps)
```

For the complete [executable] example, see our open source TPU models.