diff options
9 files changed, 268 insertions, 905 deletions
diff --git a/tensorflow/contrib/eager/python/examples/revnet/BUILD b/tensorflow/contrib/eager/python/examples/revnet/BUILD index 3316dc1114..0c0e4c0eb9 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/BUILD +++ b/tensorflow/contrib/eager/python/examples/revnet/BUILD @@ -113,39 +113,3 @@ py_binary( "//tensorflow:tensorflow_py", ], ) - -py_binary( - name = "main_estimator", - srcs = ["main_estimator.py"], - srcs_version = "PY2AND3", - deps = [ - ":cifar_input", - ":main", - ":revnet", - "//tensorflow:tensorflow_py", - ], -) - -py_library( - name = "main_estimator_lib", - srcs = ["main_estimator.py"], - srcs_version = "PY2AND3", - deps = [ - ":cifar_input", - ":main", - ":revnet", - "//tensorflow:tensorflow_py", - ], -) - -py_library( - name = "main_estimator_tpu_lib", - srcs = ["main_estimator_tpu.py"], - srcs_version = "PY2AND3", - deps = [ - ":cifar_input", - ":main", - ":revnet", - "//tensorflow:tensorflow_py", - ], -) diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks.py b/tensorflow/contrib/eager/python/examples/revnet/blocks.py index 639bb06a34..306096e9f8 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/blocks.py +++ b/tensorflow/contrib/eager/python/examples/revnet/blocks.py @@ -24,9 +24,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import functools -import operator - import tensorflow as tf from tensorflow.contrib.eager.python.examples.revnet import ops @@ -48,7 +45,7 @@ class RevBlock(tf.keras.Model): bottleneck=False, fused=True, dtype=tf.float32): - """Initialization. + """Initialize RevBlock. Args: n_res: number of residual blocks @@ -102,6 +99,7 @@ class RevBlock(tf.keras.Model): if i == 0: # First block usually contains downsampling that can't be reversed with tf.GradientTape() as tape: + x = tf.identity(x) tape.watch(x) y = block(x, training=training) @@ -123,6 +121,16 @@ class _Residual(tf.keras.Model): """Single residual block contained in a _RevBlock. Each `_Residual` object has two _ResidualInner objects, corresponding to the `F` and `G` functions in the paper. + + Args: + filters: output filter size + strides: length 2 list/tuple of integers for height and width strides + input_shape: length 3 list/tuple of integers + batch_norm_first: whether to apply activation and batch norm before conv + data_format: tensor data format, "NCHW"/"NHWC", + bottleneck: use bottleneck residual if True + fused: use fused batch normalization if True + dtype: float16, float32, or float64 """ def __init__(self, @@ -134,18 +142,6 @@ class _Residual(tf.keras.Model): bottleneck=False, fused=True, dtype=tf.float32): - """Initialization. - - Args: - filters: output filter size - strides: length 2 list/tuple of integers for height and width strides - input_shape: length 3 list/tuple of integers - batch_norm_first: whether to apply activation and batch norm before conv - data_format: tensor data format, "NCHW"/"NHWC", - bottleneck: use bottleneck residual if True - fused: use fused batch normalization if True - dtype: float16, float32, or float64 - """ super(_Residual, self).__init__() self.filters = filters @@ -200,6 +196,7 @@ class _Residual(tf.keras.Model): dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=self.axis) with tf.GradientTape(persistent=True) as tape: + y = tf.identity(y) tape.watch(y) y1, y2 = tf.split(y, num_or_size_splits=2, axis=self.axis) z1 = y1 @@ -230,252 +227,131 @@ class _Residual(tf.keras.Model): return x, dx, grads, vars_ -# Ideally, the following should be wrapped in `tf.keras.Sequential`, however -# there are subtle issues with its placeholder insertion policy and batch norm -class _BottleneckResidualInner(tf.keras.Model): +def _BottleneckResidualInner(filters, + strides, + input_shape, + batch_norm_first=True, + data_format="channels_first", + fused=True, + dtype=tf.float32): """Single bottleneck residual inner function contained in _Resdual. Corresponds to the `F`/`G` functions in the paper. Suitable for training on ImageNet dataset. - """ - - def __init__(self, - filters, - strides, - input_shape, - batch_norm_first=True, - data_format="channels_first", - fused=True, - dtype=tf.float32): - """Initialization. - - Args: - filters: output filter size - strides: length 2 list/tuple of integers for height and width strides - input_shape: length 3 list/tuple of integers - batch_norm_first: whether to apply activation and batch norm before conv - data_format: tensor data format, "NCHW"/"NHWC" - fused: use fused batch normalization if True - dtype: float16, float32, or float64 - """ - super(_BottleneckResidualInner, self).__init__() - axis = 1 if data_format == "channels_first" else 3 - if batch_norm_first: - self.batch_norm_0 = tf.keras.layers.BatchNormalization( - axis=axis, input_shape=input_shape, fused=fused, dtype=dtype) - - self.conv2d_1 = tf.keras.layers.Conv2D( - filters=filters // 4, - kernel_size=1, - strides=strides, - input_shape=input_shape, - data_format=data_format, - use_bias=False, - padding="SAME", - dtype=dtype) - self.batch_norm_1 = tf.keras.layers.BatchNormalization( - axis=axis, fused=fused, dtype=dtype) - - self.conv2d_2 = tf.keras.layers.Conv2D( - filters=filters // 4, - kernel_size=3, - strides=(1, 1), - data_format=data_format, - use_bias=False, - padding="SAME", - dtype=dtype) - self.batch_norm_2 = tf.keras.layers.BatchNormalization( - axis=axis, fused=fused, dtype=dtype) - self.conv2d_3 = tf.keras.layers.Conv2D( - filters=filters, - kernel_size=1, - strides=(1, 1), - data_format=data_format, - use_bias=False, - padding="SAME", - dtype=dtype) - - self.batch_norm_first = batch_norm_first - - def call(self, x, training=True): - net = x - if self.batch_norm_first: - net = self.batch_norm_0(net, training=training) - net = tf.nn.relu(net) - - net = self.conv2d_1(net) - net = self.batch_norm_1(net, training=training) - net = tf.nn.relu(net) - - net = self.conv2d_2(net) - net = self.batch_norm_2(net, training=training) - net = tf.nn.relu(net) + Args: + filters: output filter size + strides: length 2 list/tuple of integers for height and width strides + input_shape: length 3 list/tuple of integers + batch_norm_first: whether to apply activation and batch norm before conv + data_format: tensor data format, "NCHW"/"NHWC" + fused: use fused batch normalization if True + dtype: float16, float32, or float64 + + Returns: + A keras model + """ - net = self.conv2d_3(net) + axis = 1 if data_format == "channels_first" else 3 + model = tf.keras.Sequential() + if batch_norm_first: + model.add( + tf.keras.layers.BatchNormalization( + axis=axis, input_shape=input_shape, fused=fused, dtype=dtype)) + model.add(tf.keras.layers.Activation("relu")) + model.add( + tf.keras.layers.Conv2D( + filters=filters // 4, + kernel_size=1, + strides=strides, + input_shape=input_shape, + data_format=data_format, + use_bias=False, + padding="SAME", + dtype=dtype)) + + model.add( + tf.keras.layers.BatchNormalization(axis=axis, fused=fused, dtype=dtype)) + model.add(tf.keras.layers.Activation("relu")) + model.add( + tf.keras.layers.Conv2D( + filters=filters // 4, + kernel_size=3, + strides=(1, 1), + data_format=data_format, + use_bias=False, + padding="SAME", + dtype=dtype)) + + model.add( + tf.keras.layers.BatchNormalization(axis=axis, fused=fused, dtype=dtype)) + model.add(tf.keras.layers.Activation("relu")) + model.add( + tf.keras.layers.Conv2D( + filters=filters, + kernel_size=1, + strides=(1, 1), + data_format=data_format, + use_bias=False, + padding="SAME", + dtype=dtype)) - return net + return model -class _ResidualInner(tf.keras.Model): +def _ResidualInner(filters, + strides, + input_shape, + batch_norm_first=True, + data_format="channels_first", + fused=True, + dtype=tf.float32): """Single residual inner function contained in _ResdualBlock. Corresponds to the `F`/`G` functions in the paper. - """ - - def __init__(self, - filters, - strides, - input_shape, - batch_norm_first=True, - data_format="channels_first", - fused=True, - dtype=tf.float32): - """Initialization. - - Args: - filters: output filter size - strides: length 2 list/tuple of integers for height and width strides - input_shape: length 3 list/tuple of integers - batch_norm_first: whether to apply activation and batch norm before conv - data_format: tensor data format, "NCHW"/"NHWC" - fused: use fused batch normalization if True - dtype: float16, float32, or float64 - """ - super(_ResidualInner, self).__init__() - axis = 1 if data_format == "channels_first" else 3 - if batch_norm_first: - self.batch_norm_0 = tf.keras.layers.BatchNormalization( - axis=axis, input_shape=input_shape, fused=fused, dtype=dtype) - self.conv2d_1 = tf.keras.layers.Conv2D( - filters=filters, - kernel_size=3, - strides=strides, - input_shape=input_shape, - data_format=data_format, - use_bias=False, - padding="SAME", - dtype=dtype) - self.batch_norm_1 = tf.keras.layers.BatchNormalization( - axis=axis, fused=fused, dtype=dtype) - - self.conv2d_2 = tf.keras.layers.Conv2D( - filters=filters, - kernel_size=3, - strides=(1, 1), - data_format=data_format, - use_bias=False, - padding="SAME", - dtype=dtype) - - self.batch_norm_first = batch_norm_first - - def call(self, x, training=True): - net = x - if self.batch_norm_first: - net = self.batch_norm_0(net, training=training) - net = tf.nn.relu(net) - - net = self.conv2d_1(net) - net = self.batch_norm_1(net, training=training) - - net = self.conv2d_2(net) - - return net + Args: + filters: output filter size + strides: length 2 list/tuple of integers for height and width strides + input_shape: length 3 list/tuple of integers + batch_norm_first: whether to apply activation and batch norm before conv + data_format: tensor data format, "NCHW"/"NHWC" + fused: use fused batch normalization if True + dtype: float16, float32, or float64 + + Returns: + A keras model + """ -class InitBlock(tf.keras.Model): - """Initial block of RevNet.""" - - def __init__(self, config): - """Initialization. - - Args: - config: tf.contrib.training.HParams object; specifies hyperparameters - """ - super(InitBlock, self).__init__() - self.config = config - self.axis = 1 if self.config.data_format == "channels_first" else 3 - self.conv2d = tf.keras.layers.Conv2D( - filters=self.config.init_filters, - kernel_size=self.config.init_kernel, - strides=(self.config.init_stride, self.config.init_stride), - data_format=self.config.data_format, - use_bias=False, - padding="SAME", - input_shape=self.config.input_shape, - dtype=self.config.dtype) - self.batch_norm = tf.keras.layers.BatchNormalization( - axis=self.axis, fused=self.config.fused, dtype=self.config.dtype) - self.activation = tf.keras.layers.Activation("relu") - - if self.config.init_max_pool: - self.max_pool = tf.keras.layers.MaxPooling2D( - pool_size=(3, 3), - strides=(2, 2), + axis = 1 if data_format == "channels_first" else 3 + model = tf.keras.Sequential() + if batch_norm_first: + model.add( + tf.keras.layers.BatchNormalization( + axis=axis, input_shape=input_shape, fused=fused, dtype=dtype)) + model.add(tf.keras.layers.Activation("relu")) + model.add( + tf.keras.layers.Conv2D( + filters=filters, + kernel_size=3, + strides=strides, + input_shape=input_shape, + data_format=data_format, + use_bias=False, padding="SAME", - data_format=self.config.data_format, - dtype=self.config.dtype) - - def call(self, x, training=True): - net = x - net = self.conv2d(net) - net = self.batch_norm(net, training=training) - net = self.activation(net) - - if self.config.init_max_pool: - net = self.max_pool(net) - - return net - - -class FinalBlock(tf.keras.Model): - """Final block of RevNet.""" - - def __init__(self, config): - """Initialization. - - Args: - config: tf.contrib.training.HParams object; specifies hyperparameters + dtype=dtype)) + + model.add( + tf.keras.layers.BatchNormalization(axis=axis, fused=fused, dtype=dtype)) + model.add(tf.keras.layers.Activation("relu")) + model.add( + tf.keras.layers.Conv2D( + filters=filters, + kernel_size=3, + strides=(1, 1), + data_format=data_format, + use_bias=False, + padding="SAME", + dtype=dtype)) - Raises: - ValueError: Unsupported data format - """ - super(FinalBlock, self).__init__() - self.config = config - self.axis = 1 if self.config.data_format == "channels_first" else 3 - - f = self.config.filters[-1] # Number of filters - r = functools.reduce(operator.mul, self.config.strides, 1) # Reduce ratio - r *= self.config.init_stride - if self.config.init_max_pool: - r *= 2 - - if self.config.data_format == "channels_first": - w, h = self.config.input_shape[1], self.config.input_shape[2] - input_shape = (f, w // r, h // r) - elif self.config.data_format == "channels_last": - w, h = self.config.input_shape[0], self.config.input_shape[1] - input_shape = (w // r, h // r, f) - else: - raise ValueError("Data format should be either `channels_first`" - " or `channels_last`") - self.batch_norm = tf.keras.layers.BatchNormalization( - axis=self.axis, - input_shape=input_shape, - fused=self.config.fused, - dtype=self.config.dtype) - self.activation = tf.keras.layers.Activation("relu") - self.global_avg_pool = tf.keras.layers.GlobalAveragePooling2D( - data_format=self.config.data_format, dtype=self.config.dtype) - self.dense = tf.keras.layers.Dense( - self.config.n_classes, dtype=self.config.dtype) - - def call(self, x, training=True): - net = x - net = self.batch_norm(net, training=training) - net = self.activation(net) - net = self.global_avg_pool(net) - net = self.dense(net) - - return net + return model diff --git a/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py b/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py index e9672f13e1..b6d4c35bfd 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py +++ b/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py @@ -111,6 +111,6 @@ def get_ds_from_tfrecords(data_dir, }[split] dataset = dataset.shuffle(size) - dataset = dataset.batch(batch_size, drop_remainder=True) + dataset = dataset.batch(batch_size) return dataset diff --git a/tensorflow/contrib/eager/python/examples/revnet/config.py b/tensorflow/contrib/eager/python/examples/revnet/config.py index 1532c7b67b..3d93fa955a 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/config.py +++ b/tensorflow/contrib/eager/python/examples/revnet/config.py @@ -27,16 +27,17 @@ from __future__ import division from __future__ import print_function import tensorflow as tf +tfe = tf.contrib.eager def get_hparams_cifar_38(): """RevNet-38 configurations for CIFAR-10/CIFAR-100.""" config = tf.contrib.training.HParams() - # Hyperparameters from the RevNet paper config.add_hparam("init_filters", 32) config.add_hparam("init_kernel", 3) config.add_hparam("init_stride", 1) + config.add_hparam("n_classes", 10) config.add_hparam("n_rev_blocks", 3) config.add_hparam("n_res", [3, 3, 3]) config.add_hparam("filters", [32, 64, 112]) @@ -45,7 +46,7 @@ def get_hparams_cifar_38(): config.add_hparam("bottleneck", False) config.add_hparam("fused", True) config.add_hparam("init_max_pool", False) - if tf.test.is_gpu_available() > 0: + if tfe.num_gpus() > 0: config.add_hparam("input_shape", (3, 32, 32)) config.add_hparam("data_format", "channels_first") else: @@ -70,16 +71,6 @@ def get_hparams_cifar_38(): config.add_hparam("iters_per_epoch", 50000 // config.batch_size) config.add_hparam("epochs", config.max_train_iter // config.iters_per_epoch) - # Customized TPU hyperparameters due to differing batch size caused by - # TPU architecture specifics - # Suggested batch sizes to reduce overhead from excessive tensor padding - # https://cloud.google.com/tpu/docs/troubleshooting - config.add_hparam("tpu_batch_size", 128) - config.add_hparam("tpu_eval_batch_size", 1024) - config.add_hparam("tpu_iters_per_epoch", 50000 // config.tpu_batch_size) - config.add_hparam("tpu_epochs", - config.max_train_iter // config.tpu_iters_per_epoch) - return config @@ -110,6 +101,7 @@ def get_hparams_imagenet_56(): config.add_hparam("init_filters", 128) config.add_hparam("init_kernel", 7) config.add_hparam("init_stride", 2) + config.add_hparam("n_classes", 1000) config.add_hparam("n_rev_blocks", 4) config.add_hparam("n_res", [2, 2, 2, 2]) config.add_hparam("filters", [128, 256, 512, 832]) diff --git a/tensorflow/contrib/eager/python/examples/revnet/main.py b/tensorflow/contrib/eager/python/examples/revnet/main.py index 1a4fd45c8b..e2f43b03f9 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/main.py +++ b/tensorflow/contrib/eager/python/examples/revnet/main.py @@ -31,11 +31,8 @@ tfe = tf.contrib.eager def main(_): """Eager execution workflow with RevNet trained on CIFAR-10.""" - tf.enable_eager_execution() - - config = get_config(config_name=FLAGS.config, dataset=FLAGS.dataset) - ds_train, ds_train_one_shot, ds_validation, ds_test = get_datasets( - data_dir=FLAGS.data_dir, config=config) + config = get_config() + ds_train, ds_train_one_shot, ds_validation, ds_test = get_datasets(config) model = revnet.RevNet(config=config) global_step = tf.train.get_or_create_global_step() # Ensure correct summary global_step.assign(1) @@ -55,17 +52,23 @@ def main(_): "with global_step: {}".format(latest_path, global_step.numpy())) sys.stdout.flush() + if FLAGS.manual_grad: + print("Using manual gradients.") + else: + print("Not using manual gradients.") + sys.stdout.flush() + for x, y in ds_train: train_one_iter(model, x, y, optimizer, global_step=global_step) if global_step.numpy() % config.log_every == 0: + it_train = ds_train_one_shot.make_one_shot_iterator() it_test = ds_test.make_one_shot_iterator() + acc_train, loss_train = evaluate(model, it_train) acc_test, loss_test = evaluate(model, it_test) if FLAGS.validate: - it_train = ds_train_one_shot.make_one_shot_iterator() it_validation = ds_validation.make_one_shot_iterator() - acc_train, loss_train = evaluate(model, it_train) acc_validation, loss_validation = evaluate(model, it_validation) print("Iter {}, " "training set accuracy {:.4f}, loss {:.4f}; " @@ -74,8 +77,11 @@ def main(_): global_step.numpy(), acc_train, loss_train, acc_validation, loss_validation, acc_test, loss_test)) else: - print("Iter {}, test accuracy {:.4f}, loss {:.4f}".format( - global_step.numpy(), acc_test, loss_test)) + print("Iter {}, " + "training set accuracy {:.4f}, loss {:.4f}; " + "test accuracy {:.4f}, loss {:.4f}".format( + global_step.numpy(), acc_train, loss_train, acc_test, + loss_test)) sys.stdout.flush() if FLAGS.train_dir: @@ -97,38 +103,34 @@ def main(_): sys.stdout.flush() -def get_config(config_name="revnet-38", dataset="cifar-10"): +def get_config(): """Return configuration.""" - print("Config: {}".format(config_name)) + print("Config: {}".format(FLAGS.config)) sys.stdout.flush() config = { "revnet-38": config_.get_hparams_cifar_38(), "revnet-110": config_.get_hparams_cifar_110(), "revnet-164": config_.get_hparams_cifar_164(), - }[config_name] + }[FLAGS.config] - if dataset == "cifar-10": - config.add_hparam("n_classes", 10) - config.add_hparam("dataset", "cifar-10") - else: - config.add_hparam("n_classes", 100) - config.add_hparam("dataset", "cifar-100") + if FLAGS.dataset == "cifar-100": + config.n_classes = 100 return config -def get_datasets(data_dir, config): +def get_datasets(config): """Return dataset.""" - if data_dir is None: + if FLAGS.data_dir is None: raise ValueError("No supplied data directory") - if not os.path.exists(data_dir): - raise ValueError("Data directory {} does not exist".format(data_dir)) - if config.dataset not in ["cifar-10", "cifar-100"]: - raise ValueError("Unknown dataset {}".format(config.dataset)) + if not os.path.exists(FLAGS.data_dir): + raise ValueError("Data directory {} does not exist".format(FLAGS.data_dir)) + if FLAGS.dataset not in ["cifar-10", "cifar-100"]: + raise ValueError("Unknown dataset {}".format(FLAGS.dataset)) - print("Training on {} dataset.".format(config.dataset)) + print("Training on {} dataset.".format(FLAGS.dataset)) sys.stdout.flush() - data_dir = os.path.join(data_dir, config.dataset) + data_dir = os.path.join(FLAGS.data_dir, FLAGS.dataset) if FLAGS.validate: # 40k Training set ds_train = cifar_input.get_ds_from_tfrecords( @@ -166,7 +168,7 @@ def get_datasets(data_dir, config): prefetch=config.batch_size) ds_validation = None - # Always compute loss and accuracy on whole test set + # Always compute loss and accuracy on whole training and test set ds_train_one_shot = cifar_input.get_ds_from_tfrecords( data_dir=data_dir, split="train_all", @@ -194,11 +196,19 @@ def get_datasets(data_dir, config): def train_one_iter(model, inputs, labels, optimizer, global_step=None): """Train for one iteration.""" - grads, vars_, logits, loss = model.compute_gradients( - inputs, labels, training=True) - optimizer.apply_gradients(zip(grads, vars_), global_step=global_step) + if FLAGS.manual_grad: + grads, vars_, loss = model.compute_gradients(inputs, labels, training=True) + optimizer.apply_gradients(zip(grads, vars_), global_step=global_step) + else: # For correctness validation + with tf.GradientTape() as tape: + logits, _ = model(inputs, training=True) + loss = model.compute_loss(logits=logits, labels=labels) + tf.logging.info("Logits are placed on device: {}".format(logits.device)) + grads = tape.gradient(loss, model.trainable_variables) + optimizer.apply_gradients( + zip(grads, model.trainable_variables), global_step=global_step) - return logits, loss + return loss.numpy() def evaluate(model, iterator): @@ -231,14 +241,16 @@ if __name__ == "__main__": "validate", default=False, help="[Optional] Use the validation set or not for hyperparameter search") + flags.DEFINE_boolean( + "manual_grad", + default=False, + help="[Optional] Use manual gradient graph to save memory") flags.DEFINE_string( "dataset", default="cifar-10", help="[Optional] The dataset used; either `cifar-10` or `cifar-100`") flags.DEFINE_string( - "config", - default="revnet-38", - help="[Optional] Architecture of network. " - "Other options include `revnet-110` and `revnet-164`") + "config", default="revnet-38", help="[Optional] Architecture of network.") FLAGS = flags.FLAGS + tf.enable_eager_execution() tf.app.run(main) diff --git a/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py b/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py deleted file mode 100644 index c875e8da6d..0000000000 --- a/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py +++ /dev/null @@ -1,200 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Estimator workflow with RevNet train on CIFAR-10.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os - -from absl import flags -import tensorflow as tf -from tensorflow.contrib.eager.python.examples.revnet import cifar_input -from tensorflow.contrib.eager.python.examples.revnet import main as main_ -from tensorflow.contrib.eager.python.examples.revnet import revnet - - -def model_fn(features, labels, mode, params): - """Function specifying the model that is required by the `tf.estimator` API. - - Args: - features: Input images - labels: Labels of images - mode: One of `ModeKeys.TRAIN`, `ModeKeys.EVAL` or 'ModeKeys.PREDICT' - params: A dictionary of extra parameter that might be passed - - Returns: - An instance of `tf.estimator.EstimatorSpec` - """ - - inputs = features - if isinstance(inputs, dict): - inputs = features["image"] - - config = params["config"] - model = revnet.RevNet(config=config) - - if mode == tf.estimator.ModeKeys.TRAIN: - global_step = tf.train.get_or_create_global_step() - learning_rate = tf.train.piecewise_constant( - global_step, config.lr_decay_steps, config.lr_list) - optimizer = tf.train.MomentumOptimizer( - learning_rate, momentum=config.momentum) - grads, vars_, logits, loss = model.compute_gradients( - inputs, labels, training=True) - train_op = optimizer.apply_gradients( - zip(grads, vars_), global_step=global_step) - - return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) - else: - logits, _ = model(inputs, training=False) - predictions = tf.argmax(logits, axis=1) - probabilities = tf.nn.softmax(logits) - loss = model.compute_loss(labels=labels, logits=logits) - - if mode == tf.estimator.ModeKeys.EVAL: - return tf.estimator.EstimatorSpec( - mode=mode, - loss=loss, - eval_metric_ops={ - "accuracy": - tf.metrics.accuracy(labels=labels, predictions=predictions) - }) - - else: # mode == tf.estimator.ModeKeys.PREDICT - result = { - "classes": predictions, - "probabilities": probabilities, - } - - return tf.estimator.EstimatorSpec( - mode=mode, - predictions=predictions, - export_outputs={ - "classify": tf.estimator.export.PredictOutput(result) - }) - - -def get_input_fn(config, data_dir, split): - """Get the input function that is required by the `tf.estimator` API. - - Args: - config: Customized hyperparameters - data_dir: Directory where the data is stored - split: One of `train`, `validation`, `train_all`, and `test` - - Returns: - Input function required by the `tf.estimator` API - """ - - data_dir = os.path.join(data_dir, config.dataset) - # Fix split-dependent hyperparameters - if split == "train_all" or split == "train": - data_aug = True - batch_size = config.batch_size - epochs = config.epochs - shuffle = True - prefetch = config.batch_size - else: - data_aug = False - batch_size = config.eval_batch_size - epochs = 1 - shuffle = False - prefetch = config.eval_batch_size - - def input_fn(): - """Input function required by the `tf.estimator.Estimator` API.""" - return cifar_input.get_ds_from_tfrecords( - data_dir=data_dir, - split=split, - data_aug=data_aug, - batch_size=batch_size, - epochs=epochs, - shuffle=shuffle, - prefetch=prefetch, - data_format=config.data_format) - - return input_fn - - -def main(argv): - FLAGS = argv[0] # pylint:disable=invalid-name,redefined-outer-name - tf.logging.set_verbosity(tf.logging.INFO) - - # RevNet specific configuration - config = main_.get_config(config_name=FLAGS.config, dataset=FLAGS.dataset) - - # Estimator specific configuration - run_config = tf.estimator.RunConfig( - model_dir=FLAGS.train_dir, # Directory for storing checkpoints - tf_random_seed=config.seed, - save_summary_steps=config.log_every, - save_checkpoints_steps=config.log_every, - session_config=None, # Using default - keep_checkpoint_max=100, - keep_checkpoint_every_n_hours=10000, # Using default - log_step_count_steps=config.log_every, - train_distribute=None # Default not use distribution strategy - ) - - # Construct estimator - revnet_estimator = tf.estimator.Estimator( - model_fn=model_fn, - model_dir=FLAGS.train_dir, - config=run_config, - params={"config": config}) - - # Construct input functions - train_input_fn = get_input_fn( - config=config, data_dir=FLAGS.data_dir, split="train_all") - eval_input_fn = get_input_fn( - config=config, data_dir=FLAGS.data_dir, split="test") - - # Train and evaluate estimator - revnet_estimator.train(input_fn=train_input_fn) - revnet_estimator.evaluate(input_fn=eval_input_fn) - - if FLAGS.export: - input_shape = (None,) + config.input_shape - inputs = tf.placeholder(tf.float32, shape=input_shape) - input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({ - "image": inputs - }) - revnet_estimator.export_savedmodel(FLAGS.train_dir, input_fn) - - -if __name__ == "__main__": - flags.DEFINE_string( - "data_dir", default=None, help="Directory to load tfrecords") - flags.DEFINE_string( - "train_dir", - default=None, - help="[Optional] Directory to store the training information") - flags.DEFINE_string( - "dataset", - default="cifar-10", - help="[Optional] The dataset used; either `cifar-10` or `cifar-100`") - flags.DEFINE_boolean( - "export", - default=False, - help="[Optional] Export the model for serving if True") - flags.DEFINE_string( - "config", - default="revnet-38", - help="[Optional] Architecture of network. " - "Other options include `revnet-110` and `revnet-164`") - FLAGS = flags.FLAGS - tf.app.run(main=main, argv=[FLAGS]) diff --git a/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py b/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py deleted file mode 100644 index f1e1e530df..0000000000 --- a/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py +++ /dev/null @@ -1,328 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Cloud TPU Estimator workflow with RevNet train on CIFAR-10.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import time - -from absl import flags -import tensorflow as tf -from tensorflow.contrib.eager.python.examples.revnet import cifar_input -from tensorflow.contrib.eager.python.examples.revnet import main as main_ -from tensorflow.contrib.eager.python.examples.revnet import revnet -from tensorflow.contrib.training.python.training import evaluation -from tensorflow.python.estimator import estimator as estimator_ - - -def model_fn(features, labels, mode, params): - """Model function required by the `tf.contrib.tpu.TPUEstimator` API. - - Args: - features: Input images - labels: Labels of images - mode: One of `ModeKeys.TRAIN`, `ModeKeys.EVAL` or 'ModeKeys.PREDICT' - params: A dictionary of extra parameter that might be passed - - Returns: - An instance of `tf.contrib.tpu.TPUEstimatorSpec` - """ - - inputs = features - if isinstance(inputs, dict): - inputs = features["image"] - - FLAGS = params["FLAGS"] # pylint:disable=invalid-name,redefined-outer-name - config = params["config"] - model = revnet.RevNet(config=config) - - if mode == tf.estimator.ModeKeys.TRAIN: - global_step = tf.train.get_or_create_global_step() - learning_rate = tf.train.piecewise_constant( - global_step, config.lr_decay_steps, config.lr_list) - optimizer = tf.train.MomentumOptimizer( - learning_rate, momentum=config.momentum) - - if FLAGS.use_tpu: - optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) - - # Define gradients - grads, vars_, logits, loss = model.compute_gradients( - inputs, labels, training=True) - train_op = optimizer.apply_gradients( - zip(grads, vars_), global_step=global_step) - - names = [v.name for v in model.variables] - tf.logging.warn("{}".format(names)) - - return tf.contrib.tpu.TPUEstimatorSpec( - mode=tf.estimator.ModeKeys.TRAIN, loss=loss, train_op=train_op) - - if mode == tf.estimator.ModeKeys.EVAL: - logits, _ = model(inputs, training=False) - loss = model.compute_loss(labels=labels, logits=logits) - - def metric_fn(labels, logits): - predictions = tf.argmax(logits, axis=1) - accuracy = tf.metrics.accuracy(labels=labels, predictions=predictions) - return { - "accuracy": accuracy, - } - - return tf.contrib.tpu.TPUEstimatorSpec( - mode=mode, loss=loss, eval_metrics=(metric_fn, [labels, logits])) - - if mode == tf.estimator.ModeKeys.PREDICT: - logits, _ = model(inputs, training=False) - predictions = { - "classes": tf.argmax(logits, axis=1), - "probabilities": tf.nn.softmax(logits), - } - - return tf.contrib.tpu.TPUEstimatorSpec( - mode=mode, - predictions=predictions, - export_outputs={ - "classify": tf.estimator.export.PredictOutput(predictions) - }) - - -def get_input_fn(config, data_dir, split): - """Get the input function required by the `tf.contrib.tpu.TPUEstimator` API. - - Args: - config: Customized hyperparameters - data_dir: Directory where the data is stored - split: One of `train`, `validation`, `train_all`, and `test` - - Returns: - Input function required by the `tf.contrib.tpu.TPUEstimator` API - """ - - data_dir = os.path.join(data_dir, config.dataset) - # Fix split-dependent hyperparameters - if split == "train_all" or split == "train": - data_aug = True - epochs = config.tpu_epochs - shuffle = True - else: - data_aug = False - epochs = 1 - shuffle = False - - def input_fn(params): - """Input function required by the `tf.contrib.tpu.TPUEstimator` API.""" - batch_size = params["batch_size"] - return cifar_input.get_ds_from_tfrecords( - data_dir=data_dir, - split=split, - data_aug=data_aug, - batch_size=batch_size, # per-shard batch size - epochs=epochs, - shuffle=shuffle, - prefetch=batch_size, # per-shard batch size - data_format=config.data_format) - - return input_fn - - -def main(argv): - FLAGS = argv[0] # pylint:disable=invalid-name,redefined-outer-name - tf.logging.set_verbosity(tf.logging.INFO) - - # RevNet specific configuration - config = main_.get_config(config_name=FLAGS.config, dataset=FLAGS.dataset) - - if FLAGS.use_tpu: - tf.logging.info("Using TPU.") - tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( - FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) - else: - tpu_cluster_resolver = None - - # TPU specific configuration - tpu_config = tf.contrib.tpu.TPUConfig( - # Recommended to be set as number of global steps for next checkpoint - iterations_per_loop=FLAGS.iterations_per_loop, - num_shards=FLAGS.num_shards) - - # Estimator specific configuration - run_config = tf.contrib.tpu.RunConfig( - cluster=tpu_cluster_resolver, - model_dir=FLAGS.model_dir, - session_config=tf.ConfigProto( - allow_soft_placement=True, log_device_placement=False), - tpu_config=tpu_config, - ) - - # Construct TPU Estimator - estimator = tf.contrib.tpu.TPUEstimator( - model_fn=model_fn, - use_tpu=FLAGS.use_tpu, - train_batch_size=config.tpu_batch_size, - eval_batch_size=config.tpu_eval_batch_size, - config=run_config, - params={ - "FLAGS": FLAGS, - "config": config, - }) - - # Construct input functions - train_input_fn = get_input_fn( - config=config, data_dir=FLAGS.data_dir, split="train_all") - eval_input_fn = get_input_fn( - config=config, data_dir=FLAGS.data_dir, split="test") - - # Disabling a range within an else block currently doesn't work - # due to https://github.com/PyCQA/pylint/issues/872 - # pylint: disable=protected-access - if FLAGS.mode == "eval": - # TPUEstimator.evaluate *requires* a steps argument. - # Note that the number of examples used during evaluation is - # --eval_steps * --batch_size. - # So if you change --batch_size then change --eval_steps too. - eval_steps = 10000 // config.tpu_eval_batch_size - - # Run evaluation when there's a new checkpoint - for ckpt in evaluation.checkpoints_iterator( - FLAGS.model_dir, timeout=FLAGS.eval_timeout): - tf.logging.info("Starting to evaluate.") - try: - start_timestamp = time.time() # This time will include compilation time - eval_results = estimator.evaluate( - input_fn=eval_input_fn, steps=eval_steps, checkpoint_path=ckpt) - elapsed_time = int(time.time() - start_timestamp) - tf.logging.info("Eval results: %s. Elapsed seconds: %d" % - (eval_results, elapsed_time)) - - # Terminate eval job when final checkpoint is reached - current_step = int(os.path.basename(ckpt).split("-")[1]) - if current_step >= config.max_train_iter: - tf.logging.info( - "Evaluation finished after training step %d" % current_step) - break - - except tf.errors.NotFoundError: - # Since the coordinator is on a different job than the TPU worker, - # sometimes the TPU worker does not finish initializing until long after - # the CPU job tells it to start evaluating. In this case, the checkpoint - # file could have been deleted already. - tf.logging.info( - "Checkpoint %s no longer exists, skipping checkpoint" % ckpt) - - else: # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval' - current_step = estimator_._load_global_step_from_checkpoint_dir( - FLAGS.model_dir) - tf.logging.info("Training for %d steps . Current" - " step %d." % (config.max_train_iter, current_step)) - - start_timestamp = time.time() # This time will include compilation time - if FLAGS.mode == "train": - estimator.train(input_fn=train_input_fn, max_steps=config.max_train_iter) - else: - eval_steps = 10000 // config.tpu_eval_batch_size - assert FLAGS.mode == "train_and_eval" - while current_step < config.max_train_iter: - # Train for up to steps_per_eval number of steps. - # At the end of training, a checkpoint will be written to --model_dir. - next_checkpoint = min(current_step + FLAGS.steps_per_eval, - config.max_train_iter) - estimator.train(input_fn=train_input_fn, max_steps=next_checkpoint) - current_step = next_checkpoint - - # Evaluate the model on the most recent model in --model_dir. - # Since evaluation happens in batches of --eval_batch_size, some images - # may be consistently excluded modulo the batch size. - tf.logging.info("Starting to evaluate.") - eval_results = estimator.evaluate( - input_fn=eval_input_fn, steps=eval_steps) - tf.logging.info("Eval results: %s" % eval_results) - - elapsed_time = int(time.time() - start_timestamp) - tf.logging.info("Finished training up to step %d. Elapsed seconds %d." % - (config.max_train_iter, elapsed_time)) - # pylint: enable=protected-access - - -if __name__ == "__main__": - # Cloud TPU Cluster Resolver flags - flags.DEFINE_string( - "tpu", - default=None, - help="The Cloud TPU to use for training. This should be either the name " - "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " - "url.") - flags.DEFINE_string( - "tpu_zone", - default=None, - help="[Optional] GCE zone where the Cloud TPU is located in. If not " - "specified, we will attempt to automatically detect the GCE project from " - "metadata.") - flags.DEFINE_string( - "gcp_project", - default=None, - help="[Optional] Project name for the Cloud TPU-enabled project. If not " - "specified, we will attempt to automatically detect the GCE project from " - "metadata.") - - # Model specific parameters - flags.DEFINE_string( - "data_dir", default=None, help="Directory to load tfrecords") - flags.DEFINE_string( - "model_dir", - default=None, - help="[Optional] Directory to store the model information") - flags.DEFINE_string( - "dataset", - default="cifar-10", - help="[Optional] The dataset used; either `cifar-10` or `cifar-100`") - flags.DEFINE_string( - "config", - default="revnet-38", - help="[Optional] Architecture of network. " - "Other options include `revnet-110` and `revnet-164`") - flags.DEFINE_boolean( - "use_tpu", default=True, help="[Optional] Whether to use TPU") - flags.DEFINE_integer( - "num_shards", default=8, help="Number of shards (TPU chips).") - flags.DEFINE_integer( - "iterations_per_loop", - default=100, - help=( - "Number of steps to run on TPU before feeding metrics to the CPU." - " If the number of iterations in the loop would exceed the number of" - " train steps, the loop will exit before reaching" - " --iterations_per_loop. The larger this value is, the higher the" - " utilization on the TPU.")) - flags.DEFINE_string( - "mode", - default="train_and_eval", - help="[Optional] Mode to run: train, eval, train_and_eval") - flags.DEFINE_integer( - "eval_timeout", 60 * 60 * 24, - "Maximum seconds between checkpoints before evaluation terminates.") - flags.DEFINE_integer( - "steps_per_eval", - default=1000, - help=( - "Controls how often evaluation is performed. Since evaluation is" - " fairly expensive, it is advised to evaluate as infrequently as" - " possible (i.e. up to --train_steps, which evaluates the model only" - " after finishing the entire training regime).")) - FLAGS = flags.FLAGS - tf.app.run(main=main, argv=[FLAGS]) diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet.py b/tensorflow/contrib/eager/python/examples/revnet/revnet.py index a3c2f7dbec..af0d20fa72 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/revnet.py +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet.py @@ -24,6 +24,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools +import operator + import six import tensorflow as tf from tensorflow.contrib.eager.python.examples.revnet import blocks @@ -42,9 +45,71 @@ class RevNet(tf.keras.Model): self.axis = 1 if config.data_format == "channels_first" else 3 self.config = config - self._init_block = blocks.InitBlock(config=self.config) - self._final_block = blocks.FinalBlock(config=self.config) + self._init_block = self._construct_init_block() self._block_list = self._construct_intermediate_blocks() + self._final_block = self._construct_final_block() + + def _construct_init_block(self): + init_block = tf.keras.Sequential( + [ + tf.keras.layers.Conv2D( + filters=self.config.init_filters, + kernel_size=self.config.init_kernel, + strides=(self.config.init_stride, self.config.init_stride), + data_format=self.config.data_format, + use_bias=False, + padding="SAME", + input_shape=self.config.input_shape, + dtype=self.config.dtype), + tf.keras.layers.BatchNormalization( + axis=self.axis, + fused=self.config.fused, + dtype=self.config.dtype), + tf.keras.layers.Activation("relu"), + ], + name="init") + if self.config.init_max_pool: + init_block.add( + tf.keras.layers.MaxPooling2D( + pool_size=(3, 3), + strides=(2, 2), + padding="SAME", + data_format=self.config.data_format, + dtype=self.config.dtype)) + return init_block + + def _construct_final_block(self): + f = self.config.filters[-1] # Number of filters + r = functools.reduce(operator.mul, self.config.strides, 1) # Reduce ratio + r *= self.config.init_stride + if self.config.init_max_pool: + r *= 2 + + if self.config.data_format == "channels_first": + w, h = self.config.input_shape[1], self.config.input_shape[2] + input_shape = (f, w // r, h // r) + elif self.config.data_format == "channels_last": + w, h = self.config.input_shape[0], self.config.input_shape[1] + input_shape = (w // r, h // r, f) + else: + raise ValueError("Data format should be either `channels_first`" + " or `channels_last`") + + final_block = tf.keras.Sequential( + [ + tf.keras.layers.BatchNormalization( + axis=self.axis, + input_shape=input_shape, + fused=self.config.fused, + dtype=self.config.dtype), + tf.keras.layers.Activation("relu"), + tf.keras.layers.GlobalAveragePooling2D( + data_format=self.config.data_format, dtype=self.config.dtype), + tf.keras.layers.Dense( + self.config.n_classes, dtype=self.config.dtype) + ], + name="final") + return final_block def _construct_intermediate_blocks(self): # Precompute input shape after initial block @@ -141,20 +206,13 @@ class RevNet(tf.keras.Model): l2_reg: Apply l2 regularization Returns: - A tuple with the first entry being a list of all gradients, the second - entry being a list of respective variables, the third being the logits, - and the forth being the loss + list of tuples each being (grad, var) for optimizer to use """ - # Run forward pass to record hidden states + # Run forward pass to record hidden states; avoid updating running averages vars_and_vals = self.get_moving_stats() _, saved_hidden = self.call(inputs, training=training) - if tf.executing_eagerly(): - # Restore moving averages when executing eagerly to avoid updating twice - self.restore_moving_stats(vars_and_vals) - else: - # Fetch batch norm updates in graph mode - updates = self.get_updates_for(inputs) + self.restore_moving_stats(vars_and_vals) grads_all = [] vars_all = [] @@ -162,8 +220,9 @@ class RevNet(tf.keras.Model): # Manually backprop through last block x = saved_hidden[-1] with tf.GradientTape() as tape: + x = tf.identity(x) tape.watch(x) - # Running stats updated here + # Running stats updated below logits = self._final_block(x, training=training) loss = self.compute_loss(logits, labels) @@ -177,7 +236,6 @@ class RevNet(tf.keras.Model): for block in reversed(self._block_list): y = saved_hidden.pop() x = saved_hidden[-1] - # Running stats updated here dy, grads, vars_ = block.backward_grads_and_vars( x, y, dy, training=training) grads_all += grads @@ -189,7 +247,8 @@ class RevNet(tf.keras.Model): assert not saved_hidden # Cleared after backprop with tf.GradientTape() as tape: - # Running stats updated here + x = tf.identity(x) + # Running stats updated below y = self._init_block(x, training=training) grads_all += tape.gradient( @@ -200,13 +259,7 @@ class RevNet(tf.keras.Model): if l2_reg: grads_all = self._apply_weight_decay(grads_all, vars_all) - if not tf.executing_eagerly(): - # Force updates to be executed before gradient computation in graph mode - # This does nothing when the function is wrapped in defun - with tf.control_dependencies(updates): - grads_all[0] = tf.identity(grads_all[0]) - - return grads_all, vars_all, logits, loss + return grads_all, vars_all, loss def _apply_weight_decay(self, grads, vars_): """Update gradients to reflect weight decay.""" @@ -231,10 +284,8 @@ class RevNet(tf.keras.Model): n = v.name return n.endswith("moving_mean:0") or n.endswith("moving_variance:0") - device = "/gpu:0" if tf.test.is_gpu_available() else "/cpu:0" - with tf.device(device): - for v in filter(_is_moving_var, self.variables): - vars_and_vals[v] = v.read_value() + for v in filter(_is_moving_var, self.variables): + vars_and_vals[v] = v.read_value() return vars_and_vals @@ -246,8 +297,5 @@ class RevNet(tf.keras.Model): Args: vars_and_vals: The dictionary mapping variables to their previous values. """ - device = "/gpu:0" if tf.test.is_gpu_available() else "/cpu:0" - with tf.device(device): - for var_, val in six.iteritems(vars_and_vals): - # `assign` causes a copy to GPU (if variable is already on GPU) - var_.assign(val) + for var_, val in six.iteritems(vars_and_vals): + var_.assign(val) diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py index 26b0847523..b0d0a5486d 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py @@ -31,11 +31,10 @@ tfe = tf.contrib.eager def train_one_iter(model, inputs, labels, optimizer, global_step=None): """Train for one iteration.""" - grads, vars_, logits, loss = model.compute_gradients( - inputs, labels, training=True) + grads, vars_, loss = model.compute_gradients(inputs, labels, training=True) optimizer.apply_gradients(zip(grads, vars_), global_step=global_step) - return logits, loss + return loss class RevNetTest(tf.test.TestCase): @@ -43,8 +42,6 @@ class RevNetTest(tf.test.TestCase): def setUp(self): super(RevNetTest, self).setUp() config = config_.get_hparams_cifar_38() - config.add_hparam("n_classes", 10) - config.add_hparam("dataset", "cifar-10") # Reconstruction could cause numerical error, use double precision for tests config.dtype = tf.float64 config.fused = False # Fused batch norm does not support tf.float64 @@ -97,7 +94,7 @@ class RevNetTest(tf.test.TestCase): def test_compute_gradients(self): """Test `compute_gradients` function.""" self.model(self.x, training=False) # Initialize model - grads, vars_, logits, loss = self.model.compute_gradients( + grads, vars_, loss = self.model.compute_gradients( inputs=self.x, labels=self.t, training=True, l2_reg=True) self.assertTrue(isinstance(grads, list)) self.assertTrue(isinstance(vars_, list)) @@ -122,7 +119,7 @@ class RevNetTest(tf.test.TestCase): def test_compute_gradients_defun(self): """Test `compute_gradients` function with defun.""" compute_gradients = tfe.defun(self.model.compute_gradients) - grads, vars_, _, _ = compute_gradients(self.x, self.t, training=True) + grads, vars_, _ = compute_gradients(self.x, self.t, training=True) self.assertTrue(isinstance(grads, list)) self.assertTrue(isinstance(vars_, list)) self.assertEqual(len(grads), len(vars_)) @@ -134,9 +131,6 @@ class RevNetTest(tf.test.TestCase): """Test model training in graph mode.""" with tf.Graph().as_default(): config = config_.get_hparams_cifar_38() - config.add_hparam("n_classes", 10) - config.add_hparam("dataset", "cifar-10") - x = tf.random_normal( shape=(self.config.batch_size,) + self.config.input_shape) t = tf.random_uniform( @@ -146,10 +140,15 @@ class RevNetTest(tf.test.TestCase): dtype=tf.int32) global_step = tf.Variable(0., trainable=False) model = revnet.RevNet(config=config) - grads_all, vars_all, _, _ = model.compute_gradients(x, t, training=True) + model(x) + updates = model.get_updates_for(x) + + x_ = tf.identity(x) + grads_all, vars_all, _ = model.compute_gradients(x_, t, training=True) optimizer = tf.train.AdamOptimizer(learning_rate=1e-3) - train_op = optimizer.apply_gradients( - zip(grads_all, vars_all), global_step=global_step) + with tf.control_dependencies(updates): + train_op = optimizer.apply_gradients( + zip(grads_all, vars_all), global_step=global_step) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) |