blob: 69fb11e344f07337955b4ab57960cb853d156a2d (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
|
# TPU support for TensorFlow #
This directory contains code required to re-target a TensorFlow model to run
on TPUs.
## Example usage - TPU Estimator
Below shows example usage of the TPU Estimator for a simple convolutional
network.
```python
import tensorflow as tf
from tensorflow.contrib.tpu.python.tpu import tpu_config
from tensorflow.contrib.tpu.python.tpu import tpu_estimator
from tensorflow.contrib.tpu.python.tpu import tpu_optimizer
def model_fn(features, labels, mode, params):
# Define the model to construct the logits
logits = # ...
loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
optimizer = tpu_optimizer.CrossShardOptimizer(
tf.train.GradientDescentOptimizer(learning_rate=FLAGS.learning_rate))
train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
def input_fn(params):
# ...
pass
def main():
run_config = tpu_config.RunConfig(
master=FLAGS.master,
# ...
)
estimator = tpu_estimator.TpuEstimator(
model_fn=model_fn,
use_tpu=FLAGS.use_tpu,
config=run_config,
batch_size=FLAGS.batch_size)
estimator.train(input_fn=input_fn, max_steps=FLAGS.train_steps)
```
For the complete [executable] example, see our open source TPU models.
|