aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/opt
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-15 17:40:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-15 17:45:08 -0700
commit1608a0c91c2a604e1573ad12551c1c38c826a5c0 (patch)
tree80cd63b0e0814ac393c26a467bcbb1cd7631e42f /tensorflow/contrib/opt
parent1357711e6faf688f863821c35be6c358891616ec (diff)
LARS Optimizer in TensorFlow
Based on CL from Chris Ying and contributions from Y. You and Wang Tao. Introduced by "Large Batch Training of Convolutional Networks" by Y. You, I. Gitman, and B. Ginsburg. (https://arxiv.org/abs/1708.03888) Implements the LARS learning rate scheme presented in the paper above. This optimizer is useful when scaling the batch size to up to 32K without significant performance degradation. It is recommended to use the optimizer in conjunction with: - Gradual learning rate warm-up - Linear learning rate scaling - Poly rule learning rate decay With this optimizer, ResNet-50 now converges to 76.3% top-1 accuracy at batch 32K on a JF Pod. PiperOrigin-RevId: 208914187
Diffstat (limited to 'tensorflow/contrib/opt')
-rw-r--r--tensorflow/contrib/opt/BUILD16
-rw-r--r--tensorflow/contrib/opt/__init__.py2
-rw-r--r--tensorflow/contrib/opt/python/training/lars_optimizer.py164
-rw-r--r--tensorflow/contrib/opt/python/training/lars_optimizer_test.py127
4 files changed, 309 insertions, 0 deletions
diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD
index 778b710d78..5319a8b655 100644
--- a/tensorflow/contrib/opt/BUILD
+++ b/tensorflow/contrib/opt/BUILD
@@ -20,6 +20,7 @@ py_library(
"python/training/elastic_average_optimizer.py",
"python/training/external_optimizer.py",
"python/training/ggt.py",
+ "python/training/lars_optimizer.py",
"python/training/lazy_adam_optimizer.py",
"python/training/model_average_optimizer.py",
"python/training/moving_average_optimizer.py",
@@ -365,3 +366,18 @@ py_test(
"@absl_py//absl/testing:parameterized",
],
)
+
+py_test(
+ name = "lars_optimizer_test",
+ srcs = ["python/training/lars_optimizer_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":opt_py",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ ],
+)
diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py
index 9471fb0181..781621dba0 100644
--- a/tensorflow/contrib/opt/__init__.py
+++ b/tensorflow/contrib/opt/__init__.py
@@ -24,6 +24,7 @@ from tensorflow.contrib.opt.python.training.addsign import *
from tensorflow.contrib.opt.python.training.drop_stale_gradient_optimizer import *
from tensorflow.contrib.opt.python.training.elastic_average_optimizer import *
from tensorflow.contrib.opt.python.training.external_optimizer import *
+from tensorflow.contrib.opt.python.training.lars_optimizer import *
from tensorflow.contrib.opt.python.training.ggt import *
from tensorflow.contrib.opt.python.training.lazy_adam_optimizer import *
from tensorflow.contrib.opt.python.training.model_average_optimizer import *
@@ -46,6 +47,7 @@ _allowed_symbols = [
'DelayCompensatedGradientDescentOptimizer',
'DropStaleGradientOptimizer',
'ExternalOptimizerInterface',
+ 'LARSOptimizer',
'LazyAdamOptimizer',
'NadamOptimizer',
'MovingAverageOptimizer',
diff --git a/tensorflow/contrib/opt/python/training/lars_optimizer.py b/tensorflow/contrib/opt/python/training/lars_optimizer.py
new file mode 100644
index 0000000000..a8dafd9a4c
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/lars_optimizer.py
@@ -0,0 +1,164 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Layer-wise Adaptive Rate Scaling optimizer for large-batch training."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.training import optimizer
+from tensorflow.python.training import training_ops
+
+
+class LARSOptimizer(optimizer.Optimizer):
+ """Layer-wise Adaptive Rate Scaling for large batch training.
+
+ Introduced by "Large Batch Training of Convolutional Networks" by Y. You,
+ I. Gitman, and B. Ginsburg. (https://arxiv.org/abs/1708.03888)
+
+ Implements the LARS learning rate scheme presented in the paper above. This
+ optimizer is useful when scaling the batch size to up to 32K without
+ significant performance degradation. It is recommended to use the optimizer
+ in conjunction with:
+ - Gradual learning rate warm-up
+ - Linear learning rate scaling
+ - Poly rule learning rate decay
+
+ Note, LARS scaling is currently only enabled for dense tensors. Sparse tensors
+ use the default momentum optimizer.
+ """
+
+ def __init__(
+ self,
+ learning_rate,
+ momentum=0.9,
+ weight_decay=0.0001,
+ # The LARS coefficient is a hyperparameter
+ eeta=0.001,
+ epsilon=0.0,
+ name="LARSOptimizer",
+ # Enable skipping variables from LARS scaling.
+ # TODO(sameerkm): Enable a direct mechanism to pass a
+ # subset of variables to the optimizer.
+ skip_list=None,
+ use_nesterov=False):
+ """Construct a new LARS Optimizer.
+
+ Args:
+ learning_rate: A `Tensor` or floating point value. The base learning rate.
+ momentum: A floating point value. Momentum hyperparameter.
+ weight_decay: A floating point value. Weight decay hyperparameter.
+ eeta: LARS coefficient as used in the paper. Dfault set to LARS
+ coefficient from the paper. (eeta / weight_decay) determines the highest
+ scaling factor in LARS.
+ epsilon: Optional epsilon parameter to be set in models that have very
+ small gradients. Default set to 0.0.
+ name: Optional name prefix for variables and ops created by LARSOptimizer.
+ skip_list: List of strings to enable skipping variables from LARS scaling.
+ If any of the strings in skip_list is a subset of var.name, variable
+ 'var' is skipped from LARS scaling. For a typical classification model
+ with batch normalization, the skip_list is ['batch_normalization',
+ 'bias']
+ use_nesterov: when set to True, nesterov momentum will be enabled
+
+ Raises:
+ ValueError: If a hyperparameter is set to a non-sensical value.
+ """
+ if momentum < 0.0:
+ raise ValueError("momentum should be positive: %s" % momentum)
+ if weight_decay < 0.0:
+ raise ValueError("weight_decay should be positive: %s" % weight_decay)
+ super(LARSOptimizer, self).__init__(use_locking=False, name=name)
+
+ self._learning_rate = learning_rate
+ self._momentum = momentum
+ self._weight_decay = weight_decay
+ self._eeta = eeta
+ self._epsilon = epsilon
+ self._name = name
+ self._skip_list = skip_list
+ self._use_nesterov = use_nesterov
+
+ def _create_slots(self, var_list):
+ for v in var_list:
+ self._zeros_slot(v, "momentum", self._name)
+
+ def compute_lr(self, grad, var):
+ scaled_lr = self._learning_rate
+ if self._skip_list is None or not any(v in var.name
+ for v in self._skip_list):
+ w_norm = linalg_ops.norm(var, ord=2)
+ g_norm = linalg_ops.norm(grad, ord=2)
+ trust_ratio = array_ops.where(
+ math_ops.greater(w_norm, 0),
+ array_ops.where(
+ math_ops.greater(g_norm, 0),
+ (self._eeta * w_norm /
+ (g_norm + self._weight_decay * w_norm + self._epsilon)), 1.0),
+ 1.0)
+ scaled_lr = self._learning_rate * trust_ratio
+ return scaled_lr
+
+ def _apply_dense(self, grad, var):
+ scaled_lr = self.compute_lr(grad, var)
+ mom = self.get_slot(var, "momentum")
+ return training_ops.apply_momentum(
+ var,
+ mom,
+ scaled_lr,
+ grad,
+ self._momentum,
+ use_locking=False,
+ use_nesterov=self._use_nesterov)
+
+ def _resource_apply_dense(self, grad, var):
+ scaled_lr = self.compute_lr(grad, var)
+ mom = self.get_slot(var, "momentum")
+ return training_ops.resource_apply_momentum(
+ var.handle,
+ mom.handle,
+ scaled_lr,
+ grad,
+ self._momentum,
+ use_locking=False,
+ use_nesterov=self._use_nesterov)
+
+ # Fallback to momentum optimizer for sparse tensors
+ def _apply_sparse(self, grad, var):
+ mom = self.get_slot(var, "momentum")
+ return training_ops.sparse_apply_momentum(
+ var,
+ mom,
+ math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
+ grad.values,
+ grad.indices,
+ math_ops.cast(self._momentum_tensor, var.dtype.base_dtype),
+ use_locking=self._use_locking,
+ use_nesterov=self._use_nesterov).op
+
+ def _resource_apply_sparse(self, grad, var, indices):
+ mom = self.get_slot(var, "momentum")
+ return training_ops.resource_sparse_apply_momentum(
+ var.handle,
+ mom.handle,
+ math_ops.cast(self._learning_rate_tensor, grad.dtype),
+ grad,
+ indices,
+ math_ops.cast(self._momentum_tensor, grad.dtype),
+ use_locking=self._use_locking,
+ use_nesterov=self._use_nesterov)
diff --git a/tensorflow/contrib/opt/python/training/lars_optimizer_test.py b/tensorflow/contrib/opt/python/training/lars_optimizer_test.py
new file mode 100644
index 0000000000..d94249b994
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/lars_optimizer_test.py
@@ -0,0 +1,127 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0. Licensed to the Apache
+# Software Foundation. You may not use this file except in compliance with the
+# License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test for Layer-wise Adaptive Rate Scaling optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.opt.python.training import lars_optimizer as lo
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class LARSOptimizerTest(test.TestCase):
+
+ def testLARSGradientOneStep(self):
+ for _ in range(10):
+ for dtype in [dtypes.float32, dtypes.float64]:
+ with self.test_session() as sess:
+ shape = [3, 3]
+ var_np = np.ones(shape)
+ grad_np = np.ones(shape)
+ lr_np = 0.1
+ m_np = 0.9
+ wd_np = 0.1
+ ep_np = 1e-5
+ eeta = 0.1
+ vel_np = np.zeros(shape)
+
+ var = variables.Variable(var_np, dtype=dtype)
+ grad = variables.Variable(grad_np, dtype=dtype)
+ opt = lo.LARSOptimizer(
+ learning_rate=lr_np,
+ momentum=m_np,
+ weight_decay=wd_np,
+ eeta=eeta,
+ epsilon=ep_np)
+
+ step = opt.apply_gradients([(grad, var)])
+ variables.global_variables_initializer().run()
+
+ pre_var = sess.run(var)
+ pre_vel = sess.run(opt.get_slot(var, 'momentum'))
+ self.assertAllClose(var_np, pre_var)
+ self.assertAllClose(vel_np, pre_vel)
+
+ step.run()
+ post_var = sess.run(var)
+ post_vel = sess.run(opt.get_slot(var, 'momentum'))
+
+ w_norm = np.linalg.norm(var_np.flatten(), ord=2)
+ g_norm = np.linalg.norm(grad_np.flatten(), ord=2)
+ trust_ratio = eeta * w_norm / (g_norm + wd_np * w_norm + ep_np)
+ scaled_lr = lr_np * trust_ratio
+
+ vel_np = m_np * vel_np + grad_np
+ var_np -= scaled_lr * vel_np
+
+ self.assertAllClose(var_np, post_var)
+ self.assertAllClose(vel_np, post_vel)
+
+ def testLARSGradientMultiStep(self):
+ for _ in range(10):
+ for dtype in [dtypes.float32, dtypes.float64]:
+ with self.test_session() as sess:
+ shape = [3, 3]
+ var_np = np.ones(shape)
+ grad_np = np.ones(shape)
+ lr_np = 0.1
+ m_np = 0.9
+ wd_np = 0.1
+ ep_np = 1e-5
+ eeta = 0.1
+ vel_np = np.zeros(shape)
+
+ var = variables.Variable(var_np, dtype=dtype)
+ grad = variables.Variable(grad_np, dtype=dtype)
+ opt = lo.LARSOptimizer(
+ learning_rate=lr_np,
+ momentum=m_np,
+ eeta=eeta,
+ weight_decay=wd_np,
+ epsilon=ep_np)
+
+ step = opt.apply_gradients([(grad, var)])
+ variables.global_variables_initializer().run()
+
+ pre_var = sess.run(var)
+ pre_vel = sess.run(opt.get_slot(var, 'momentum'))
+ self.assertAllClose(var_np, pre_var)
+ self.assertAllClose(vel_np, pre_vel)
+
+ for _ in range(10):
+ step.run()
+
+ post_var = sess.run(var)
+ post_vel = sess.run(opt.get_slot(var, 'momentum'))
+
+ w_norm = np.linalg.norm(var_np.flatten(), ord=2)
+ g_norm = np.linalg.norm(grad_np.flatten(), ord=2)
+ trust_ratio = eeta * w_norm / (g_norm + wd_np * w_norm + ep_np)
+ scaled_lr = lr_np * trust_ratio
+
+ vel_np = m_np * vel_np + grad_np
+ var_np -= scaled_lr * vel_np
+
+ self.assertAllClose(var_np, post_var)
+ self.assertAllClose(vel_np, post_vel)
+
+
+if __name__ == '__main__':
+ test.main()