diff options
author | 2018-07-31 14:19:22 -0700 | |
---|---|---|
committer | 2018-07-31 14:26:21 -0700 | |
commit | 3fda31fe7d17d808c18e53186beb54b457088587 (patch) | |
tree | fb27ab6284ebbe5d7e987da3d9813402d3b2e541 /tensorflow | |
parent | 48f486b4b4ad23dbfddbd9527f184ef25bdc2421 (diff) |
Update TPU estimator script to work with ImageNet.
PiperOrigin-RevId: 206825954
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py | 335 |
1 files changed, 206 insertions, 129 deletions
diff --git a/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py b/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py index f0aad9b110..6300021824 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py +++ b/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py @@ -12,22 +12,90 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Cloud TPU Estimator workflow with RevNet train on CIFAR-10.""" +"""Cloud TPU Estimator workflow with RevNet train on ImageNet.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os import time from absl import flags import tensorflow as tf -from tensorflow.contrib.eager.python.examples.revnet import cifar_input -from tensorflow.contrib.eager.python.examples.revnet import main as main_ +from tensorflow.contrib import summary +from tensorflow.contrib.eager.python.examples.revnet import config as config_ +from tensorflow.contrib.eager.python.examples.revnet import imagenet_input from tensorflow.contrib.eager.python.examples.revnet import revnet from tensorflow.contrib.training.python.training import evaluation -from tensorflow.python.estimator import estimator as estimator_ +from tensorflow.python.estimator import estimator + +MEAN_RGB = [0.485, 0.456, 0.406] +STDDEV_RGB = [0.229, 0.224, 0.225] + + +def _host_call_fn(gs, loss, lr): + """Training host call. + + Creates scalar summaries for training metrics. + + This function is executed on the CPU and should not directly reference + any Tensors in the rest of the `model_fn`. To pass Tensors from the + model to the `metric_fn`, provide as part of the `host_call`. See + https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec + for more information. + + Arguments should match the list of `Tensor` objects passed as the second + element in the tuple passed to `host_call`. + + Args: + gs: `Tensor with shape `[batch]` for the global_step + loss: `Tensor` with shape `[batch]` for the training loss. + lr: `Tensor` with shape `[batch]` for the learning_rate. + + Returns: + List of summary ops to run on the CPU host. + """ + # Host call fns are executed FLAGS.iterations_per_loop times after one + # TPU loop is finished, setting max_queue value to the same as number of + # iterations will make the summary writer only flush the data to storage + # once per loop. + gs = gs[0] + with summary.create_file_writer( + FLAGS.model_dir, max_queue=FLAGS.iterations_per_loop).as_default(): + with summary.always_record_summaries(): + summary.scalar("loss", loss[0], step=gs) + summary.scalar("learning_rate", lr[0], step=gs) + return summary.all_summary_ops() + + +def _metric_fn(labels, logits): + """Evaluation metric function. Evaluates accuracy. + + This function is executed on the CPU and should not directly reference + any Tensors in the rest of the `model_fn`. To pass Tensors from the model + to the `metric_fn`, provide as part of the `eval_metrics`. See + https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec + for more information. + + Arguments should match the list of `Tensor` objects passed as the second + element in the tuple passed to `eval_metrics`. + + Args: + labels: `Tensor` with shape `[batch]`. + logits: `Tensor` with shape `[batch, num_classes]`. + + Returns: + A dict of the metrics to return from evaluation. + """ + predictions = tf.argmax(logits, axis=1) + top_1_accuracy = tf.metrics.accuracy(labels, predictions) + in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32) + top_5_accuracy = tf.metrics.mean(in_top_5) + + return { + "top_1_accuracy": top_1_accuracy, + "top_5_accuracy": top_5_accuracy, + } def model_fn(features, labels, mode, params): @@ -42,45 +110,60 @@ def model_fn(features, labels, mode, params): Returns: An instance of `tf.contrib.tpu.TPUEstimatorSpec` """ + tf.logging.info("features: {}".format(features.dtype)) + tf.logging.info("labels: {}".format(labels.dtype)) + revnet_config = params["revnet_config"] + model = revnet.RevNet(config=revnet_config) inputs = features if isinstance(inputs, dict): inputs = features["image"] - config = params["config"] - model = revnet.RevNet(config=config) + if revnet_config.data_format == "channels_first": + assert not FLAGS.transpose_input # channels_first only for GPU + inputs = tf.transpose(inputs, [0, 3, 1, 2]) + + if FLAGS.transpose_input and mode != tf.estimator.ModeKeys.PREDICT: + inputs = tf.transpose(inputs, [3, 0, 1, 2]) # HWCN to NHWC + + # Normalize the image to zero mean and unit variance. + inputs -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=inputs.dtype) + inputs /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=inputs.dtype) if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.train.get_or_create_global_step() learning_rate = tf.train.piecewise_constant( - global_step, config.lr_decay_steps, config.lr_list) - optimizer = tf.train.MomentumOptimizer( - learning_rate, momentum=config.momentum) - + global_step, revnet_config.lr_decay_steps, revnet_config.lr_list) + optimizer = tf.train.MomentumOptimizer(learning_rate, + revnet_config.momentum) if FLAGS.use_tpu: optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) logits, saved_hidden = model(inputs, training=True) grads, loss = model.compute_gradients(saved_hidden, labels, training=True) - train_op = optimizer.apply_gradients( - zip(grads, model.trainable_variables), global_step=global_step) + with tf.control_dependencies(model.get_updates_for(inputs)): + train_op = optimizer.apply_gradients( + zip(grads, model.trainable_variables), global_step=global_step) + if not FLAGS.skip_host_call: + # To log the loss, current learning rate, and epoch for Tensorboard, the + # summary op needs to be run on the host CPU via host_call. host_call + # expects [batch_size, ...] Tensors, thus reshape to introduce a batch + # dimension. These Tensors are implicitly concatenated to + # [params['batch_size']]. + gs_t = tf.reshape(global_step, [1]) + loss_t = tf.reshape(loss, [1]) + lr_t = tf.reshape(learning_rate, [1]) + host_call = (_host_call_fn, [gs_t, loss_t, lr_t]) return tf.contrib.tpu.TPUEstimatorSpec( - mode=tf.estimator.ModeKeys.TRAIN, loss=loss, train_op=train_op) + mode=mode, loss=loss, train_op=train_op, host_call=host_call) elif mode == tf.estimator.ModeKeys.EVAL: logits, _ = model(inputs, training=False) loss = model.compute_loss(labels=labels, logits=logits) - def metric_fn(labels, logits): - predictions = tf.argmax(logits, axis=1) - accuracy = tf.metrics.accuracy(labels=labels, predictions=predictions) - return { - "accuracy": accuracy, - } - return tf.contrib.tpu.TPUEstimatorSpec( - mode=mode, loss=loss, eval_metrics=(metric_fn, [labels, logits])) + mode=mode, loss=loss, eval_metrics=(_metric_fn, [labels, logits])) else: # Predict or export logits, _ = model(inputs, training=False) @@ -97,113 +180,75 @@ def model_fn(features, labels, mode, params): }) -def get_input_fn(config, data_dir, split): - """Get the input function required by the `tf.contrib.tpu.TPUEstimator` API. - - Args: - config: Customized hyperparameters - data_dir: Directory where the data is stored - split: One of `train`, `validation`, `train_all`, and `test` - - Returns: - Input function required by the `tf.contrib.tpu.TPUEstimator` API - """ - - data_dir = os.path.join(data_dir, config.dataset) - # Fix split-dependent hyperparameters - if split == "train_all" or split == "train": - data_aug = True - epochs = config.tpu_epochs - shuffle = True - else: - data_aug = False - epochs = 1 - shuffle = False - - def input_fn(params): - """Input function required by the `tf.contrib.tpu.TPUEstimator` API.""" - batch_size = params["batch_size"] - return cifar_input.get_ds_from_tfrecords( - data_dir=data_dir, - split=split, - data_aug=data_aug, - batch_size=batch_size, # per-shard batch size - epochs=epochs, - shuffle=shuffle, - prefetch=batch_size, # per-shard batch size - data_format=config.data_format) - - return input_fn - - def main(_): tf.logging.set_verbosity(tf.logging.INFO) # RevNet specific configuration - config = main_.get_config(config_name=FLAGS.config, dataset=FLAGS.dataset) + revnet_config = { + "revnet-56": config_.get_hparams_imagenet_56(), + "revnet-104": config_.get_hparams_imagenet_104() + }[FLAGS.revnet_config] if FLAGS.use_tpu: - tf.logging.info("Using TPU.") - tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( - FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) - else: - tpu_cluster_resolver = None - - # TPU specific configuration - tpu_config = tf.contrib.tpu.TPUConfig( - # Recommended to be set as number of global steps for next checkpoint - iterations_per_loop=FLAGS.iterations_per_loop, - num_shards=FLAGS.num_shards) + revnet_config.data_format = "channels_last" + + tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( + FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) # Estimator specific configuration - run_config = tf.contrib.tpu.RunConfig( + config = tf.contrib.tpu.RunConfig( cluster=tpu_cluster_resolver, model_dir=FLAGS.model_dir, session_config=tf.ConfigProto( - allow_soft_placement=True, log_device_placement=False), - tpu_config=tpu_config, + allow_soft_placement=True, log_device_placement=True), + tpu_config=tf.contrib.tpu.TPUConfig( + iterations_per_loop=FLAGS.iterations_per_loop, + num_shards=FLAGS.num_shards, + per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig. + PER_HOST_V2), ) - # Construct TPU Estimator - estimator = tf.contrib.tpu.TPUEstimator( + # Input pipelines are slightly different (with regards to shuffling and + # preprocessing) between training and evaluation. + imagenet_train, imagenet_eval = [ + imagenet_input.ImageNetInput( + is_training=is_training, + data_dir=FLAGS.data_dir, + transpose_input=FLAGS.transpose_input, + use_bfloat16=False) for is_training in [True, False] + ] + + revnet_classifier = tf.contrib.tpu.TPUEstimator( model_fn=model_fn, use_tpu=FLAGS.use_tpu, - train_batch_size=config.tpu_batch_size, - eval_batch_size=config.tpu_eval_batch_size, - config=run_config, - params={"config": config}) - - # Construct input functions - train_input_fn = get_input_fn( - config=config, data_dir=FLAGS.data_dir, split="train_all") - eval_input_fn = get_input_fn( - config=config, data_dir=FLAGS.data_dir, split="test") - - # Disabling a range within an else block currently doesn't work - # due to https://github.com/PyCQA/pylint/issues/872 + train_batch_size=revnet_config.tpu_batch_size, + eval_batch_size=revnet_config.tpu_eval_batch_size, + config=config, + export_to_tpu=False, + params={"revnet_config": revnet_config}) + + steps_per_epoch = revnet_config.tpu_iters_per_epoch + eval_steps = revnet_config.tpu_eval_steps + # pylint: disable=protected-access if FLAGS.mode == "eval": - # TPUEstimator.evaluate *requires* a steps argument. - # Note that the number of examples used during evaluation is - # --eval_steps * --batch_size. - # So if you change --batch_size then change --eval_steps too. - eval_steps = 10000 // config.tpu_eval_batch_size - # Run evaluation when there's a new checkpoint for ckpt in evaluation.checkpoints_iterator( FLAGS.model_dir, timeout=FLAGS.eval_timeout): tf.logging.info("Starting to evaluate.") try: start_timestamp = time.time() # This time will include compilation time - eval_results = estimator.evaluate( - input_fn=eval_input_fn, steps=eval_steps, checkpoint_path=ckpt) + eval_results = revnet_classifier.evaluate( + input_fn=imagenet_eval.input_fn, + steps=eval_steps, + checkpoint_path=ckpt) elapsed_time = int(time.time() - start_timestamp) tf.logging.info("Eval results: %s. Elapsed seconds: %d" % (eval_results, elapsed_time)) # Terminate eval job when final checkpoint is reached current_step = int(os.path.basename(ckpt).split("-")[1]) - if current_step >= config.max_train_iter: + if current_step >= revnet_config.max_train_iter: tf.logging.info( "Evaluation finished after training step %d" % current_step) break @@ -217,37 +262,56 @@ def main(_): "Checkpoint %s no longer exists, skipping checkpoint" % ckpt) else: # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval' - current_step = estimator_._load_global_step_from_checkpoint_dir( + current_step = estimator._load_global_step_from_checkpoint_dir( FLAGS.model_dir) - tf.logging.info("Training for %d steps . Current" - " step %d." % (config.max_train_iter, current_step)) + + tf.logging.info( + "Training for %d steps (%.2f epochs in total). Current" + " step %d." % (revnet_config.max_train_iter, + revnet_config.max_train_iter / steps_per_epoch, + current_step)) start_timestamp = time.time() # This time will include compilation time + if FLAGS.mode == "train": - estimator.train(input_fn=train_input_fn, max_steps=config.max_train_iter) + revnet_classifier.train( + input_fn=imagenet_train.input_fn, + max_steps=revnet_config.max_train_iter) + else: - eval_steps = 10000 // config.tpu_eval_batch_size assert FLAGS.mode == "train_and_eval" - while current_step < config.max_train_iter: + while current_step < revnet_config.max_train_iter: # Train for up to steps_per_eval number of steps. # At the end of training, a checkpoint will be written to --model_dir. next_checkpoint = min(current_step + FLAGS.steps_per_eval, - config.max_train_iter) - estimator.train(input_fn=train_input_fn, max_steps=next_checkpoint) + revnet_config.max_train_iter) + revnet_classifier.train( + input_fn=imagenet_train.input_fn, max_steps=next_checkpoint) current_step = next_checkpoint + tf.logging.info("Finished training up to step %d. Elapsed seconds %d." % + (next_checkpoint, int(time.time() - start_timestamp))) + # Evaluate the model on the most recent model in --model_dir. # Since evaluation happens in batches of --eval_batch_size, some images - # may be consistently excluded modulo the batch size. + # may be excluded modulo the batch size. As long as the batch size is + # consistent, the evaluated images are also consistent. tf.logging.info("Starting to evaluate.") - eval_results = estimator.evaluate( - input_fn=eval_input_fn, steps=eval_steps) + eval_results = revnet_classifier.evaluate( + input_fn=imagenet_eval.input_fn, steps=eval_steps) tf.logging.info("Eval results: %s" % eval_results) - elapsed_time = int(time.time() - start_timestamp) - tf.logging.info("Finished training up to step %d. Elapsed seconds %d." % - (config.max_train_iter, elapsed_time)) - # pylint: enable=protected-access + elapsed_time = int(time.time() - start_timestamp) + tf.logging.info("Finished training up to step %d. Elapsed seconds %d." % + (revnet_config.max_train_iter, elapsed_time)) + + if FLAGS.export_dir is not None: + # The guide to serve an exported TensorFlow model is at: + # https://www.tensorflow.org/serving/serving_basic + tf.logging.info("Starting to export model.") + revnet_classifier.export_savedmodel( + export_dir_base=FLAGS.export_dir, + serving_input_receiver_fn=imagenet_input.image_serving_input_fn) if __name__ == "__main__": @@ -279,14 +343,10 @@ if __name__ == "__main__": default=None, help="[Optional] Directory to store the model information") flags.DEFINE_string( - "dataset", - default="cifar-10", - help="[Optional] The dataset used; either `cifar-10` or `cifar-100`") - flags.DEFINE_string( - "config", - default="revnet-38", + "revnet_config", + default="revnet-56", help="[Optional] Architecture of network. " - "Other options include `revnet-110` and `revnet-164`") + "Other options include `revnet-104`") flags.DEFINE_boolean( "use_tpu", default=True, help="[Optional] Whether to use TPU") flags.DEFINE_integer( @@ -300,20 +360,37 @@ if __name__ == "__main__": " train steps, the loop will exit before reaching" " --iterations_per_loop. The larger this value is, the higher the" " utilization on the TPU.")) - flags.DEFINE_string( - "mode", - default="train_and_eval", - help="[Optional] Mode to run: train, eval, train_and_eval") flags.DEFINE_integer( - "eval_timeout", 60 * 60 * 24, - "Maximum seconds between checkpoints before evaluation terminates.") + "eval_timeout", + default=None, + help="Maximum seconds between checkpoints before evaluation terminates.") flags.DEFINE_integer( "steps_per_eval", - default=1000, + default=5000, help=( "Controls how often evaluation is performed. Since evaluation is" " fairly expensive, it is advised to evaluate as infrequently as" " possible (i.e. up to --train_steps, which evaluates the model only" " after finishing the entire training regime).")) + flags.DEFINE_bool( + "transpose_input", + default=True, + help="Use TPU double transpose optimization") + flags.DEFINE_string( + "export_dir", + default=None, + help=("The directory where the exported SavedModel will be stored.")) + flags.DEFINE_bool( + "skip_host_call", + default=False, + help=("Skip the host_call which is executed every training step. This is" + " generally used for generating training summaries (train loss," + " learning rate, etc...). When --skip_host_call=false, there could" + " be a performance drop if host_call function is slow and cannot" + " keep up with the TPU-side computation.")) + flags.DEFINE_string( + "mode", + default="train_and_eval", + help='One of {"train_and_eval", "train", "eval"}.') FLAGS = flags.FLAGS tf.app.run() |