aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager/python/examples/revnet/config.py
blob: 30b0edbf43304f4dd1b3a10165bdb28886d2d152 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Reversible residual network compatible with eager execution.

Configuration in format of tf.contrib.training.HParams.
Supports CIFAR-10, CIFAR-100, and ImageNet datasets.

Reference [The Reversible Residual Network: Backpropagation
Without Storing Activations](https://arxiv.org/pdf/1707.04585.pdf)

"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
tfe = tf.contrib.eager


def get_hparams_cifar_38():
  """RevNet-38 configurations for CIFAR-10/CIFAR-100."""

  config = tf.contrib.training.HParams()
  config.add_hparam("init_filters", 32)
  config.add_hparam("init_kernel", 3)
  config.add_hparam("init_stride", 1)
  config.add_hparam("n_classes", 10)
  config.add_hparam("n_rev_blocks", 3)
  config.add_hparam("n_res", [3, 3, 3])
  config.add_hparam("filters", [32, 64, 112])
  config.add_hparam("strides", [1, 2, 2])
  config.add_hparam("batch_size", 100)
  config.add_hparam("bottleneck", False)
  config.add_hparam("fused", True)
  config.add_hparam("init_max_pool", False)
  if tfe.num_gpus() > 0:
    config.add_hparam("input_shape", (3, 32, 32))
    config.add_hparam("data_format", "channels_first")
  else:
    config.add_hparam("input_shape", (32, 32, 3))
    config.add_hparam("data_format", "channels_last")

  # Training details
  config.add_hparam("weight_decay", 2e-4)
  config.add_hparam("momentum", .9)
  config.add_hparam("lr_decay_steps", [40000, 60000])
  config.add_hparam("lr_list", [1e-1, 1e-2, 1e-3])
  config.add_hparam("max_train_iter", 80000)
  config.add_hparam("seed", 1234)
  config.add_hparam("shuffle", True)
  config.add_hparam("log_every", 500)
  config.add_hparam("save_every", 500)
  config.add_hparam("dtype", tf.float32)
  config.add_hparam("eval_batch_size", 1000)
  config.add_hparam("div255", True)
  # TODO(lxuechen): This is imprecise, when training with validation set,
  # we only have 40k images in training data
  config.add_hparam("iters_per_epoch", 50000 // config.batch_size)
  config.add_hparam("epochs", config.max_train_iter // config.iters_per_epoch)

  return config


def get_hparams_imagenet_56():
  """RevNet-56 configurations for ImageNet."""

  config = tf.contrib.training.HParams()
  config.add_hparam("init_filters", 128)
  config.add_hparam("init_kernel", 7)
  config.add_hparam("init_stride", 2)
  config.add_hparam("n_classes", 1000)
  config.add_hparam("n_rev_blocks", 4)
  config.add_hparam("n_res", [2, 2, 2, 2])
  config.add_hparam("filters", [128, 256, 512, 832])
  config.add_hparam("strides", [1, 2, 2, 2])
  config.add_hparam("batch_size", 16)
  config.add_hparam("bottleneck", True)
  config.add_hparam("fused", True)
  config.add_hparam("init_max_pool", True)
  if tf.test.is_gpu_available():
    config.add_hparam("input_shape", (3, 224, 224))
    config.add_hparam("data_format", "channels_first")
  else:
    config.add_hparam("input_shape", (224, 224, 3))
    config.add_hparam("data_format", "channels_last")

  # Training details
  config.add_hparam("weight_decay", 1e-4)
  config.add_hparam("momentum", .9)
  config.add_hparam("lr_decay_steps", [160000, 320000, 480000])
  config.add_hparam("lr_list", [1e-1, 1e-2, 1e-3, 1e-4])
  config.add_hparam("max_train_iter", 600000)
  config.add_hparam("seed", 1234)
  config.add_hparam("shuffle", True)
  config.add_hparam("log_every", 50)
  config.add_hparam("save_every", 50)
  config.add_hparam("dtype", tf.float32)
  config.add_hparam("eval_batch_size", 1000)
  config.add_hparam("div255", True)
  # TODO(lxuechen): Update this according to ImageNet data
  config.add_hparam("iters_per_epoch", 50000 // config.batch_size)
  config.add_hparam("epochs", config.max_train_iter // config.iters_per_epoch)

  if config.bottleneck:
    filters = [f * 4 for f in config.filters]
    config.filters = filters

  return config