aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-20 16:39:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-20 16:46:25 -0700
commit8741006018326350467fe86785d98963ff9e983e (patch)
treeb8ff67c88a10fe80cf32e1fa3fcd2086961d7ab5
parent0cc0166a97f95499f0af673f3004d6bb748dc7e4 (diff)
Automated rollback of commit 265292420de30f24805d28886d403dc42d3685b3
PiperOrigin-RevId: 205472990
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/BUILD36
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/blocks.py374
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/cifar_input.py2
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/config.py16
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/main.py82
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/main_estimator.py200
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py328
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/revnet.py110
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/revnet_test.py25
9 files changed, 268 insertions, 905 deletions
diff --git a/tensorflow/contrib/eager/python/examples/revnet/BUILD b/tensorflow/contrib/eager/python/examples/revnet/BUILD
index 3316dc1114..0c0e4c0eb9 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/BUILD
+++ b/tensorflow/contrib/eager/python/examples/revnet/BUILD
@@ -113,39 +113,3 @@ py_binary(
"//tensorflow:tensorflow_py",
],
)
-
-py_binary(
- name = "main_estimator",
- srcs = ["main_estimator.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":cifar_input",
- ":main",
- ":revnet",
- "//tensorflow:tensorflow_py",
- ],
-)
-
-py_library(
- name = "main_estimator_lib",
- srcs = ["main_estimator.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":cifar_input",
- ":main",
- ":revnet",
- "//tensorflow:tensorflow_py",
- ],
-)
-
-py_library(
- name = "main_estimator_tpu_lib",
- srcs = ["main_estimator_tpu.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":cifar_input",
- ":main",
- ":revnet",
- "//tensorflow:tensorflow_py",
- ],
-)
diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks.py b/tensorflow/contrib/eager/python/examples/revnet/blocks.py
index 639bb06a34..306096e9f8 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/blocks.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/blocks.py
@@ -24,9 +24,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import functools
-import operator
-
import tensorflow as tf
from tensorflow.contrib.eager.python.examples.revnet import ops
@@ -48,7 +45,7 @@ class RevBlock(tf.keras.Model):
bottleneck=False,
fused=True,
dtype=tf.float32):
- """Initialization.
+ """Initialize RevBlock.
Args:
n_res: number of residual blocks
@@ -102,6 +99,7 @@ class RevBlock(tf.keras.Model):
if i == 0:
# First block usually contains downsampling that can't be reversed
with tf.GradientTape() as tape:
+ x = tf.identity(x)
tape.watch(x)
y = block(x, training=training)
@@ -123,6 +121,16 @@ class _Residual(tf.keras.Model):
"""Single residual block contained in a _RevBlock. Each `_Residual` object has
two _ResidualInner objects, corresponding to the `F` and `G` functions in the
paper.
+
+ Args:
+ filters: output filter size
+ strides: length 2 list/tuple of integers for height and width strides
+ input_shape: length 3 list/tuple of integers
+ batch_norm_first: whether to apply activation and batch norm before conv
+ data_format: tensor data format, "NCHW"/"NHWC",
+ bottleneck: use bottleneck residual if True
+ fused: use fused batch normalization if True
+ dtype: float16, float32, or float64
"""
def __init__(self,
@@ -134,18 +142,6 @@ class _Residual(tf.keras.Model):
bottleneck=False,
fused=True,
dtype=tf.float32):
- """Initialization.
-
- Args:
- filters: output filter size
- strides: length 2 list/tuple of integers for height and width strides
- input_shape: length 3 list/tuple of integers
- batch_norm_first: whether to apply activation and batch norm before conv
- data_format: tensor data format, "NCHW"/"NHWC",
- bottleneck: use bottleneck residual if True
- fused: use fused batch normalization if True
- dtype: float16, float32, or float64
- """
super(_Residual, self).__init__()
self.filters = filters
@@ -200,6 +196,7 @@ class _Residual(tf.keras.Model):
dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=self.axis)
with tf.GradientTape(persistent=True) as tape:
+ y = tf.identity(y)
tape.watch(y)
y1, y2 = tf.split(y, num_or_size_splits=2, axis=self.axis)
z1 = y1
@@ -230,252 +227,131 @@ class _Residual(tf.keras.Model):
return x, dx, grads, vars_
-# Ideally, the following should be wrapped in `tf.keras.Sequential`, however
-# there are subtle issues with its placeholder insertion policy and batch norm
-class _BottleneckResidualInner(tf.keras.Model):
+def _BottleneckResidualInner(filters,
+ strides,
+ input_shape,
+ batch_norm_first=True,
+ data_format="channels_first",
+ fused=True,
+ dtype=tf.float32):
"""Single bottleneck residual inner function contained in _Resdual.
Corresponds to the `F`/`G` functions in the paper.
Suitable for training on ImageNet dataset.
- """
-
- def __init__(self,
- filters,
- strides,
- input_shape,
- batch_norm_first=True,
- data_format="channels_first",
- fused=True,
- dtype=tf.float32):
- """Initialization.
-
- Args:
- filters: output filter size
- strides: length 2 list/tuple of integers for height and width strides
- input_shape: length 3 list/tuple of integers
- batch_norm_first: whether to apply activation and batch norm before conv
- data_format: tensor data format, "NCHW"/"NHWC"
- fused: use fused batch normalization if True
- dtype: float16, float32, or float64
- """
- super(_BottleneckResidualInner, self).__init__()
- axis = 1 if data_format == "channels_first" else 3
- if batch_norm_first:
- self.batch_norm_0 = tf.keras.layers.BatchNormalization(
- axis=axis, input_shape=input_shape, fused=fused, dtype=dtype)
-
- self.conv2d_1 = tf.keras.layers.Conv2D(
- filters=filters // 4,
- kernel_size=1,
- strides=strides,
- input_shape=input_shape,
- data_format=data_format,
- use_bias=False,
- padding="SAME",
- dtype=dtype)
- self.batch_norm_1 = tf.keras.layers.BatchNormalization(
- axis=axis, fused=fused, dtype=dtype)
-
- self.conv2d_2 = tf.keras.layers.Conv2D(
- filters=filters // 4,
- kernel_size=3,
- strides=(1, 1),
- data_format=data_format,
- use_bias=False,
- padding="SAME",
- dtype=dtype)
- self.batch_norm_2 = tf.keras.layers.BatchNormalization(
- axis=axis, fused=fused, dtype=dtype)
- self.conv2d_3 = tf.keras.layers.Conv2D(
- filters=filters,
- kernel_size=1,
- strides=(1, 1),
- data_format=data_format,
- use_bias=False,
- padding="SAME",
- dtype=dtype)
-
- self.batch_norm_first = batch_norm_first
-
- def call(self, x, training=True):
- net = x
- if self.batch_norm_first:
- net = self.batch_norm_0(net, training=training)
- net = tf.nn.relu(net)
-
- net = self.conv2d_1(net)
- net = self.batch_norm_1(net, training=training)
- net = tf.nn.relu(net)
-
- net = self.conv2d_2(net)
- net = self.batch_norm_2(net, training=training)
- net = tf.nn.relu(net)
+ Args:
+ filters: output filter size
+ strides: length 2 list/tuple of integers for height and width strides
+ input_shape: length 3 list/tuple of integers
+ batch_norm_first: whether to apply activation and batch norm before conv
+ data_format: tensor data format, "NCHW"/"NHWC"
+ fused: use fused batch normalization if True
+ dtype: float16, float32, or float64
+
+ Returns:
+ A keras model
+ """
- net = self.conv2d_3(net)
+ axis = 1 if data_format == "channels_first" else 3
+ model = tf.keras.Sequential()
+ if batch_norm_first:
+ model.add(
+ tf.keras.layers.BatchNormalization(
+ axis=axis, input_shape=input_shape, fused=fused, dtype=dtype))
+ model.add(tf.keras.layers.Activation("relu"))
+ model.add(
+ tf.keras.layers.Conv2D(
+ filters=filters // 4,
+ kernel_size=1,
+ strides=strides,
+ input_shape=input_shape,
+ data_format=data_format,
+ use_bias=False,
+ padding="SAME",
+ dtype=dtype))
+
+ model.add(
+ tf.keras.layers.BatchNormalization(axis=axis, fused=fused, dtype=dtype))
+ model.add(tf.keras.layers.Activation("relu"))
+ model.add(
+ tf.keras.layers.Conv2D(
+ filters=filters // 4,
+ kernel_size=3,
+ strides=(1, 1),
+ data_format=data_format,
+ use_bias=False,
+ padding="SAME",
+ dtype=dtype))
+
+ model.add(
+ tf.keras.layers.BatchNormalization(axis=axis, fused=fused, dtype=dtype))
+ model.add(tf.keras.layers.Activation("relu"))
+ model.add(
+ tf.keras.layers.Conv2D(
+ filters=filters,
+ kernel_size=1,
+ strides=(1, 1),
+ data_format=data_format,
+ use_bias=False,
+ padding="SAME",
+ dtype=dtype))
- return net
+ return model
-class _ResidualInner(tf.keras.Model):
+def _ResidualInner(filters,
+ strides,
+ input_shape,
+ batch_norm_first=True,
+ data_format="channels_first",
+ fused=True,
+ dtype=tf.float32):
"""Single residual inner function contained in _ResdualBlock.
Corresponds to the `F`/`G` functions in the paper.
- """
-
- def __init__(self,
- filters,
- strides,
- input_shape,
- batch_norm_first=True,
- data_format="channels_first",
- fused=True,
- dtype=tf.float32):
- """Initialization.
-
- Args:
- filters: output filter size
- strides: length 2 list/tuple of integers for height and width strides
- input_shape: length 3 list/tuple of integers
- batch_norm_first: whether to apply activation and batch norm before conv
- data_format: tensor data format, "NCHW"/"NHWC"
- fused: use fused batch normalization if True
- dtype: float16, float32, or float64
- """
- super(_ResidualInner, self).__init__()
- axis = 1 if data_format == "channels_first" else 3
- if batch_norm_first:
- self.batch_norm_0 = tf.keras.layers.BatchNormalization(
- axis=axis, input_shape=input_shape, fused=fused, dtype=dtype)
- self.conv2d_1 = tf.keras.layers.Conv2D(
- filters=filters,
- kernel_size=3,
- strides=strides,
- input_shape=input_shape,
- data_format=data_format,
- use_bias=False,
- padding="SAME",
- dtype=dtype)
- self.batch_norm_1 = tf.keras.layers.BatchNormalization(
- axis=axis, fused=fused, dtype=dtype)
-
- self.conv2d_2 = tf.keras.layers.Conv2D(
- filters=filters,
- kernel_size=3,
- strides=(1, 1),
- data_format=data_format,
- use_bias=False,
- padding="SAME",
- dtype=dtype)
-
- self.batch_norm_first = batch_norm_first
-
- def call(self, x, training=True):
- net = x
- if self.batch_norm_first:
- net = self.batch_norm_0(net, training=training)
- net = tf.nn.relu(net)
-
- net = self.conv2d_1(net)
- net = self.batch_norm_1(net, training=training)
-
- net = self.conv2d_2(net)
-
- return net
+ Args:
+ filters: output filter size
+ strides: length 2 list/tuple of integers for height and width strides
+ input_shape: length 3 list/tuple of integers
+ batch_norm_first: whether to apply activation and batch norm before conv
+ data_format: tensor data format, "NCHW"/"NHWC"
+ fused: use fused batch normalization if True
+ dtype: float16, float32, or float64
+
+ Returns:
+ A keras model
+ """
-class InitBlock(tf.keras.Model):
- """Initial block of RevNet."""
-
- def __init__(self, config):
- """Initialization.
-
- Args:
- config: tf.contrib.training.HParams object; specifies hyperparameters
- """
- super(InitBlock, self).__init__()
- self.config = config
- self.axis = 1 if self.config.data_format == "channels_first" else 3
- self.conv2d = tf.keras.layers.Conv2D(
- filters=self.config.init_filters,
- kernel_size=self.config.init_kernel,
- strides=(self.config.init_stride, self.config.init_stride),
- data_format=self.config.data_format,
- use_bias=False,
- padding="SAME",
- input_shape=self.config.input_shape,
- dtype=self.config.dtype)
- self.batch_norm = tf.keras.layers.BatchNormalization(
- axis=self.axis, fused=self.config.fused, dtype=self.config.dtype)
- self.activation = tf.keras.layers.Activation("relu")
-
- if self.config.init_max_pool:
- self.max_pool = tf.keras.layers.MaxPooling2D(
- pool_size=(3, 3),
- strides=(2, 2),
+ axis = 1 if data_format == "channels_first" else 3
+ model = tf.keras.Sequential()
+ if batch_norm_first:
+ model.add(
+ tf.keras.layers.BatchNormalization(
+ axis=axis, input_shape=input_shape, fused=fused, dtype=dtype))
+ model.add(tf.keras.layers.Activation("relu"))
+ model.add(
+ tf.keras.layers.Conv2D(
+ filters=filters,
+ kernel_size=3,
+ strides=strides,
+ input_shape=input_shape,
+ data_format=data_format,
+ use_bias=False,
padding="SAME",
- data_format=self.config.data_format,
- dtype=self.config.dtype)
-
- def call(self, x, training=True):
- net = x
- net = self.conv2d(net)
- net = self.batch_norm(net, training=training)
- net = self.activation(net)
-
- if self.config.init_max_pool:
- net = self.max_pool(net)
-
- return net
-
-
-class FinalBlock(tf.keras.Model):
- """Final block of RevNet."""
-
- def __init__(self, config):
- """Initialization.
-
- Args:
- config: tf.contrib.training.HParams object; specifies hyperparameters
+ dtype=dtype))
+
+ model.add(
+ tf.keras.layers.BatchNormalization(axis=axis, fused=fused, dtype=dtype))
+ model.add(tf.keras.layers.Activation("relu"))
+ model.add(
+ tf.keras.layers.Conv2D(
+ filters=filters,
+ kernel_size=3,
+ strides=(1, 1),
+ data_format=data_format,
+ use_bias=False,
+ padding="SAME",
+ dtype=dtype))
- Raises:
- ValueError: Unsupported data format
- """
- super(FinalBlock, self).__init__()
- self.config = config
- self.axis = 1 if self.config.data_format == "channels_first" else 3
-
- f = self.config.filters[-1] # Number of filters
- r = functools.reduce(operator.mul, self.config.strides, 1) # Reduce ratio
- r *= self.config.init_stride
- if self.config.init_max_pool:
- r *= 2
-
- if self.config.data_format == "channels_first":
- w, h = self.config.input_shape[1], self.config.input_shape[2]
- input_shape = (f, w // r, h // r)
- elif self.config.data_format == "channels_last":
- w, h = self.config.input_shape[0], self.config.input_shape[1]
- input_shape = (w // r, h // r, f)
- else:
- raise ValueError("Data format should be either `channels_first`"
- " or `channels_last`")
- self.batch_norm = tf.keras.layers.BatchNormalization(
- axis=self.axis,
- input_shape=input_shape,
- fused=self.config.fused,
- dtype=self.config.dtype)
- self.activation = tf.keras.layers.Activation("relu")
- self.global_avg_pool = tf.keras.layers.GlobalAveragePooling2D(
- data_format=self.config.data_format, dtype=self.config.dtype)
- self.dense = tf.keras.layers.Dense(
- self.config.n_classes, dtype=self.config.dtype)
-
- def call(self, x, training=True):
- net = x
- net = self.batch_norm(net, training=training)
- net = self.activation(net)
- net = self.global_avg_pool(net)
- net = self.dense(net)
-
- return net
+ return model
diff --git a/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py b/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py
index e9672f13e1..b6d4c35bfd 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py
@@ -111,6 +111,6 @@ def get_ds_from_tfrecords(data_dir,
}[split]
dataset = dataset.shuffle(size)
- dataset = dataset.batch(batch_size, drop_remainder=True)
+ dataset = dataset.batch(batch_size)
return dataset
diff --git a/tensorflow/contrib/eager/python/examples/revnet/config.py b/tensorflow/contrib/eager/python/examples/revnet/config.py
index 1532c7b67b..3d93fa955a 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/config.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/config.py
@@ -27,16 +27,17 @@ from __future__ import division
from __future__ import print_function
import tensorflow as tf
+tfe = tf.contrib.eager
def get_hparams_cifar_38():
"""RevNet-38 configurations for CIFAR-10/CIFAR-100."""
config = tf.contrib.training.HParams()
- # Hyperparameters from the RevNet paper
config.add_hparam("init_filters", 32)
config.add_hparam("init_kernel", 3)
config.add_hparam("init_stride", 1)
+ config.add_hparam("n_classes", 10)
config.add_hparam("n_rev_blocks", 3)
config.add_hparam("n_res", [3, 3, 3])
config.add_hparam("filters", [32, 64, 112])
@@ -45,7 +46,7 @@ def get_hparams_cifar_38():
config.add_hparam("bottleneck", False)
config.add_hparam("fused", True)
config.add_hparam("init_max_pool", False)
- if tf.test.is_gpu_available() > 0:
+ if tfe.num_gpus() > 0:
config.add_hparam("input_shape", (3, 32, 32))
config.add_hparam("data_format", "channels_first")
else:
@@ -70,16 +71,6 @@ def get_hparams_cifar_38():
config.add_hparam("iters_per_epoch", 50000 // config.batch_size)
config.add_hparam("epochs", config.max_train_iter // config.iters_per_epoch)
- # Customized TPU hyperparameters due to differing batch size caused by
- # TPU architecture specifics
- # Suggested batch sizes to reduce overhead from excessive tensor padding
- # https://cloud.google.com/tpu/docs/troubleshooting
- config.add_hparam("tpu_batch_size", 128)
- config.add_hparam("tpu_eval_batch_size", 1024)
- config.add_hparam("tpu_iters_per_epoch", 50000 // config.tpu_batch_size)
- config.add_hparam("tpu_epochs",
- config.max_train_iter // config.tpu_iters_per_epoch)
-
return config
@@ -110,6 +101,7 @@ def get_hparams_imagenet_56():
config.add_hparam("init_filters", 128)
config.add_hparam("init_kernel", 7)
config.add_hparam("init_stride", 2)
+ config.add_hparam("n_classes", 1000)
config.add_hparam("n_rev_blocks", 4)
config.add_hparam("n_res", [2, 2, 2, 2])
config.add_hparam("filters", [128, 256, 512, 832])
diff --git a/tensorflow/contrib/eager/python/examples/revnet/main.py b/tensorflow/contrib/eager/python/examples/revnet/main.py
index 1a4fd45c8b..e2f43b03f9 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/main.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/main.py
@@ -31,11 +31,8 @@ tfe = tf.contrib.eager
def main(_):
"""Eager execution workflow with RevNet trained on CIFAR-10."""
- tf.enable_eager_execution()
-
- config = get_config(config_name=FLAGS.config, dataset=FLAGS.dataset)
- ds_train, ds_train_one_shot, ds_validation, ds_test = get_datasets(
- data_dir=FLAGS.data_dir, config=config)
+ config = get_config()
+ ds_train, ds_train_one_shot, ds_validation, ds_test = get_datasets(config)
model = revnet.RevNet(config=config)
global_step = tf.train.get_or_create_global_step() # Ensure correct summary
global_step.assign(1)
@@ -55,17 +52,23 @@ def main(_):
"with global_step: {}".format(latest_path, global_step.numpy()))
sys.stdout.flush()
+ if FLAGS.manual_grad:
+ print("Using manual gradients.")
+ else:
+ print("Not using manual gradients.")
+ sys.stdout.flush()
+
for x, y in ds_train:
train_one_iter(model, x, y, optimizer, global_step=global_step)
if global_step.numpy() % config.log_every == 0:
+ it_train = ds_train_one_shot.make_one_shot_iterator()
it_test = ds_test.make_one_shot_iterator()
+ acc_train, loss_train = evaluate(model, it_train)
acc_test, loss_test = evaluate(model, it_test)
if FLAGS.validate:
- it_train = ds_train_one_shot.make_one_shot_iterator()
it_validation = ds_validation.make_one_shot_iterator()
- acc_train, loss_train = evaluate(model, it_train)
acc_validation, loss_validation = evaluate(model, it_validation)
print("Iter {}, "
"training set accuracy {:.4f}, loss {:.4f}; "
@@ -74,8 +77,11 @@ def main(_):
global_step.numpy(), acc_train, loss_train, acc_validation,
loss_validation, acc_test, loss_test))
else:
- print("Iter {}, test accuracy {:.4f}, loss {:.4f}".format(
- global_step.numpy(), acc_test, loss_test))
+ print("Iter {}, "
+ "training set accuracy {:.4f}, loss {:.4f}; "
+ "test accuracy {:.4f}, loss {:.4f}".format(
+ global_step.numpy(), acc_train, loss_train, acc_test,
+ loss_test))
sys.stdout.flush()
if FLAGS.train_dir:
@@ -97,38 +103,34 @@ def main(_):
sys.stdout.flush()
-def get_config(config_name="revnet-38", dataset="cifar-10"):
+def get_config():
"""Return configuration."""
- print("Config: {}".format(config_name))
+ print("Config: {}".format(FLAGS.config))
sys.stdout.flush()
config = {
"revnet-38": config_.get_hparams_cifar_38(),
"revnet-110": config_.get_hparams_cifar_110(),
"revnet-164": config_.get_hparams_cifar_164(),
- }[config_name]
+ }[FLAGS.config]
- if dataset == "cifar-10":
- config.add_hparam("n_classes", 10)
- config.add_hparam("dataset", "cifar-10")
- else:
- config.add_hparam("n_classes", 100)
- config.add_hparam("dataset", "cifar-100")
+ if FLAGS.dataset == "cifar-100":
+ config.n_classes = 100
return config
-def get_datasets(data_dir, config):
+def get_datasets(config):
"""Return dataset."""
- if data_dir is None:
+ if FLAGS.data_dir is None:
raise ValueError("No supplied data directory")
- if not os.path.exists(data_dir):
- raise ValueError("Data directory {} does not exist".format(data_dir))
- if config.dataset not in ["cifar-10", "cifar-100"]:
- raise ValueError("Unknown dataset {}".format(config.dataset))
+ if not os.path.exists(FLAGS.data_dir):
+ raise ValueError("Data directory {} does not exist".format(FLAGS.data_dir))
+ if FLAGS.dataset not in ["cifar-10", "cifar-100"]:
+ raise ValueError("Unknown dataset {}".format(FLAGS.dataset))
- print("Training on {} dataset.".format(config.dataset))
+ print("Training on {} dataset.".format(FLAGS.dataset))
sys.stdout.flush()
- data_dir = os.path.join(data_dir, config.dataset)
+ data_dir = os.path.join(FLAGS.data_dir, FLAGS.dataset)
if FLAGS.validate:
# 40k Training set
ds_train = cifar_input.get_ds_from_tfrecords(
@@ -166,7 +168,7 @@ def get_datasets(data_dir, config):
prefetch=config.batch_size)
ds_validation = None
- # Always compute loss and accuracy on whole test set
+ # Always compute loss and accuracy on whole training and test set
ds_train_one_shot = cifar_input.get_ds_from_tfrecords(
data_dir=data_dir,
split="train_all",
@@ -194,11 +196,19 @@ def get_datasets(data_dir, config):
def train_one_iter(model, inputs, labels, optimizer, global_step=None):
"""Train for one iteration."""
- grads, vars_, logits, loss = model.compute_gradients(
- inputs, labels, training=True)
- optimizer.apply_gradients(zip(grads, vars_), global_step=global_step)
+ if FLAGS.manual_grad:
+ grads, vars_, loss = model.compute_gradients(inputs, labels, training=True)
+ optimizer.apply_gradients(zip(grads, vars_), global_step=global_step)
+ else: # For correctness validation
+ with tf.GradientTape() as tape:
+ logits, _ = model(inputs, training=True)
+ loss = model.compute_loss(logits=logits, labels=labels)
+ tf.logging.info("Logits are placed on device: {}".format(logits.device))
+ grads = tape.gradient(loss, model.trainable_variables)
+ optimizer.apply_gradients(
+ zip(grads, model.trainable_variables), global_step=global_step)
- return logits, loss
+ return loss.numpy()
def evaluate(model, iterator):
@@ -231,14 +241,16 @@ if __name__ == "__main__":
"validate",
default=False,
help="[Optional] Use the validation set or not for hyperparameter search")
+ flags.DEFINE_boolean(
+ "manual_grad",
+ default=False,
+ help="[Optional] Use manual gradient graph to save memory")
flags.DEFINE_string(
"dataset",
default="cifar-10",
help="[Optional] The dataset used; either `cifar-10` or `cifar-100`")
flags.DEFINE_string(
- "config",
- default="revnet-38",
- help="[Optional] Architecture of network. "
- "Other options include `revnet-110` and `revnet-164`")
+ "config", default="revnet-38", help="[Optional] Architecture of network.")
FLAGS = flags.FLAGS
+ tf.enable_eager_execution()
tf.app.run(main)
diff --git a/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py b/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py
deleted file mode 100644
index c875e8da6d..0000000000
--- a/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py
+++ /dev/null
@@ -1,200 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Estimator workflow with RevNet train on CIFAR-10."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-
-from absl import flags
-import tensorflow as tf
-from tensorflow.contrib.eager.python.examples.revnet import cifar_input
-from tensorflow.contrib.eager.python.examples.revnet import main as main_
-from tensorflow.contrib.eager.python.examples.revnet import revnet
-
-
-def model_fn(features, labels, mode, params):
- """Function specifying the model that is required by the `tf.estimator` API.
-
- Args:
- features: Input images
- labels: Labels of images
- mode: One of `ModeKeys.TRAIN`, `ModeKeys.EVAL` or 'ModeKeys.PREDICT'
- params: A dictionary of extra parameter that might be passed
-
- Returns:
- An instance of `tf.estimator.EstimatorSpec`
- """
-
- inputs = features
- if isinstance(inputs, dict):
- inputs = features["image"]
-
- config = params["config"]
- model = revnet.RevNet(config=config)
-
- if mode == tf.estimator.ModeKeys.TRAIN:
- global_step = tf.train.get_or_create_global_step()
- learning_rate = tf.train.piecewise_constant(
- global_step, config.lr_decay_steps, config.lr_list)
- optimizer = tf.train.MomentumOptimizer(
- learning_rate, momentum=config.momentum)
- grads, vars_, logits, loss = model.compute_gradients(
- inputs, labels, training=True)
- train_op = optimizer.apply_gradients(
- zip(grads, vars_), global_step=global_step)
-
- return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
- else:
- logits, _ = model(inputs, training=False)
- predictions = tf.argmax(logits, axis=1)
- probabilities = tf.nn.softmax(logits)
- loss = model.compute_loss(labels=labels, logits=logits)
-
- if mode == tf.estimator.ModeKeys.EVAL:
- return tf.estimator.EstimatorSpec(
- mode=mode,
- loss=loss,
- eval_metric_ops={
- "accuracy":
- tf.metrics.accuracy(labels=labels, predictions=predictions)
- })
-
- else: # mode == tf.estimator.ModeKeys.PREDICT
- result = {
- "classes": predictions,
- "probabilities": probabilities,
- }
-
- return tf.estimator.EstimatorSpec(
- mode=mode,
- predictions=predictions,
- export_outputs={
- "classify": tf.estimator.export.PredictOutput(result)
- })
-
-
-def get_input_fn(config, data_dir, split):
- """Get the input function that is required by the `tf.estimator` API.
-
- Args:
- config: Customized hyperparameters
- data_dir: Directory where the data is stored
- split: One of `train`, `validation`, `train_all`, and `test`
-
- Returns:
- Input function required by the `tf.estimator` API
- """
-
- data_dir = os.path.join(data_dir, config.dataset)
- # Fix split-dependent hyperparameters
- if split == "train_all" or split == "train":
- data_aug = True
- batch_size = config.batch_size
- epochs = config.epochs
- shuffle = True
- prefetch = config.batch_size
- else:
- data_aug = False
- batch_size = config.eval_batch_size
- epochs = 1
- shuffle = False
- prefetch = config.eval_batch_size
-
- def input_fn():
- """Input function required by the `tf.estimator.Estimator` API."""
- return cifar_input.get_ds_from_tfrecords(
- data_dir=data_dir,
- split=split,
- data_aug=data_aug,
- batch_size=batch_size,
- epochs=epochs,
- shuffle=shuffle,
- prefetch=prefetch,
- data_format=config.data_format)
-
- return input_fn
-
-
-def main(argv):
- FLAGS = argv[0] # pylint:disable=invalid-name,redefined-outer-name
- tf.logging.set_verbosity(tf.logging.INFO)
-
- # RevNet specific configuration
- config = main_.get_config(config_name=FLAGS.config, dataset=FLAGS.dataset)
-
- # Estimator specific configuration
- run_config = tf.estimator.RunConfig(
- model_dir=FLAGS.train_dir, # Directory for storing checkpoints
- tf_random_seed=config.seed,
- save_summary_steps=config.log_every,
- save_checkpoints_steps=config.log_every,
- session_config=None, # Using default
- keep_checkpoint_max=100,
- keep_checkpoint_every_n_hours=10000, # Using default
- log_step_count_steps=config.log_every,
- train_distribute=None # Default not use distribution strategy
- )
-
- # Construct estimator
- revnet_estimator = tf.estimator.Estimator(
- model_fn=model_fn,
- model_dir=FLAGS.train_dir,
- config=run_config,
- params={"config": config})
-
- # Construct input functions
- train_input_fn = get_input_fn(
- config=config, data_dir=FLAGS.data_dir, split="train_all")
- eval_input_fn = get_input_fn(
- config=config, data_dir=FLAGS.data_dir, split="test")
-
- # Train and evaluate estimator
- revnet_estimator.train(input_fn=train_input_fn)
- revnet_estimator.evaluate(input_fn=eval_input_fn)
-
- if FLAGS.export:
- input_shape = (None,) + config.input_shape
- inputs = tf.placeholder(tf.float32, shape=input_shape)
- input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
- "image": inputs
- })
- revnet_estimator.export_savedmodel(FLAGS.train_dir, input_fn)
-
-
-if __name__ == "__main__":
- flags.DEFINE_string(
- "data_dir", default=None, help="Directory to load tfrecords")
- flags.DEFINE_string(
- "train_dir",
- default=None,
- help="[Optional] Directory to store the training information")
- flags.DEFINE_string(
- "dataset",
- default="cifar-10",
- help="[Optional] The dataset used; either `cifar-10` or `cifar-100`")
- flags.DEFINE_boolean(
- "export",
- default=False,
- help="[Optional] Export the model for serving if True")
- flags.DEFINE_string(
- "config",
- default="revnet-38",
- help="[Optional] Architecture of network. "
- "Other options include `revnet-110` and `revnet-164`")
- FLAGS = flags.FLAGS
- tf.app.run(main=main, argv=[FLAGS])
diff --git a/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py b/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py
deleted file mode 100644
index f1e1e530df..0000000000
--- a/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py
+++ /dev/null
@@ -1,328 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Cloud TPU Estimator workflow with RevNet train on CIFAR-10."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-import time
-
-from absl import flags
-import tensorflow as tf
-from tensorflow.contrib.eager.python.examples.revnet import cifar_input
-from tensorflow.contrib.eager.python.examples.revnet import main as main_
-from tensorflow.contrib.eager.python.examples.revnet import revnet
-from tensorflow.contrib.training.python.training import evaluation
-from tensorflow.python.estimator import estimator as estimator_
-
-
-def model_fn(features, labels, mode, params):
- """Model function required by the `tf.contrib.tpu.TPUEstimator` API.
-
- Args:
- features: Input images
- labels: Labels of images
- mode: One of `ModeKeys.TRAIN`, `ModeKeys.EVAL` or 'ModeKeys.PREDICT'
- params: A dictionary of extra parameter that might be passed
-
- Returns:
- An instance of `tf.contrib.tpu.TPUEstimatorSpec`
- """
-
- inputs = features
- if isinstance(inputs, dict):
- inputs = features["image"]
-
- FLAGS = params["FLAGS"] # pylint:disable=invalid-name,redefined-outer-name
- config = params["config"]
- model = revnet.RevNet(config=config)
-
- if mode == tf.estimator.ModeKeys.TRAIN:
- global_step = tf.train.get_or_create_global_step()
- learning_rate = tf.train.piecewise_constant(
- global_step, config.lr_decay_steps, config.lr_list)
- optimizer = tf.train.MomentumOptimizer(
- learning_rate, momentum=config.momentum)
-
- if FLAGS.use_tpu:
- optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
-
- # Define gradients
- grads, vars_, logits, loss = model.compute_gradients(
- inputs, labels, training=True)
- train_op = optimizer.apply_gradients(
- zip(grads, vars_), global_step=global_step)
-
- names = [v.name for v in model.variables]
- tf.logging.warn("{}".format(names))
-
- return tf.contrib.tpu.TPUEstimatorSpec(
- mode=tf.estimator.ModeKeys.TRAIN, loss=loss, train_op=train_op)
-
- if mode == tf.estimator.ModeKeys.EVAL:
- logits, _ = model(inputs, training=False)
- loss = model.compute_loss(labels=labels, logits=logits)
-
- def metric_fn(labels, logits):
- predictions = tf.argmax(logits, axis=1)
- accuracy = tf.metrics.accuracy(labels=labels, predictions=predictions)
- return {
- "accuracy": accuracy,
- }
-
- return tf.contrib.tpu.TPUEstimatorSpec(
- mode=mode, loss=loss, eval_metrics=(metric_fn, [labels, logits]))
-
- if mode == tf.estimator.ModeKeys.PREDICT:
- logits, _ = model(inputs, training=False)
- predictions = {
- "classes": tf.argmax(logits, axis=1),
- "probabilities": tf.nn.softmax(logits),
- }
-
- return tf.contrib.tpu.TPUEstimatorSpec(
- mode=mode,
- predictions=predictions,
- export_outputs={
- "classify": tf.estimator.export.PredictOutput(predictions)
- })
-
-
-def get_input_fn(config, data_dir, split):
- """Get the input function required by the `tf.contrib.tpu.TPUEstimator` API.
-
- Args:
- config: Customized hyperparameters
- data_dir: Directory where the data is stored
- split: One of `train`, `validation`, `train_all`, and `test`
-
- Returns:
- Input function required by the `tf.contrib.tpu.TPUEstimator` API
- """
-
- data_dir = os.path.join(data_dir, config.dataset)
- # Fix split-dependent hyperparameters
- if split == "train_all" or split == "train":
- data_aug = True
- epochs = config.tpu_epochs
- shuffle = True
- else:
- data_aug = False
- epochs = 1
- shuffle = False
-
- def input_fn(params):
- """Input function required by the `tf.contrib.tpu.TPUEstimator` API."""
- batch_size = params["batch_size"]
- return cifar_input.get_ds_from_tfrecords(
- data_dir=data_dir,
- split=split,
- data_aug=data_aug,
- batch_size=batch_size, # per-shard batch size
- epochs=epochs,
- shuffle=shuffle,
- prefetch=batch_size, # per-shard batch size
- data_format=config.data_format)
-
- return input_fn
-
-
-def main(argv):
- FLAGS = argv[0] # pylint:disable=invalid-name,redefined-outer-name
- tf.logging.set_verbosity(tf.logging.INFO)
-
- # RevNet specific configuration
- config = main_.get_config(config_name=FLAGS.config, dataset=FLAGS.dataset)
-
- if FLAGS.use_tpu:
- tf.logging.info("Using TPU.")
- tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
- FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
- else:
- tpu_cluster_resolver = None
-
- # TPU specific configuration
- tpu_config = tf.contrib.tpu.TPUConfig(
- # Recommended to be set as number of global steps for next checkpoint
- iterations_per_loop=FLAGS.iterations_per_loop,
- num_shards=FLAGS.num_shards)
-
- # Estimator specific configuration
- run_config = tf.contrib.tpu.RunConfig(
- cluster=tpu_cluster_resolver,
- model_dir=FLAGS.model_dir,
- session_config=tf.ConfigProto(
- allow_soft_placement=True, log_device_placement=False),
- tpu_config=tpu_config,
- )
-
- # Construct TPU Estimator
- estimator = tf.contrib.tpu.TPUEstimator(
- model_fn=model_fn,
- use_tpu=FLAGS.use_tpu,
- train_batch_size=config.tpu_batch_size,
- eval_batch_size=config.tpu_eval_batch_size,
- config=run_config,
- params={
- "FLAGS": FLAGS,
- "config": config,
- })
-
- # Construct input functions
- train_input_fn = get_input_fn(
- config=config, data_dir=FLAGS.data_dir, split="train_all")
- eval_input_fn = get_input_fn(
- config=config, data_dir=FLAGS.data_dir, split="test")
-
- # Disabling a range within an else block currently doesn't work
- # due to https://github.com/PyCQA/pylint/issues/872
- # pylint: disable=protected-access
- if FLAGS.mode == "eval":
- # TPUEstimator.evaluate *requires* a steps argument.
- # Note that the number of examples used during evaluation is
- # --eval_steps * --batch_size.
- # So if you change --batch_size then change --eval_steps too.
- eval_steps = 10000 // config.tpu_eval_batch_size
-
- # Run evaluation when there's a new checkpoint
- for ckpt in evaluation.checkpoints_iterator(
- FLAGS.model_dir, timeout=FLAGS.eval_timeout):
- tf.logging.info("Starting to evaluate.")
- try:
- start_timestamp = time.time() # This time will include compilation time
- eval_results = estimator.evaluate(
- input_fn=eval_input_fn, steps=eval_steps, checkpoint_path=ckpt)
- elapsed_time = int(time.time() - start_timestamp)
- tf.logging.info("Eval results: %s. Elapsed seconds: %d" %
- (eval_results, elapsed_time))
-
- # Terminate eval job when final checkpoint is reached
- current_step = int(os.path.basename(ckpt).split("-")[1])
- if current_step >= config.max_train_iter:
- tf.logging.info(
- "Evaluation finished after training step %d" % current_step)
- break
-
- except tf.errors.NotFoundError:
- # Since the coordinator is on a different job than the TPU worker,
- # sometimes the TPU worker does not finish initializing until long after
- # the CPU job tells it to start evaluating. In this case, the checkpoint
- # file could have been deleted already.
- tf.logging.info(
- "Checkpoint %s no longer exists, skipping checkpoint" % ckpt)
-
- else: # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
- current_step = estimator_._load_global_step_from_checkpoint_dir(
- FLAGS.model_dir)
- tf.logging.info("Training for %d steps . Current"
- " step %d." % (config.max_train_iter, current_step))
-
- start_timestamp = time.time() # This time will include compilation time
- if FLAGS.mode == "train":
- estimator.train(input_fn=train_input_fn, max_steps=config.max_train_iter)
- else:
- eval_steps = 10000 // config.tpu_eval_batch_size
- assert FLAGS.mode == "train_and_eval"
- while current_step < config.max_train_iter:
- # Train for up to steps_per_eval number of steps.
- # At the end of training, a checkpoint will be written to --model_dir.
- next_checkpoint = min(current_step + FLAGS.steps_per_eval,
- config.max_train_iter)
- estimator.train(input_fn=train_input_fn, max_steps=next_checkpoint)
- current_step = next_checkpoint
-
- # Evaluate the model on the most recent model in --model_dir.
- # Since evaluation happens in batches of --eval_batch_size, some images
- # may be consistently excluded modulo the batch size.
- tf.logging.info("Starting to evaluate.")
- eval_results = estimator.evaluate(
- input_fn=eval_input_fn, steps=eval_steps)
- tf.logging.info("Eval results: %s" % eval_results)
-
- elapsed_time = int(time.time() - start_timestamp)
- tf.logging.info("Finished training up to step %d. Elapsed seconds %d." %
- (config.max_train_iter, elapsed_time))
- # pylint: enable=protected-access
-
-
-if __name__ == "__main__":
- # Cloud TPU Cluster Resolver flags
- flags.DEFINE_string(
- "tpu",
- default=None,
- help="The Cloud TPU to use for training. This should be either the name "
- "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
- "url.")
- flags.DEFINE_string(
- "tpu_zone",
- default=None,
- help="[Optional] GCE zone where the Cloud TPU is located in. If not "
- "specified, we will attempt to automatically detect the GCE project from "
- "metadata.")
- flags.DEFINE_string(
- "gcp_project",
- default=None,
- help="[Optional] Project name for the Cloud TPU-enabled project. If not "
- "specified, we will attempt to automatically detect the GCE project from "
- "metadata.")
-
- # Model specific parameters
- flags.DEFINE_string(
- "data_dir", default=None, help="Directory to load tfrecords")
- flags.DEFINE_string(
- "model_dir",
- default=None,
- help="[Optional] Directory to store the model information")
- flags.DEFINE_string(
- "dataset",
- default="cifar-10",
- help="[Optional] The dataset used; either `cifar-10` or `cifar-100`")
- flags.DEFINE_string(
- "config",
- default="revnet-38",
- help="[Optional] Architecture of network. "
- "Other options include `revnet-110` and `revnet-164`")
- flags.DEFINE_boolean(
- "use_tpu", default=True, help="[Optional] Whether to use TPU")
- flags.DEFINE_integer(
- "num_shards", default=8, help="Number of shards (TPU chips).")
- flags.DEFINE_integer(
- "iterations_per_loop",
- default=100,
- help=(
- "Number of steps to run on TPU before feeding metrics to the CPU."
- " If the number of iterations in the loop would exceed the number of"
- " train steps, the loop will exit before reaching"
- " --iterations_per_loop. The larger this value is, the higher the"
- " utilization on the TPU."))
- flags.DEFINE_string(
- "mode",
- default="train_and_eval",
- help="[Optional] Mode to run: train, eval, train_and_eval")
- flags.DEFINE_integer(
- "eval_timeout", 60 * 60 * 24,
- "Maximum seconds between checkpoints before evaluation terminates.")
- flags.DEFINE_integer(
- "steps_per_eval",
- default=1000,
- help=(
- "Controls how often evaluation is performed. Since evaluation is"
- " fairly expensive, it is advised to evaluate as infrequently as"
- " possible (i.e. up to --train_steps, which evaluates the model only"
- " after finishing the entire training regime)."))
- FLAGS = flags.FLAGS
- tf.app.run(main=main, argv=[FLAGS])
diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet.py b/tensorflow/contrib/eager/python/examples/revnet/revnet.py
index a3c2f7dbec..af0d20fa72 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/revnet.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/revnet.py
@@ -24,6 +24,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
+import operator
+
import six
import tensorflow as tf
from tensorflow.contrib.eager.python.examples.revnet import blocks
@@ -42,9 +45,71 @@ class RevNet(tf.keras.Model):
self.axis = 1 if config.data_format == "channels_first" else 3
self.config = config
- self._init_block = blocks.InitBlock(config=self.config)
- self._final_block = blocks.FinalBlock(config=self.config)
+ self._init_block = self._construct_init_block()
self._block_list = self._construct_intermediate_blocks()
+ self._final_block = self._construct_final_block()
+
+ def _construct_init_block(self):
+ init_block = tf.keras.Sequential(
+ [
+ tf.keras.layers.Conv2D(
+ filters=self.config.init_filters,
+ kernel_size=self.config.init_kernel,
+ strides=(self.config.init_stride, self.config.init_stride),
+ data_format=self.config.data_format,
+ use_bias=False,
+ padding="SAME",
+ input_shape=self.config.input_shape,
+ dtype=self.config.dtype),
+ tf.keras.layers.BatchNormalization(
+ axis=self.axis,
+ fused=self.config.fused,
+ dtype=self.config.dtype),
+ tf.keras.layers.Activation("relu"),
+ ],
+ name="init")
+ if self.config.init_max_pool:
+ init_block.add(
+ tf.keras.layers.MaxPooling2D(
+ pool_size=(3, 3),
+ strides=(2, 2),
+ padding="SAME",
+ data_format=self.config.data_format,
+ dtype=self.config.dtype))
+ return init_block
+
+ def _construct_final_block(self):
+ f = self.config.filters[-1] # Number of filters
+ r = functools.reduce(operator.mul, self.config.strides, 1) # Reduce ratio
+ r *= self.config.init_stride
+ if self.config.init_max_pool:
+ r *= 2
+
+ if self.config.data_format == "channels_first":
+ w, h = self.config.input_shape[1], self.config.input_shape[2]
+ input_shape = (f, w // r, h // r)
+ elif self.config.data_format == "channels_last":
+ w, h = self.config.input_shape[0], self.config.input_shape[1]
+ input_shape = (w // r, h // r, f)
+ else:
+ raise ValueError("Data format should be either `channels_first`"
+ " or `channels_last`")
+
+ final_block = tf.keras.Sequential(
+ [
+ tf.keras.layers.BatchNormalization(
+ axis=self.axis,
+ input_shape=input_shape,
+ fused=self.config.fused,
+ dtype=self.config.dtype),
+ tf.keras.layers.Activation("relu"),
+ tf.keras.layers.GlobalAveragePooling2D(
+ data_format=self.config.data_format, dtype=self.config.dtype),
+ tf.keras.layers.Dense(
+ self.config.n_classes, dtype=self.config.dtype)
+ ],
+ name="final")
+ return final_block
def _construct_intermediate_blocks(self):
# Precompute input shape after initial block
@@ -141,20 +206,13 @@ class RevNet(tf.keras.Model):
l2_reg: Apply l2 regularization
Returns:
- A tuple with the first entry being a list of all gradients, the second
- entry being a list of respective variables, the third being the logits,
- and the forth being the loss
+ list of tuples each being (grad, var) for optimizer to use
"""
- # Run forward pass to record hidden states
+ # Run forward pass to record hidden states; avoid updating running averages
vars_and_vals = self.get_moving_stats()
_, saved_hidden = self.call(inputs, training=training)
- if tf.executing_eagerly():
- # Restore moving averages when executing eagerly to avoid updating twice
- self.restore_moving_stats(vars_and_vals)
- else:
- # Fetch batch norm updates in graph mode
- updates = self.get_updates_for(inputs)
+ self.restore_moving_stats(vars_and_vals)
grads_all = []
vars_all = []
@@ -162,8 +220,9 @@ class RevNet(tf.keras.Model):
# Manually backprop through last block
x = saved_hidden[-1]
with tf.GradientTape() as tape:
+ x = tf.identity(x)
tape.watch(x)
- # Running stats updated here
+ # Running stats updated below
logits = self._final_block(x, training=training)
loss = self.compute_loss(logits, labels)
@@ -177,7 +236,6 @@ class RevNet(tf.keras.Model):
for block in reversed(self._block_list):
y = saved_hidden.pop()
x = saved_hidden[-1]
- # Running stats updated here
dy, grads, vars_ = block.backward_grads_and_vars(
x, y, dy, training=training)
grads_all += grads
@@ -189,7 +247,8 @@ class RevNet(tf.keras.Model):
assert not saved_hidden # Cleared after backprop
with tf.GradientTape() as tape:
- # Running stats updated here
+ x = tf.identity(x)
+ # Running stats updated below
y = self._init_block(x, training=training)
grads_all += tape.gradient(
@@ -200,13 +259,7 @@ class RevNet(tf.keras.Model):
if l2_reg:
grads_all = self._apply_weight_decay(grads_all, vars_all)
- if not tf.executing_eagerly():
- # Force updates to be executed before gradient computation in graph mode
- # This does nothing when the function is wrapped in defun
- with tf.control_dependencies(updates):
- grads_all[0] = tf.identity(grads_all[0])
-
- return grads_all, vars_all, logits, loss
+ return grads_all, vars_all, loss
def _apply_weight_decay(self, grads, vars_):
"""Update gradients to reflect weight decay."""
@@ -231,10 +284,8 @@ class RevNet(tf.keras.Model):
n = v.name
return n.endswith("moving_mean:0") or n.endswith("moving_variance:0")
- device = "/gpu:0" if tf.test.is_gpu_available() else "/cpu:0"
- with tf.device(device):
- for v in filter(_is_moving_var, self.variables):
- vars_and_vals[v] = v.read_value()
+ for v in filter(_is_moving_var, self.variables):
+ vars_and_vals[v] = v.read_value()
return vars_and_vals
@@ -246,8 +297,5 @@ class RevNet(tf.keras.Model):
Args:
vars_and_vals: The dictionary mapping variables to their previous values.
"""
- device = "/gpu:0" if tf.test.is_gpu_available() else "/cpu:0"
- with tf.device(device):
- for var_, val in six.iteritems(vars_and_vals):
- # `assign` causes a copy to GPU (if variable is already on GPU)
- var_.assign(val)
+ for var_, val in six.iteritems(vars_and_vals):
+ var_.assign(val)
diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
index 26b0847523..b0d0a5486d 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
@@ -31,11 +31,10 @@ tfe = tf.contrib.eager
def train_one_iter(model, inputs, labels, optimizer, global_step=None):
"""Train for one iteration."""
- grads, vars_, logits, loss = model.compute_gradients(
- inputs, labels, training=True)
+ grads, vars_, loss = model.compute_gradients(inputs, labels, training=True)
optimizer.apply_gradients(zip(grads, vars_), global_step=global_step)
- return logits, loss
+ return loss
class RevNetTest(tf.test.TestCase):
@@ -43,8 +42,6 @@ class RevNetTest(tf.test.TestCase):
def setUp(self):
super(RevNetTest, self).setUp()
config = config_.get_hparams_cifar_38()
- config.add_hparam("n_classes", 10)
- config.add_hparam("dataset", "cifar-10")
# Reconstruction could cause numerical error, use double precision for tests
config.dtype = tf.float64
config.fused = False # Fused batch norm does not support tf.float64
@@ -97,7 +94,7 @@ class RevNetTest(tf.test.TestCase):
def test_compute_gradients(self):
"""Test `compute_gradients` function."""
self.model(self.x, training=False) # Initialize model
- grads, vars_, logits, loss = self.model.compute_gradients(
+ grads, vars_, loss = self.model.compute_gradients(
inputs=self.x, labels=self.t, training=True, l2_reg=True)
self.assertTrue(isinstance(grads, list))
self.assertTrue(isinstance(vars_, list))
@@ -122,7 +119,7 @@ class RevNetTest(tf.test.TestCase):
def test_compute_gradients_defun(self):
"""Test `compute_gradients` function with defun."""
compute_gradients = tfe.defun(self.model.compute_gradients)
- grads, vars_, _, _ = compute_gradients(self.x, self.t, training=True)
+ grads, vars_, _ = compute_gradients(self.x, self.t, training=True)
self.assertTrue(isinstance(grads, list))
self.assertTrue(isinstance(vars_, list))
self.assertEqual(len(grads), len(vars_))
@@ -134,9 +131,6 @@ class RevNetTest(tf.test.TestCase):
"""Test model training in graph mode."""
with tf.Graph().as_default():
config = config_.get_hparams_cifar_38()
- config.add_hparam("n_classes", 10)
- config.add_hparam("dataset", "cifar-10")
-
x = tf.random_normal(
shape=(self.config.batch_size,) + self.config.input_shape)
t = tf.random_uniform(
@@ -146,10 +140,15 @@ class RevNetTest(tf.test.TestCase):
dtype=tf.int32)
global_step = tf.Variable(0., trainable=False)
model = revnet.RevNet(config=config)
- grads_all, vars_all, _, _ = model.compute_gradients(x, t, training=True)
+ model(x)
+ updates = model.get_updates_for(x)
+
+ x_ = tf.identity(x)
+ grads_all, vars_all, _ = model.compute_gradients(x_, t, training=True)
optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
- train_op = optimizer.apply_gradients(
- zip(grads_all, vars_all), global_step=global_step)
+ with tf.control_dependencies(updates):
+ train_op = optimizer.apply_gradients(
+ zip(grads_all, vars_all), global_step=global_step)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())