aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/mixed_precision
diff options
context:
space:
mode:
authorGravatar James Qin <jamesqin@google.com>2018-05-14 17:51:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-14 17:53:46 -0700
commit4c2cb712c7d7be93533da240d4c8e55e69d79625 (patch)
tree2ab6a3445c6df34d60106f9ee2a4f2c5184f1a4b /tensorflow/contrib/mixed_precision
parente1a49f30435096c7e0817dde2e472c85db143a81 (diff)
Introduce LossScalingOptimizer for mixed precision training.
PiperOrigin-RevId: 196597196
Diffstat (limited to 'tensorflow/contrib/mixed_precision')
-rw-r--r--tensorflow/contrib/mixed_precision/BUILD32
-rw-r--r--tensorflow/contrib/mixed_precision/__init__.py34
-rw-r--r--tensorflow/contrib/mixed_precision/python/BUILD74
-rw-r--r--tensorflow/contrib/mixed_precision/python/loss_scale_manager.py200
-rw-r--r--tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py182
-rw-r--r--tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py166
-rw-r--r--tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py216
7 files changed, 904 insertions, 0 deletions
diff --git a/tensorflow/contrib/mixed_precision/BUILD b/tensorflow/contrib/mixed_precision/BUILD
new file mode 100644
index 0000000000..3dfb95e0a0
--- /dev/null
+++ b/tensorflow/contrib/mixed_precision/BUILD
@@ -0,0 +1,32 @@
+# Mixed precision training optimizers
+
+package(
+ default_visibility = ["//tensorflow:internal"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_library(
+ name = "mixed_precision",
+ srcs = ["__init__.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/contrib/mixed_precision/python:loss_scale_manager",
+ "//tensorflow/contrib/mixed_precision/python:loss_scale_optimizer",
+ ],
+)
diff --git a/tensorflow/contrib/mixed_precision/__init__.py b/tensorflow/contrib/mixed_precision/__init__.py
new file mode 100644
index 0000000000..43e98cdda0
--- /dev/null
+++ b/tensorflow/contrib/mixed_precision/__init__.py
@@ -0,0 +1,34 @@
+# Copyright 2018 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# mixed_precisiond under the License is mixed_precisiond on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Library for mixed precision training."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,wildcard-import
+from tensorflow.contrib.mixed_precision.python.loss_scale_manager import *
+from tensorflow.contrib.mixed_precision.python.loss_scale_optimizer import *
+
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = [
+ "LossScaleManager",
+ "FixedLossScaleManager",
+ "ExponentialUpdateLossScaleManager",
+ "LossScaleOptimizer",
+]
+
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/mixed_precision/python/BUILD b/tensorflow/contrib/mixed_precision/python/BUILD
new file mode 100644
index 0000000000..1d769e1614
--- /dev/null
+++ b/tensorflow/contrib/mixed_precision/python/BUILD
@@ -0,0 +1,74 @@
+# Mixed precision training optimizers
+
+package(
+ default_visibility = ["//tensorflow:internal"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_library(
+ name = "loss_scale_manager",
+ srcs = ["loss_scale_manager.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:variable_scope",
+ ],
+)
+
+py_test(
+ name = "loss_scale_manager_test",
+ size = "small",
+ srcs = ["loss_scale_manager_test.py"],
+ deps = [
+ ":loss_scale_manager",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "loss_scale_optimizer",
+ srcs = ["loss_scale_optimizer.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":loss_scale_manager",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:util",
+ ],
+)
+
+py_test(
+ name = "loss_scale_optimizer_test",
+ size = "small",
+ srcs = ["loss_scale_optimizer_test.py"],
+ deps = [
+ ":loss_scale_optimizer",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_manager.py b/tensorflow/contrib/mixed_precision/python/loss_scale_manager.py
new file mode 100644
index 0000000000..be7377b151
--- /dev/null
+++ b/tensorflow/contrib/mixed_precision/python/loss_scale_manager.py
@@ -0,0 +1,200 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""LossScaleManager classes for mixed precision training."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import six
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_control_flow_ops
+from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+
+
+@six.add_metaclass(abc.ABCMeta)
+class LossScaleManager(object):
+ """Abstract loss scale manager class.
+
+ Loss scale managers with a different strategy should subclass this class.
+ Loss scaling is a process that:
+
+ 1) Applies a multiplier on the loss before computing gradients, and
+ 2) Applies the reciprocal of the multiplier on the gradients before they are
+ applied on variables.
+
+ This class is used together with
+ @{tf.contrib.mixed_precision.LossScaleOptimizer} for mixed precision training
+ (float32 variables and float16 ops) on Nvidia GPUs in order to achieve the
+ same model quality as single precision training, with the benefits of
+ potential higher throughput.
+
+ See @{tf.contrib.mixed_precision.LossScaleOptimizer} for more details.
+ """
+
+ @abc.abstractmethod
+ def get_loss_scale(self):
+ """Returns the loss scale as a scalar `float32` tensor."""
+ pass
+
+ @abc.abstractmethod
+ def update_loss_scale(self, finite_grads):
+ """Updates loss scale based on if gradients are finite in current step.
+
+ Args:
+ finite_grads: bool scalar tensor indicating if all gradients are
+ finite (i.e., not inf or nan).
+
+ Returns:
+ An op, when executed updates the loss scale. If eager execution is
+ enabled, does not return anything.
+ """
+ del finite_grads
+ return
+
+
+class FixedLossScaleManager(LossScaleManager):
+ """Loss scale manager with a fixed loss scale.
+
+ The loss scale is not updated for the lifetime of the class.
+ """
+
+ def __init__(self, loss_scale):
+ """Creates the fixed loss scale manager.
+
+ Args:
+ loss_scale: A Python float. Its ideal value varies depending on models to
+ run. Choosing a too small loss_scale might affect model quality; a too
+ big loss_scale might cause inf or nan. There is no single right
+ loss_scale to apply. There is no harm choosing a relatively big number
+ as long as no nan or inf is encountered in training.
+
+ Raises:
+ ValueError: If loss_scale is less than 1.
+ """
+ if loss_scale < 1:
+ raise ValueError("loss scale must be at least 1.")
+ self._loss_scale = ops.convert_to_tensor(loss_scale, dtype=dtypes.float32)
+
+ def get_loss_scale(self):
+ return self._loss_scale
+
+ def update_loss_scale(self, finite_grads):
+ del finite_grads
+ return gen_control_flow_ops.no_op()
+
+
+class ExponentialUpdateLossScaleManager(LossScaleManager):
+ """Loss scale manager uses an exponential update strategy.
+
+ In general, the strategy increases loss scale by a greater-than-one factor
+ after encountering a consecutive series of steps with finite gradients;
+ Similarly, it decreases the loss scale by a factor when the accumulated number
+ of steps with non-finite (nan or inf) gradients are met. An update is not
+ applied if its result is less than 1 or overflows the float32 dynamic range.
+
+ The number of finite and non-finite steps are cleared every time the loss
+ scale is changed. The condition to decrease the loss scale is looser than to
+ increase it since the former does not require the steps to be consecutive.
+ """
+
+ def __init__(self,
+ init_loss_scale,
+ incr_every_n_steps,
+ decr_every_n_nan_or_inf=2,
+ incr_ratio=2,
+ decr_ratio=0.8):
+ """Constructor of exponential-update loss scale manager.
+
+ Args:
+ init_loss_scale: A Python float. The loss scale to use at the beginning.
+ incr_every_n_steps: Increases loss scale every n consecutive steps with
+ finite gradients.
+ decr_every_n_nan_or_inf: Decreases loss scale every n accumulated steps
+ with nan or inf gradients.
+ incr_ratio: The multiplier to use when increasing the loss scale.
+ decr_ratio: The less-than-one-multiplier to use when decreasing the loss
+ scale.
+ """
+ self._incr_every_n_steps = incr_every_n_steps
+ self._decr_every_n_nan_or_inf = decr_every_n_nan_or_inf
+ self._incr_ratio = incr_ratio
+ self._decr_ratio = decr_ratio
+ self._loss_scale = variable_scope.variable(
+ name="loss_scale",
+ initial_value=ops.convert_to_tensor(init_loss_scale, dtypes.float32),
+ dtype=dtypes.float32,
+ trainable=False)
+ self._num_good_steps = variable_scope.variable(
+ name="good_steps", initial_value=0, dtype=dtypes.int32, trainable=False)
+ self._num_bad_steps = variable_scope.variable(
+ name="bad_steps", initial_value=0, dtype=dtypes.int32, trainable=False)
+
+ def _reset_stats(self):
+ return control_flow_ops.group(
+ state_ops.assign(self._num_good_steps, 0),
+ state_ops.assign(self._num_bad_steps, 0))
+
+ def get_loss_scale(self):
+ """Returns the loss scale."""
+ return self._loss_scale
+
+ def update_loss_scale(self, finite_grads):
+ """Updates loss scale based on if gradients are finite in current step."""
+
+ def update_if_finite_grads():
+ """Branch function when grads are all finite."""
+
+ def incr_loss_scale():
+ new_loss_scale = control_flow_ops.cond(
+ gen_math_ops.is_finite(self._loss_scale * self._incr_ratio),
+ lambda: self._loss_scale * self._incr_ratio,
+ lambda: self._loss_scale)
+ update_op = state_ops.assign(self._loss_scale, new_loss_scale)
+ # When loss_scale is updated, both good and bad steps are reset.
+ return control_flow_ops.group(update_op, self._reset_stats())
+
+ return control_flow_ops.cond(
+ self._num_good_steps + 1 >= self._incr_every_n_steps,
+ incr_loss_scale,
+ lambda: state_ops.assign_add(self._num_good_steps, 1).op)
+
+ def update_if_not_finite_grads():
+ """Branch function when any grad is not finite."""
+
+ def decr_loss_scale():
+ update_op = state_ops.assign(
+ self._loss_scale,
+ gen_math_ops.maximum(1., self._loss_scale * self._decr_ratio))
+ # When loss_scale is updated, both good and bad steps are reset.
+ return control_flow_ops.group(update_op, self._reset_stats())
+
+ def just_update_steps():
+ # When bad_steps is incremented, good_step is reset.
+ return control_flow_ops.group(
+ state_ops.assign_add(self._num_bad_steps, 1),
+ state_ops.assign(self._num_good_steps, 0))
+
+ return control_flow_ops.cond(
+ self._num_bad_steps + 1 >= self._decr_every_n_nan_or_inf,
+ decr_loss_scale, just_update_steps)
+
+ return control_flow_ops.cond(finite_grads, update_if_finite_grads,
+ update_if_not_finite_grads)
diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py b/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py
new file mode 100644
index 0000000000..480f5f6eaf
--- /dev/null
+++ b/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py
@@ -0,0 +1,182 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for LossScaleManager classes.."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.mixed_precision.python import loss_scale_manager as lsm_lib
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import context
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+def _GetExampleIter(inputs):
+ dataset = dataset_ops.Dataset.from_tensor_slices(inputs)
+ return dataset.make_one_shot_iterator()
+
+
+class FixedLossScaleManagerTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_basic(self):
+ itr = _GetExampleIter([True] * 10 + [False] * 10)
+
+ loss_scale = 1000
+ lsm = lsm_lib.FixedLossScaleManager(loss_scale)
+ update_fn = lambda: lsm.update_loss_scale(itr.get_next())
+
+ self.evaluate(variables.global_variables_initializer())
+ if not context.executing_eagerly():
+ update_op = update_fn()
+ for _ in range(10):
+ if context.executing_eagerly():
+ update_fn()
+ else:
+ self.evaluate(update_op)
+ self.assertEqual(loss_scale, self.evaluate(lsm.get_loss_scale()))
+
+
+class ExponentialUpdateLossScaleManagerTest(test.TestCase):
+
+ def _test_helper(self,
+ inputs,
+ expected_outputs,
+ init_loss_scale=1,
+ incr_every_n_step=2,
+ decr_every_n_nan_or_inf=2):
+ ratio = 2
+ lsm = lsm_lib.ExponentialUpdateLossScaleManager(
+ init_loss_scale=init_loss_scale,
+ incr_every_n_steps=incr_every_n_step,
+ decr_every_n_nan_or_inf=decr_every_n_nan_or_inf,
+ incr_ratio=ratio,
+ decr_ratio=1. / ratio)
+ itr = _GetExampleIter(inputs)
+ update_fn = lambda: lsm.update_loss_scale(itr.get_next())
+
+ self.evaluate(variables.global_variables_initializer())
+ actual_outputs = []
+
+ if not context.executing_eagerly():
+ update_op = update_fn()
+ for _ in range(len(inputs)):
+ if context.executing_eagerly():
+ update_fn()
+ else:
+ self.evaluate(update_op)
+ actual_outputs.append(self.evaluate(lsm.get_loss_scale()))
+ self.assertEqual(actual_outputs, expected_outputs)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_increase_every_n_steps(self):
+ inputs = [True] * 6
+ expected_outputs = [1, 2, 2, 4, 4, 8]
+ self._test_helper(inputs, expected_outputs)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_keep_increasing_until_capped(self):
+ init_loss_scale = np.finfo(np.float32).max / 4 + 10
+ max_float = np.finfo(np.float32).max
+
+ inputs = [True] * 6
+ # Output is capped the 2nd time it doubles.
+ expected_outputs = [
+ init_loss_scale, init_loss_scale * 2, init_loss_scale * 2, max_float,
+ max_float, max_float
+ ]
+
+ self._test_helper(inputs, expected_outputs, init_loss_scale)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_decrease_every_n_steps(self):
+ inputs = [False] * 6
+ init_loss_scale = 1024
+ expected_outputs = [1024, 512, 512, 256, 256, 128]
+
+ self._test_helper(inputs, expected_outputs, init_loss_scale)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_keep_decreasing_until_one(self):
+ inputs = [False] * 10
+ init_loss_scale = 16
+ expected_outputs = [16, 8, 8, 4, 4, 2, 2, 1, 1, 1]
+
+ self._test_helper(inputs, expected_outputs, init_loss_scale)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_incr_bad_step_clear_good_step(self):
+ inputs = [True, True, True, False, True]
+ expected_outputs = [1, 2, 2, 2, 2]
+ self._test_helper(inputs, expected_outputs)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_incr_good_step_does_not_clear_bad_step(self):
+ inputs = [True, True, True, False, True, False]
+ expected_outputs = [1, 2, 2, 2, 2, 1]
+ self._test_helper(inputs, expected_outputs)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_trigger_loss_scale_update_each_step(self):
+ """Test when incr_every_n_step and decr_every_n_nan_or_inf is 1."""
+ init_loss_scale = 1
+ incr_every_n_step = 1
+ decr_every_n_nan_or_inf = 1
+
+ inputs = [True] * 3 + [False, True, True]
+ expected_outputs = [2, 4, 8, 4, 8, 16]
+
+ self._test_helper(inputs, expected_outputs, init_loss_scale,
+ incr_every_n_step, decr_every_n_nan_or_inf)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_alternating_good_and_bad_gradients_trigger_each_step(self):
+ init_loss_scale = 1
+ incr_every_n_step = 1
+ decr_every_n_nan_or_inf = 1
+
+ inputs = [True, False] * 4 + [True]
+ expected_outputs = [2, 1, 2, 1, 2, 1, 2, 1, 2]
+ self._test_helper(inputs, expected_outputs, init_loss_scale,
+ incr_every_n_step, decr_every_n_nan_or_inf)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_alternating_good_and_bad_gradients_trigger_incr_every_2steps(self):
+ init_loss_scale = 32
+ incr_every_n_step = 2
+ decr_every_n_nan_or_inf = 1
+
+ inputs = [True, False] * 3 + [True]
+ expected_outputs = [32, 16, 16, 8, 8, 4, 4]
+ self._test_helper(inputs, expected_outputs, init_loss_scale,
+ incr_every_n_step, decr_every_n_nan_or_inf)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_random_mix_good_and_bad_gradients(self):
+ init_loss_scale = 4
+ inputs = [
+ False, False, True, True, True, False, True, False, True, True, True,
+ False
+ ]
+ expected_outputs = [4, 2, 2, 4, 4, 4, 4, 2, 2, 4, 4, 4]
+ self._test_helper(inputs, expected_outputs, init_loss_scale)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py
new file mode 100644
index 0000000000..e4e5ccc334
--- /dev/null
+++ b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py
@@ -0,0 +1,166 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Loss scaling optimizer."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.eager import context
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_control_flow_ops
+from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.training import optimizer
+
+
+class LossScaleOptimizer(optimizer.Optimizer):
+ """An optimizer that applies loss scaling in backprop.
+
+ This class is useful for mixed precision training on GPUs (or other potential
+ accelerators), which is an approach to improve compute throughput without loss
+ of model quality.
+
+ The commmon configuration of mixed precision models is the following:
+ * variables are kept in high precision (e.g. float32).
+ * computations are done in lower precision (e.g. float16). variables are
+ casted to lower precision before they're used.
+ * (in training), final gradients are casted back to variable precision and get
+ applied.
+
+ Because computations happen in lower precision, gradients in the backprop pass
+ might underflow in the smaller dynamic range, causing a model to converge at a
+ suboptimal level. This optimizer multiplies the loss by a factor before
+ backprop starts to prevent underflow. Before gradients are applied, they are
+ casted to higher precision and down-scaled by the same factor, so
+ mathematically the variable updates are no different from regular
+ same-precision training.
+
+ See [Nvidia's manual on mixed precision training](
+ https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html)
+ for more details.
+
+ To use loss scale optimizer, one only needs choose a loss scale strategy and
+ wrap a regular optimizer. See examples below.
+
+ ```
+ loss = loss_fn()
+ opt = tf.AdamOptimizer(learning_rate=...)
+
+ # Choose a loss scale manager which decides how to pick the right loss scale
+ # throughout the training process.
+ loss_scale_manger = tf.contrib.mixed_precision.FixedLossScaleManager(5000)
+
+ # Wraps the original optimizer in a LossScaleOptimizer.
+ loss_scale_optimizer = LossScaleOptimizer(opt, loss_scale_manager)
+
+ # Call minimize() on the loss scale optimizer.
+ train_op = loss_scale_optimizer.minimize(loss)
+ ```
+
+ If gradients clipping is applied, one can call
+ `optimizer.compute_gradients()` and `optimizer.apply_gradients()`
+ seperately.
+
+ Notice the following way of using LossScaleOptimizer is not intended. Always
+ use `loss_scale_optimizer.compute_gradients()` to compute gradients instead of
+ `tf.gradients()` if doing mixed precision training.
+
+ ```
+ # The following is a wrong way to use LossScaleOptimizer along with
+ # tf.gradients().
+
+ # Always use loss_scale_optimizer.compute_gradients() to compute grads, or
+ # loss scale is not correctly applied.
+ grads = tf.gradients(loss, ...)
+
+ # Do some custom grad clipping.
+ grads = clip_grads(grads, ...)
+
+ loss_scale_optimizer.apply(grads_and_vars)
+ ```
+ """
+
+ def __init__(self, opt, loss_scale_manager):
+ """Construct a loss scaling optimizer.
+
+ Args:
+ opt: The actual optimizer that will be used to compute and apply the
+ gradients. Must be an implementation of the @{tf.train.Optimizer}
+ interface.
+ loss_scale_manager: A LossScaleManager object.
+ """
+ self._opt = opt
+ self._loss_scale_manager = loss_scale_manager
+
+ def compute_gradients(self,
+ loss,
+ var_list=None,
+ gate_gradients=optimizer.Optimizer.GATE_OP,
+ aggregation_method=None,
+ colocate_gradients_with_ops=False,
+ grad_loss=None):
+ """Compute gradients. See base class @{tf.train.Optimizer}."""
+ loss_scale = self._loss_scale_manager.get_loss_scale()
+ if context.executing_eagerly():
+
+ def scaled_loss():
+ loss_val = loss()
+ return loss_val * math_ops.cast(loss_scale, loss_val.dtype.base_dtype)
+ else:
+ if callable(loss):
+ loss_val = loss()
+ else:
+ loss_val = loss
+ scaled_loss = loss_val * math_ops.cast(loss_scale,
+ loss_val.dtype.base_dtype)
+ grads_and_vars = self._opt.compute_gradients(
+ scaled_loss,
+ var_list=var_list,
+ gate_gradients=gate_gradients,
+ aggregation_method=aggregation_method,
+ colocate_gradients_with_ops=colocate_gradients_with_ops,
+ grad_loss=grad_loss)
+ return self._down_scale(grads_and_vars, loss_scale)
+
+ def apply_gradients(self, grads_and_vars, global_step=None, name=None):
+ """Apply gradients. See base class @{tf.train.Optimizer}."""
+ grads = [g for (g, _) in grads_and_vars]
+
+ is_finite_grad = []
+ for g in grads:
+ is_finite_grad.append(math_ops.reduce_all(gen_math_ops.is_finite(g)))
+ is_overall_finite = math_ops.reduce_all(is_finite_grad)
+
+ # Only update gradients when all grads are finite.
+ def true_apply_gradients_fn():
+ return self._opt.apply_gradients(grads_and_vars, global_step, name)
+
+ update_vars = control_flow_ops.cond(
+ is_overall_finite, true_apply_gradients_fn, gen_control_flow_ops.no_op)
+ # Potentially adjust gradient scale in case of finite gradients.
+ return control_flow_ops.group(
+ update_vars,
+ self._loss_scale_manager.update_loss_scale(is_overall_finite))
+
+ def _down_scale(self, grads_vars, loss_scale):
+ # Down scale grads by the loss_scale.
+ gv = []
+ inv_loss_scale = gen_math_ops.reciprocal(loss_scale)
+ for g, v in grads_vars:
+ if g is not None:
+ gv.append((g * math_ops.cast(inv_loss_scale, g.dtype.base_dtype), v))
+ else:
+ gv.append((g, v))
+ return gv
diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py
new file mode 100644
index 0000000000..dded61ccd5
--- /dev/null
+++ b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py
@@ -0,0 +1,216 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for LossScaleOptimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.mixed_precision.python import loss_scale_manager as lsm_lib
+from tensorflow.contrib.mixed_precision.python import loss_scale_optimizer as lso
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import gradient_descent as gd
+
+
+class LossScaleOptimizerTest(test.TestCase):
+
+ def _build_graph(self, lr, init_val, loss_scale_opt_fn=None):
+ x = variable_scope.get_variable(
+ "x", initializer=init_val, dtype=dtypes.float32)
+ c1 = constant_op.constant(1e4, dtype=dtypes.float16)
+ c2 = constant_op.constant(1e-4, dtype=dtypes.float16)
+ c3 = constant_op.constant(1e-4, dtype=dtypes.float16)
+ if context.executing_eagerly():
+ loss = lambda: math_ops.cast(x, dtypes.float16) * c1 * c2 * c3
+ else:
+ loss = math_ops.cast(x, dtypes.float16) * c1 * c2 * c3
+
+ opt = gd.GradientDescentOptimizer(lr)
+ if loss_scale_opt_fn:
+ opt = loss_scale_opt_fn(opt)
+ return x, loss, opt
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_float16_underflow_without_loss_scale(self):
+ lr = 1
+ init_val = 1.
+ x, loss, opt = self._build_graph(lr, init_val)
+
+ self.evaluate(variables.global_variables_initializer())
+ self.evaluate(opt.minimize(loss, var_list=[x]))
+
+ # Symbolic grad is c1 * c2 * c3 = 1e-4 and actual grad is 0, since in
+ # backprop, c2 * c3 underflows in fp16 range. So variable isn't updated.
+ expected_update = 0
+ symbolic_update = 1e-4 * lr
+ self.assertAllClose(
+ init_val - expected_update,
+ self.evaluate(x),
+ rtol=0,
+ atol=min(symbolic_update, 1e-6))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_float16_with_loss_scale(self):
+ lr = 1.
+ init_val = 1.
+
+ def loss_scale_opt_fn(opt):
+ return lso.LossScaleOptimizer(opt, lsm_lib.FixedLossScaleManager(1e4))
+
+ x, loss, opt = self._build_graph(lr, init_val, loss_scale_opt_fn)
+
+ self.evaluate(variables.global_variables_initializer())
+ self.evaluate(opt.minimize(loss, var_list=[x]))
+
+ # Symbolic grad is c1 * c2 * c3 = 1e-4 and actual grad is the same, due to
+ # up-scaled loss before backprop starts.
+ expected_update = 1.e-4 * lr
+ self.assertAllClose(
+ init_val - expected_update,
+ self.evaluate(x),
+ rtol=0,
+ atol=min(expected_update, 1e-6))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_compute_gradients_with_loss_scale(self):
+ lr = 1
+ init_val = 1.
+
+ def loss_scale_opt_fn(opt):
+ return lso.LossScaleOptimizer(opt, lsm_lib.FixedLossScaleManager(1e4))
+
+ x, loss, opt = self._build_graph(lr, init_val, loss_scale_opt_fn)
+ grads_and_vars = opt.compute_gradients(loss, var_list=[x])
+
+ self.assertEqual(len(grads_and_vars), 1)
+
+ self.evaluate(variables.global_variables_initializer())
+ g_v = self.evaluate(grads_and_vars[0][0])
+ self.assertAllClose(g_v, 1e-4)
+ self.assertIs(grads_and_vars[0][1], x)
+ # Gradients aren't applied.
+ self.assertAllClose(init_val, self.evaluate(x), rtol=0, atol=1e-6)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_compute_gradients_without_loss_scale(self):
+ lr = 1
+ init_val = 1.
+ x, loss, opt = self._build_graph(lr, init_val)
+ grads_and_vars = opt.compute_gradients(loss, var_list=[x])
+
+ self.assertEqual(len(grads_and_vars), 1)
+ self.evaluate(variables.global_variables_initializer())
+ g_v = self.evaluate(grads_and_vars[0][0])
+ self.assertAllClose(g_v, 0)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_apply_gradients(self):
+
+ x = variable_scope.get_variable("x", initializer=1., dtype=dtypes.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices([np.nan, np.inf, 0.1])
+ itr = dataset.make_one_shot_iterator()
+
+ lr = 1
+ opt = gd.GradientDescentOptimizer(lr)
+ lsm = lsm_lib.FixedLossScaleManager(1.e4)
+ opt = lso.LossScaleOptimizer(opt, lsm)
+ train_fn = lambda: opt.apply_gradients([(itr.get_next(), x)])
+ if not context.executing_eagerly():
+ train_op = train_fn()
+
+ expected_output = [1, 1, 1 - 0.1]
+ actual_output = []
+
+ self.evaluate(variables.global_variables_initializer())
+ for _ in range(3):
+ # nan or inf is not applied.
+ if context.executing_eagerly():
+ train_fn()
+ else:
+ self.evaluate(train_op)
+ actual_output.append(self.evaluate(x))
+ self.assertAllClose(expected_output, actual_output)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_apply_gradients_loss_scale_is_updated(self):
+
+ class SimpleLossScaleManager(lsm_lib.LossScaleManager):
+ """A simple loss scale manager for easier testing.
+
+ It increments loss scale by 1 if grads are finite, and decreases loss
+ scale by 1 if otherwise.
+ """
+
+ def __init__(self, loss_scale):
+ self._loss_scale = variable_scope.variable(
+ name="loss_scale",
+ initial_value=loss_scale,
+ dtype=dtypes.float32,
+ trainable=False)
+
+ def get_loss_scale(self):
+ return self._loss_scale
+
+ def update_loss_scale(self, if_finite_grads):
+ return control_flow_ops.cond(
+ if_finite_grads, lambda: state_ops.assign_add(self._loss_scale, 1),
+ lambda: state_ops.assign_sub(self._loss_scale, 1))
+
+ x = variable_scope.get_variable("x", initializer=1., dtype=dtypes.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices([np.nan, np.inf, 0.1])
+ itr = dataset.make_one_shot_iterator()
+
+ lr = 1
+ init_loss_scale = 8
+ opt = gd.GradientDescentOptimizer(lr)
+ lsm = SimpleLossScaleManager(init_loss_scale)
+ opt = lso.LossScaleOptimizer(opt, lsm)
+ train_fn = lambda: opt.apply_gradients([(itr.get_next(), x)])
+ if not context.executing_eagerly():
+ train_op = train_fn()
+
+ self.evaluate(variables.global_variables_initializer())
+
+ expected_loss_scale = [
+ init_loss_scale - 1, init_loss_scale - 2, init_loss_scale - 2 + 1
+ ]
+ expected_output = [1, 1, 1 - 0.1]
+ actual_output = []
+ for i in range(3):
+ # nan or inf is not applied.
+ if context.executing_eagerly():
+ train_fn()
+ else:
+ self.evaluate(train_op)
+ actual_output.append(self.evaluate(x))
+ self.assertAllClose(expected_loss_scale[i],
+ self.evaluate(lsm._loss_scale))
+ self.assertAllClose(expected_output, actual_output)
+
+
+if __name__ == "__main__":
+ test.main()