aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Xuechen Li <lxuechen@google.com>2018-07-31 14:19:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-31 14:26:21 -0700
commit3fda31fe7d17d808c18e53186beb54b457088587 (patch)
treefb27ab6284ebbe5d7e987da3d9813402d3b2e541 /tensorflow
parent48f486b4b4ad23dbfddbd9527f184ef25bdc2421 (diff)
Update TPU estimator script to work with ImageNet.
PiperOrigin-RevId: 206825954
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py335
1 files changed, 206 insertions, 129 deletions
diff --git a/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py b/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py
index f0aad9b110..6300021824 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py
@@ -12,22 +12,90 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Cloud TPU Estimator workflow with RevNet train on CIFAR-10."""
+"""Cloud TPU Estimator workflow with RevNet train on ImageNet."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import os
import time
from absl import flags
import tensorflow as tf
-from tensorflow.contrib.eager.python.examples.revnet import cifar_input
-from tensorflow.contrib.eager.python.examples.revnet import main as main_
+from tensorflow.contrib import summary
+from tensorflow.contrib.eager.python.examples.revnet import config as config_
+from tensorflow.contrib.eager.python.examples.revnet import imagenet_input
from tensorflow.contrib.eager.python.examples.revnet import revnet
from tensorflow.contrib.training.python.training import evaluation
-from tensorflow.python.estimator import estimator as estimator_
+from tensorflow.python.estimator import estimator
+
+MEAN_RGB = [0.485, 0.456, 0.406]
+STDDEV_RGB = [0.229, 0.224, 0.225]
+
+
+def _host_call_fn(gs, loss, lr):
+ """Training host call.
+
+ Creates scalar summaries for training metrics.
+
+ This function is executed on the CPU and should not directly reference
+ any Tensors in the rest of the `model_fn`. To pass Tensors from the
+ model to the `metric_fn`, provide as part of the `host_call`. See
+ https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
+ for more information.
+
+ Arguments should match the list of `Tensor` objects passed as the second
+ element in the tuple passed to `host_call`.
+
+ Args:
+ gs: `Tensor with shape `[batch]` for the global_step
+ loss: `Tensor` with shape `[batch]` for the training loss.
+ lr: `Tensor` with shape `[batch]` for the learning_rate.
+
+ Returns:
+ List of summary ops to run on the CPU host.
+ """
+ # Host call fns are executed FLAGS.iterations_per_loop times after one
+ # TPU loop is finished, setting max_queue value to the same as number of
+ # iterations will make the summary writer only flush the data to storage
+ # once per loop.
+ gs = gs[0]
+ with summary.create_file_writer(
+ FLAGS.model_dir, max_queue=FLAGS.iterations_per_loop).as_default():
+ with summary.always_record_summaries():
+ summary.scalar("loss", loss[0], step=gs)
+ summary.scalar("learning_rate", lr[0], step=gs)
+ return summary.all_summary_ops()
+
+
+def _metric_fn(labels, logits):
+ """Evaluation metric function. Evaluates accuracy.
+
+ This function is executed on the CPU and should not directly reference
+ any Tensors in the rest of the `model_fn`. To pass Tensors from the model
+ to the `metric_fn`, provide as part of the `eval_metrics`. See
+ https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
+ for more information.
+
+ Arguments should match the list of `Tensor` objects passed as the second
+ element in the tuple passed to `eval_metrics`.
+
+ Args:
+ labels: `Tensor` with shape `[batch]`.
+ logits: `Tensor` with shape `[batch, num_classes]`.
+
+ Returns:
+ A dict of the metrics to return from evaluation.
+ """
+ predictions = tf.argmax(logits, axis=1)
+ top_1_accuracy = tf.metrics.accuracy(labels, predictions)
+ in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
+ top_5_accuracy = tf.metrics.mean(in_top_5)
+
+ return {
+ "top_1_accuracy": top_1_accuracy,
+ "top_5_accuracy": top_5_accuracy,
+ }
def model_fn(features, labels, mode, params):
@@ -42,45 +110,60 @@ def model_fn(features, labels, mode, params):
Returns:
An instance of `tf.contrib.tpu.TPUEstimatorSpec`
"""
+ tf.logging.info("features: {}".format(features.dtype))
+ tf.logging.info("labels: {}".format(labels.dtype))
+ revnet_config = params["revnet_config"]
+ model = revnet.RevNet(config=revnet_config)
inputs = features
if isinstance(inputs, dict):
inputs = features["image"]
- config = params["config"]
- model = revnet.RevNet(config=config)
+ if revnet_config.data_format == "channels_first":
+ assert not FLAGS.transpose_input # channels_first only for GPU
+ inputs = tf.transpose(inputs, [0, 3, 1, 2])
+
+ if FLAGS.transpose_input and mode != tf.estimator.ModeKeys.PREDICT:
+ inputs = tf.transpose(inputs, [3, 0, 1, 2]) # HWCN to NHWC
+
+ # Normalize the image to zero mean and unit variance.
+ inputs -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=inputs.dtype)
+ inputs /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=inputs.dtype)
if mode == tf.estimator.ModeKeys.TRAIN:
global_step = tf.train.get_or_create_global_step()
learning_rate = tf.train.piecewise_constant(
- global_step, config.lr_decay_steps, config.lr_list)
- optimizer = tf.train.MomentumOptimizer(
- learning_rate, momentum=config.momentum)
-
+ global_step, revnet_config.lr_decay_steps, revnet_config.lr_list)
+ optimizer = tf.train.MomentumOptimizer(learning_rate,
+ revnet_config.momentum)
if FLAGS.use_tpu:
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
logits, saved_hidden = model(inputs, training=True)
grads, loss = model.compute_gradients(saved_hidden, labels, training=True)
- train_op = optimizer.apply_gradients(
- zip(grads, model.trainable_variables), global_step=global_step)
+ with tf.control_dependencies(model.get_updates_for(inputs)):
+ train_op = optimizer.apply_gradients(
+ zip(grads, model.trainable_variables), global_step=global_step)
+ if not FLAGS.skip_host_call:
+ # To log the loss, current learning rate, and epoch for Tensorboard, the
+ # summary op needs to be run on the host CPU via host_call. host_call
+ # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
+ # dimension. These Tensors are implicitly concatenated to
+ # [params['batch_size']].
+ gs_t = tf.reshape(global_step, [1])
+ loss_t = tf.reshape(loss, [1])
+ lr_t = tf.reshape(learning_rate, [1])
+ host_call = (_host_call_fn, [gs_t, loss_t, lr_t])
return tf.contrib.tpu.TPUEstimatorSpec(
- mode=tf.estimator.ModeKeys.TRAIN, loss=loss, train_op=train_op)
+ mode=mode, loss=loss, train_op=train_op, host_call=host_call)
elif mode == tf.estimator.ModeKeys.EVAL:
logits, _ = model(inputs, training=False)
loss = model.compute_loss(labels=labels, logits=logits)
- def metric_fn(labels, logits):
- predictions = tf.argmax(logits, axis=1)
- accuracy = tf.metrics.accuracy(labels=labels, predictions=predictions)
- return {
- "accuracy": accuracy,
- }
-
return tf.contrib.tpu.TPUEstimatorSpec(
- mode=mode, loss=loss, eval_metrics=(metric_fn, [labels, logits]))
+ mode=mode, loss=loss, eval_metrics=(_metric_fn, [labels, logits]))
else: # Predict or export
logits, _ = model(inputs, training=False)
@@ -97,113 +180,75 @@ def model_fn(features, labels, mode, params):
})
-def get_input_fn(config, data_dir, split):
- """Get the input function required by the `tf.contrib.tpu.TPUEstimator` API.
-
- Args:
- config: Customized hyperparameters
- data_dir: Directory where the data is stored
- split: One of `train`, `validation`, `train_all`, and `test`
-
- Returns:
- Input function required by the `tf.contrib.tpu.TPUEstimator` API
- """
-
- data_dir = os.path.join(data_dir, config.dataset)
- # Fix split-dependent hyperparameters
- if split == "train_all" or split == "train":
- data_aug = True
- epochs = config.tpu_epochs
- shuffle = True
- else:
- data_aug = False
- epochs = 1
- shuffle = False
-
- def input_fn(params):
- """Input function required by the `tf.contrib.tpu.TPUEstimator` API."""
- batch_size = params["batch_size"]
- return cifar_input.get_ds_from_tfrecords(
- data_dir=data_dir,
- split=split,
- data_aug=data_aug,
- batch_size=batch_size, # per-shard batch size
- epochs=epochs,
- shuffle=shuffle,
- prefetch=batch_size, # per-shard batch size
- data_format=config.data_format)
-
- return input_fn
-
-
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
# RevNet specific configuration
- config = main_.get_config(config_name=FLAGS.config, dataset=FLAGS.dataset)
+ revnet_config = {
+ "revnet-56": config_.get_hparams_imagenet_56(),
+ "revnet-104": config_.get_hparams_imagenet_104()
+ }[FLAGS.revnet_config]
if FLAGS.use_tpu:
- tf.logging.info("Using TPU.")
- tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
- FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
- else:
- tpu_cluster_resolver = None
-
- # TPU specific configuration
- tpu_config = tf.contrib.tpu.TPUConfig(
- # Recommended to be set as number of global steps for next checkpoint
- iterations_per_loop=FLAGS.iterations_per_loop,
- num_shards=FLAGS.num_shards)
+ revnet_config.data_format = "channels_last"
+
+ tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
+ FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
# Estimator specific configuration
- run_config = tf.contrib.tpu.RunConfig(
+ config = tf.contrib.tpu.RunConfig(
cluster=tpu_cluster_resolver,
model_dir=FLAGS.model_dir,
session_config=tf.ConfigProto(
- allow_soft_placement=True, log_device_placement=False),
- tpu_config=tpu_config,
+ allow_soft_placement=True, log_device_placement=True),
+ tpu_config=tf.contrib.tpu.TPUConfig(
+ iterations_per_loop=FLAGS.iterations_per_loop,
+ num_shards=FLAGS.num_shards,
+ per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.
+ PER_HOST_V2),
)
- # Construct TPU Estimator
- estimator = tf.contrib.tpu.TPUEstimator(
+ # Input pipelines are slightly different (with regards to shuffling and
+ # preprocessing) between training and evaluation.
+ imagenet_train, imagenet_eval = [
+ imagenet_input.ImageNetInput(
+ is_training=is_training,
+ data_dir=FLAGS.data_dir,
+ transpose_input=FLAGS.transpose_input,
+ use_bfloat16=False) for is_training in [True, False]
+ ]
+
+ revnet_classifier = tf.contrib.tpu.TPUEstimator(
model_fn=model_fn,
use_tpu=FLAGS.use_tpu,
- train_batch_size=config.tpu_batch_size,
- eval_batch_size=config.tpu_eval_batch_size,
- config=run_config,
- params={"config": config})
-
- # Construct input functions
- train_input_fn = get_input_fn(
- config=config, data_dir=FLAGS.data_dir, split="train_all")
- eval_input_fn = get_input_fn(
- config=config, data_dir=FLAGS.data_dir, split="test")
-
- # Disabling a range within an else block currently doesn't work
- # due to https://github.com/PyCQA/pylint/issues/872
+ train_batch_size=revnet_config.tpu_batch_size,
+ eval_batch_size=revnet_config.tpu_eval_batch_size,
+ config=config,
+ export_to_tpu=False,
+ params={"revnet_config": revnet_config})
+
+ steps_per_epoch = revnet_config.tpu_iters_per_epoch
+ eval_steps = revnet_config.tpu_eval_steps
+
# pylint: disable=protected-access
if FLAGS.mode == "eval":
- # TPUEstimator.evaluate *requires* a steps argument.
- # Note that the number of examples used during evaluation is
- # --eval_steps * --batch_size.
- # So if you change --batch_size then change --eval_steps too.
- eval_steps = 10000 // config.tpu_eval_batch_size
-
# Run evaluation when there's a new checkpoint
for ckpt in evaluation.checkpoints_iterator(
FLAGS.model_dir, timeout=FLAGS.eval_timeout):
tf.logging.info("Starting to evaluate.")
try:
start_timestamp = time.time() # This time will include compilation time
- eval_results = estimator.evaluate(
- input_fn=eval_input_fn, steps=eval_steps, checkpoint_path=ckpt)
+ eval_results = revnet_classifier.evaluate(
+ input_fn=imagenet_eval.input_fn,
+ steps=eval_steps,
+ checkpoint_path=ckpt)
elapsed_time = int(time.time() - start_timestamp)
tf.logging.info("Eval results: %s. Elapsed seconds: %d" %
(eval_results, elapsed_time))
# Terminate eval job when final checkpoint is reached
current_step = int(os.path.basename(ckpt).split("-")[1])
- if current_step >= config.max_train_iter:
+ if current_step >= revnet_config.max_train_iter:
tf.logging.info(
"Evaluation finished after training step %d" % current_step)
break
@@ -217,37 +262,56 @@ def main(_):
"Checkpoint %s no longer exists, skipping checkpoint" % ckpt)
else: # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
- current_step = estimator_._load_global_step_from_checkpoint_dir(
+ current_step = estimator._load_global_step_from_checkpoint_dir(
FLAGS.model_dir)
- tf.logging.info("Training for %d steps . Current"
- " step %d." % (config.max_train_iter, current_step))
+
+ tf.logging.info(
+ "Training for %d steps (%.2f epochs in total). Current"
+ " step %d." % (revnet_config.max_train_iter,
+ revnet_config.max_train_iter / steps_per_epoch,
+ current_step))
start_timestamp = time.time() # This time will include compilation time
+
if FLAGS.mode == "train":
- estimator.train(input_fn=train_input_fn, max_steps=config.max_train_iter)
+ revnet_classifier.train(
+ input_fn=imagenet_train.input_fn,
+ max_steps=revnet_config.max_train_iter)
+
else:
- eval_steps = 10000 // config.tpu_eval_batch_size
assert FLAGS.mode == "train_and_eval"
- while current_step < config.max_train_iter:
+ while current_step < revnet_config.max_train_iter:
# Train for up to steps_per_eval number of steps.
# At the end of training, a checkpoint will be written to --model_dir.
next_checkpoint = min(current_step + FLAGS.steps_per_eval,
- config.max_train_iter)
- estimator.train(input_fn=train_input_fn, max_steps=next_checkpoint)
+ revnet_config.max_train_iter)
+ revnet_classifier.train(
+ input_fn=imagenet_train.input_fn, max_steps=next_checkpoint)
current_step = next_checkpoint
+ tf.logging.info("Finished training up to step %d. Elapsed seconds %d." %
+ (next_checkpoint, int(time.time() - start_timestamp)))
+
# Evaluate the model on the most recent model in --model_dir.
# Since evaluation happens in batches of --eval_batch_size, some images
- # may be consistently excluded modulo the batch size.
+ # may be excluded modulo the batch size. As long as the batch size is
+ # consistent, the evaluated images are also consistent.
tf.logging.info("Starting to evaluate.")
- eval_results = estimator.evaluate(
- input_fn=eval_input_fn, steps=eval_steps)
+ eval_results = revnet_classifier.evaluate(
+ input_fn=imagenet_eval.input_fn, steps=eval_steps)
tf.logging.info("Eval results: %s" % eval_results)
- elapsed_time = int(time.time() - start_timestamp)
- tf.logging.info("Finished training up to step %d. Elapsed seconds %d." %
- (config.max_train_iter, elapsed_time))
- # pylint: enable=protected-access
+ elapsed_time = int(time.time() - start_timestamp)
+ tf.logging.info("Finished training up to step %d. Elapsed seconds %d." %
+ (revnet_config.max_train_iter, elapsed_time))
+
+ if FLAGS.export_dir is not None:
+ # The guide to serve an exported TensorFlow model is at:
+ # https://www.tensorflow.org/serving/serving_basic
+ tf.logging.info("Starting to export model.")
+ revnet_classifier.export_savedmodel(
+ export_dir_base=FLAGS.export_dir,
+ serving_input_receiver_fn=imagenet_input.image_serving_input_fn)
if __name__ == "__main__":
@@ -279,14 +343,10 @@ if __name__ == "__main__":
default=None,
help="[Optional] Directory to store the model information")
flags.DEFINE_string(
- "dataset",
- default="cifar-10",
- help="[Optional] The dataset used; either `cifar-10` or `cifar-100`")
- flags.DEFINE_string(
- "config",
- default="revnet-38",
+ "revnet_config",
+ default="revnet-56",
help="[Optional] Architecture of network. "
- "Other options include `revnet-110` and `revnet-164`")
+ "Other options include `revnet-104`")
flags.DEFINE_boolean(
"use_tpu", default=True, help="[Optional] Whether to use TPU")
flags.DEFINE_integer(
@@ -300,20 +360,37 @@ if __name__ == "__main__":
" train steps, the loop will exit before reaching"
" --iterations_per_loop. The larger this value is, the higher the"
" utilization on the TPU."))
- flags.DEFINE_string(
- "mode",
- default="train_and_eval",
- help="[Optional] Mode to run: train, eval, train_and_eval")
flags.DEFINE_integer(
- "eval_timeout", 60 * 60 * 24,
- "Maximum seconds between checkpoints before evaluation terminates.")
+ "eval_timeout",
+ default=None,
+ help="Maximum seconds between checkpoints before evaluation terminates.")
flags.DEFINE_integer(
"steps_per_eval",
- default=1000,
+ default=5000,
help=(
"Controls how often evaluation is performed. Since evaluation is"
" fairly expensive, it is advised to evaluate as infrequently as"
" possible (i.e. up to --train_steps, which evaluates the model only"
" after finishing the entire training regime)."))
+ flags.DEFINE_bool(
+ "transpose_input",
+ default=True,
+ help="Use TPU double transpose optimization")
+ flags.DEFINE_string(
+ "export_dir",
+ default=None,
+ help=("The directory where the exported SavedModel will be stored."))
+ flags.DEFINE_bool(
+ "skip_host_call",
+ default=False,
+ help=("Skip the host_call which is executed every training step. This is"
+ " generally used for generating training summaries (train loss,"
+ " learning rate, etc...). When --skip_host_call=false, there could"
+ " be a performance drop if host_call function is slow and cannot"
+ " keep up with the TPU-side computation."))
+ flags.DEFINE_string(
+ "mode",
+ default="train_and_eval",
+ help='One of {"train_and_eval", "train", "eval"}.')
FLAGS = flags.FLAGS
tf.app.run()