diff options
author | 2018-05-14 17:51:11 -0700 | |
---|---|---|
committer | 2018-05-14 17:53:46 -0700 | |
commit | 4c2cb712c7d7be93533da240d4c8e55e69d79625 (patch) | |
tree | 2ab6a3445c6df34d60106f9ee2a4f2c5184f1a4b /tensorflow/contrib/mixed_precision | |
parent | e1a49f30435096c7e0817dde2e472c85db143a81 (diff) |
Introduce LossScalingOptimizer for mixed precision training.
PiperOrigin-RevId: 196597196
Diffstat (limited to 'tensorflow/contrib/mixed_precision')
7 files changed, 904 insertions, 0 deletions
diff --git a/tensorflow/contrib/mixed_precision/BUILD b/tensorflow/contrib/mixed_precision/BUILD new file mode 100644 index 0000000000..3dfb95e0a0 --- /dev/null +++ b/tensorflow/contrib/mixed_precision/BUILD @@ -0,0 +1,32 @@ +# Mixed precision training optimizers + +package( + default_visibility = ["//tensorflow:internal"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "mixed_precision", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/contrib/mixed_precision/python:loss_scale_manager", + "//tensorflow/contrib/mixed_precision/python:loss_scale_optimizer", + ], +) diff --git a/tensorflow/contrib/mixed_precision/__init__.py b/tensorflow/contrib/mixed_precision/__init__.py new file mode 100644 index 0000000000..43e98cdda0 --- /dev/null +++ b/tensorflow/contrib/mixed_precision/__init__.py @@ -0,0 +1,34 @@ +# Copyright 2018 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# mixed_precisiond under the License is mixed_precisiond on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Library for mixed precision training.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.mixed_precision.python.loss_scale_manager import * +from tensorflow.contrib.mixed_precision.python.loss_scale_optimizer import * + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + "LossScaleManager", + "FixedLossScaleManager", + "ExponentialUpdateLossScaleManager", + "LossScaleOptimizer", +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/mixed_precision/python/BUILD b/tensorflow/contrib/mixed_precision/python/BUILD new file mode 100644 index 0000000000..1d769e1614 --- /dev/null +++ b/tensorflow/contrib/mixed_precision/python/BUILD @@ -0,0 +1,74 @@ +# Mixed precision training optimizers + +package( + default_visibility = ["//tensorflow:internal"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +py_library( + name = "loss_scale_manager", + srcs = ["loss_scale_manager.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:state_ops", + "//tensorflow/python:variable_scope", + ], +) + +py_test( + name = "loss_scale_manager_test", + size = "small", + srcs = ["loss_scale_manager_test.py"], + deps = [ + ":loss_scale_manager", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:platform_test", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_library( + name = "loss_scale_optimizer", + srcs = ["loss_scale_optimizer.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":loss_scale_manager", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:util", + ], +) + +py_test( + name = "loss_scale_optimizer_test", + size = "small", + srcs = ["loss_scale_optimizer_test.py"], + deps = [ + ":loss_scale_optimizer", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:platform_test", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_manager.py b/tensorflow/contrib/mixed_precision/python/loss_scale_manager.py new file mode 100644 index 0000000000..be7377b151 --- /dev/null +++ b/tensorflow/contrib/mixed_precision/python/loss_scale_manager.py @@ -0,0 +1,200 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""LossScaleManager classes for mixed precision training.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import six + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_control_flow_ops +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope + + +@six.add_metaclass(abc.ABCMeta) +class LossScaleManager(object): + """Abstract loss scale manager class. + + Loss scale managers with a different strategy should subclass this class. + Loss scaling is a process that: + + 1) Applies a multiplier on the loss before computing gradients, and + 2) Applies the reciprocal of the multiplier on the gradients before they are + applied on variables. + + This class is used together with + @{tf.contrib.mixed_precision.LossScaleOptimizer} for mixed precision training + (float32 variables and float16 ops) on Nvidia GPUs in order to achieve the + same model quality as single precision training, with the benefits of + potential higher throughput. + + See @{tf.contrib.mixed_precision.LossScaleOptimizer} for more details. + """ + + @abc.abstractmethod + def get_loss_scale(self): + """Returns the loss scale as a scalar `float32` tensor.""" + pass + + @abc.abstractmethod + def update_loss_scale(self, finite_grads): + """Updates loss scale based on if gradients are finite in current step. + + Args: + finite_grads: bool scalar tensor indicating if all gradients are + finite (i.e., not inf or nan). + + Returns: + An op, when executed updates the loss scale. If eager execution is + enabled, does not return anything. + """ + del finite_grads + return + + +class FixedLossScaleManager(LossScaleManager): + """Loss scale manager with a fixed loss scale. + + The loss scale is not updated for the lifetime of the class. + """ + + def __init__(self, loss_scale): + """Creates the fixed loss scale manager. + + Args: + loss_scale: A Python float. Its ideal value varies depending on models to + run. Choosing a too small loss_scale might affect model quality; a too + big loss_scale might cause inf or nan. There is no single right + loss_scale to apply. There is no harm choosing a relatively big number + as long as no nan or inf is encountered in training. + + Raises: + ValueError: If loss_scale is less than 1. + """ + if loss_scale < 1: + raise ValueError("loss scale must be at least 1.") + self._loss_scale = ops.convert_to_tensor(loss_scale, dtype=dtypes.float32) + + def get_loss_scale(self): + return self._loss_scale + + def update_loss_scale(self, finite_grads): + del finite_grads + return gen_control_flow_ops.no_op() + + +class ExponentialUpdateLossScaleManager(LossScaleManager): + """Loss scale manager uses an exponential update strategy. + + In general, the strategy increases loss scale by a greater-than-one factor + after encountering a consecutive series of steps with finite gradients; + Similarly, it decreases the loss scale by a factor when the accumulated number + of steps with non-finite (nan or inf) gradients are met. An update is not + applied if its result is less than 1 or overflows the float32 dynamic range. + + The number of finite and non-finite steps are cleared every time the loss + scale is changed. The condition to decrease the loss scale is looser than to + increase it since the former does not require the steps to be consecutive. + """ + + def __init__(self, + init_loss_scale, + incr_every_n_steps, + decr_every_n_nan_or_inf=2, + incr_ratio=2, + decr_ratio=0.8): + """Constructor of exponential-update loss scale manager. + + Args: + init_loss_scale: A Python float. The loss scale to use at the beginning. + incr_every_n_steps: Increases loss scale every n consecutive steps with + finite gradients. + decr_every_n_nan_or_inf: Decreases loss scale every n accumulated steps + with nan or inf gradients. + incr_ratio: The multiplier to use when increasing the loss scale. + decr_ratio: The less-than-one-multiplier to use when decreasing the loss + scale. + """ + self._incr_every_n_steps = incr_every_n_steps + self._decr_every_n_nan_or_inf = decr_every_n_nan_or_inf + self._incr_ratio = incr_ratio + self._decr_ratio = decr_ratio + self._loss_scale = variable_scope.variable( + name="loss_scale", + initial_value=ops.convert_to_tensor(init_loss_scale, dtypes.float32), + dtype=dtypes.float32, + trainable=False) + self._num_good_steps = variable_scope.variable( + name="good_steps", initial_value=0, dtype=dtypes.int32, trainable=False) + self._num_bad_steps = variable_scope.variable( + name="bad_steps", initial_value=0, dtype=dtypes.int32, trainable=False) + + def _reset_stats(self): + return control_flow_ops.group( + state_ops.assign(self._num_good_steps, 0), + state_ops.assign(self._num_bad_steps, 0)) + + def get_loss_scale(self): + """Returns the loss scale.""" + return self._loss_scale + + def update_loss_scale(self, finite_grads): + """Updates loss scale based on if gradients are finite in current step.""" + + def update_if_finite_grads(): + """Branch function when grads are all finite.""" + + def incr_loss_scale(): + new_loss_scale = control_flow_ops.cond( + gen_math_ops.is_finite(self._loss_scale * self._incr_ratio), + lambda: self._loss_scale * self._incr_ratio, + lambda: self._loss_scale) + update_op = state_ops.assign(self._loss_scale, new_loss_scale) + # When loss_scale is updated, both good and bad steps are reset. + return control_flow_ops.group(update_op, self._reset_stats()) + + return control_flow_ops.cond( + self._num_good_steps + 1 >= self._incr_every_n_steps, + incr_loss_scale, + lambda: state_ops.assign_add(self._num_good_steps, 1).op) + + def update_if_not_finite_grads(): + """Branch function when any grad is not finite.""" + + def decr_loss_scale(): + update_op = state_ops.assign( + self._loss_scale, + gen_math_ops.maximum(1., self._loss_scale * self._decr_ratio)) + # When loss_scale is updated, both good and bad steps are reset. + return control_flow_ops.group(update_op, self._reset_stats()) + + def just_update_steps(): + # When bad_steps is incremented, good_step is reset. + return control_flow_ops.group( + state_ops.assign_add(self._num_bad_steps, 1), + state_ops.assign(self._num_good_steps, 0)) + + return control_flow_ops.cond( + self._num_bad_steps + 1 >= self._decr_every_n_nan_or_inf, + decr_loss_scale, just_update_steps) + + return control_flow_ops.cond(finite_grads, update_if_finite_grads, + update_if_not_finite_grads) diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py b/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py new file mode 100644 index 0000000000..480f5f6eaf --- /dev/null +++ b/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py @@ -0,0 +1,182 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for LossScaleManager classes..""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.mixed_precision.python import loss_scale_manager as lsm_lib +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context +from tensorflow.python.framework import test_util +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +def _GetExampleIter(inputs): + dataset = dataset_ops.Dataset.from_tensor_slices(inputs) + return dataset.make_one_shot_iterator() + + +class FixedLossScaleManagerTest(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def test_basic(self): + itr = _GetExampleIter([True] * 10 + [False] * 10) + + loss_scale = 1000 + lsm = lsm_lib.FixedLossScaleManager(loss_scale) + update_fn = lambda: lsm.update_loss_scale(itr.get_next()) + + self.evaluate(variables.global_variables_initializer()) + if not context.executing_eagerly(): + update_op = update_fn() + for _ in range(10): + if context.executing_eagerly(): + update_fn() + else: + self.evaluate(update_op) + self.assertEqual(loss_scale, self.evaluate(lsm.get_loss_scale())) + + +class ExponentialUpdateLossScaleManagerTest(test.TestCase): + + def _test_helper(self, + inputs, + expected_outputs, + init_loss_scale=1, + incr_every_n_step=2, + decr_every_n_nan_or_inf=2): + ratio = 2 + lsm = lsm_lib.ExponentialUpdateLossScaleManager( + init_loss_scale=init_loss_scale, + incr_every_n_steps=incr_every_n_step, + decr_every_n_nan_or_inf=decr_every_n_nan_or_inf, + incr_ratio=ratio, + decr_ratio=1. / ratio) + itr = _GetExampleIter(inputs) + update_fn = lambda: lsm.update_loss_scale(itr.get_next()) + + self.evaluate(variables.global_variables_initializer()) + actual_outputs = [] + + if not context.executing_eagerly(): + update_op = update_fn() + for _ in range(len(inputs)): + if context.executing_eagerly(): + update_fn() + else: + self.evaluate(update_op) + actual_outputs.append(self.evaluate(lsm.get_loss_scale())) + self.assertEqual(actual_outputs, expected_outputs) + + @test_util.run_in_graph_and_eager_modes() + def test_increase_every_n_steps(self): + inputs = [True] * 6 + expected_outputs = [1, 2, 2, 4, 4, 8] + self._test_helper(inputs, expected_outputs) + + @test_util.run_in_graph_and_eager_modes() + def test_keep_increasing_until_capped(self): + init_loss_scale = np.finfo(np.float32).max / 4 + 10 + max_float = np.finfo(np.float32).max + + inputs = [True] * 6 + # Output is capped the 2nd time it doubles. + expected_outputs = [ + init_loss_scale, init_loss_scale * 2, init_loss_scale * 2, max_float, + max_float, max_float + ] + + self._test_helper(inputs, expected_outputs, init_loss_scale) + + @test_util.run_in_graph_and_eager_modes() + def test_decrease_every_n_steps(self): + inputs = [False] * 6 + init_loss_scale = 1024 + expected_outputs = [1024, 512, 512, 256, 256, 128] + + self._test_helper(inputs, expected_outputs, init_loss_scale) + + @test_util.run_in_graph_and_eager_modes() + def test_keep_decreasing_until_one(self): + inputs = [False] * 10 + init_loss_scale = 16 + expected_outputs = [16, 8, 8, 4, 4, 2, 2, 1, 1, 1] + + self._test_helper(inputs, expected_outputs, init_loss_scale) + + @test_util.run_in_graph_and_eager_modes() + def test_incr_bad_step_clear_good_step(self): + inputs = [True, True, True, False, True] + expected_outputs = [1, 2, 2, 2, 2] + self._test_helper(inputs, expected_outputs) + + @test_util.run_in_graph_and_eager_modes() + def test_incr_good_step_does_not_clear_bad_step(self): + inputs = [True, True, True, False, True, False] + expected_outputs = [1, 2, 2, 2, 2, 1] + self._test_helper(inputs, expected_outputs) + + @test_util.run_in_graph_and_eager_modes() + def test_trigger_loss_scale_update_each_step(self): + """Test when incr_every_n_step and decr_every_n_nan_or_inf is 1.""" + init_loss_scale = 1 + incr_every_n_step = 1 + decr_every_n_nan_or_inf = 1 + + inputs = [True] * 3 + [False, True, True] + expected_outputs = [2, 4, 8, 4, 8, 16] + + self._test_helper(inputs, expected_outputs, init_loss_scale, + incr_every_n_step, decr_every_n_nan_or_inf) + + @test_util.run_in_graph_and_eager_modes() + def test_alternating_good_and_bad_gradients_trigger_each_step(self): + init_loss_scale = 1 + incr_every_n_step = 1 + decr_every_n_nan_or_inf = 1 + + inputs = [True, False] * 4 + [True] + expected_outputs = [2, 1, 2, 1, 2, 1, 2, 1, 2] + self._test_helper(inputs, expected_outputs, init_loss_scale, + incr_every_n_step, decr_every_n_nan_or_inf) + + @test_util.run_in_graph_and_eager_modes() + def test_alternating_good_and_bad_gradients_trigger_incr_every_2steps(self): + init_loss_scale = 32 + incr_every_n_step = 2 + decr_every_n_nan_or_inf = 1 + + inputs = [True, False] * 3 + [True] + expected_outputs = [32, 16, 16, 8, 8, 4, 4] + self._test_helper(inputs, expected_outputs, init_loss_scale, + incr_every_n_step, decr_every_n_nan_or_inf) + + @test_util.run_in_graph_and_eager_modes() + def test_random_mix_good_and_bad_gradients(self): + init_loss_scale = 4 + inputs = [ + False, False, True, True, True, False, True, False, True, True, True, + False + ] + expected_outputs = [4, 2, 2, 4, 4, 4, 4, 2, 2, 4, 4, 4] + self._test_helper(inputs, expected_outputs, init_loss_scale) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py new file mode 100644 index 0000000000..e4e5ccc334 --- /dev/null +++ b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py @@ -0,0 +1,166 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Loss scaling optimizer.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.eager import context +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_control_flow_ops +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.training import optimizer + + +class LossScaleOptimizer(optimizer.Optimizer): + """An optimizer that applies loss scaling in backprop. + + This class is useful for mixed precision training on GPUs (or other potential + accelerators), which is an approach to improve compute throughput without loss + of model quality. + + The commmon configuration of mixed precision models is the following: + * variables are kept in high precision (e.g. float32). + * computations are done in lower precision (e.g. float16). variables are + casted to lower precision before they're used. + * (in training), final gradients are casted back to variable precision and get + applied. + + Because computations happen in lower precision, gradients in the backprop pass + might underflow in the smaller dynamic range, causing a model to converge at a + suboptimal level. This optimizer multiplies the loss by a factor before + backprop starts to prevent underflow. Before gradients are applied, they are + casted to higher precision and down-scaled by the same factor, so + mathematically the variable updates are no different from regular + same-precision training. + + See [Nvidia's manual on mixed precision training]( + https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html) + for more details. + + To use loss scale optimizer, one only needs choose a loss scale strategy and + wrap a regular optimizer. See examples below. + + ``` + loss = loss_fn() + opt = tf.AdamOptimizer(learning_rate=...) + + # Choose a loss scale manager which decides how to pick the right loss scale + # throughout the training process. + loss_scale_manger = tf.contrib.mixed_precision.FixedLossScaleManager(5000) + + # Wraps the original optimizer in a LossScaleOptimizer. + loss_scale_optimizer = LossScaleOptimizer(opt, loss_scale_manager) + + # Call minimize() on the loss scale optimizer. + train_op = loss_scale_optimizer.minimize(loss) + ``` + + If gradients clipping is applied, one can call + `optimizer.compute_gradients()` and `optimizer.apply_gradients()` + seperately. + + Notice the following way of using LossScaleOptimizer is not intended. Always + use `loss_scale_optimizer.compute_gradients()` to compute gradients instead of + `tf.gradients()` if doing mixed precision training. + + ``` + # The following is a wrong way to use LossScaleOptimizer along with + # tf.gradients(). + + # Always use loss_scale_optimizer.compute_gradients() to compute grads, or + # loss scale is not correctly applied. + grads = tf.gradients(loss, ...) + + # Do some custom grad clipping. + grads = clip_grads(grads, ...) + + loss_scale_optimizer.apply(grads_and_vars) + ``` + """ + + def __init__(self, opt, loss_scale_manager): + """Construct a loss scaling optimizer. + + Args: + opt: The actual optimizer that will be used to compute and apply the + gradients. Must be an implementation of the @{tf.train.Optimizer} + interface. + loss_scale_manager: A LossScaleManager object. + """ + self._opt = opt + self._loss_scale_manager = loss_scale_manager + + def compute_gradients(self, + loss, + var_list=None, + gate_gradients=optimizer.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + grad_loss=None): + """Compute gradients. See base class @{tf.train.Optimizer}.""" + loss_scale = self._loss_scale_manager.get_loss_scale() + if context.executing_eagerly(): + + def scaled_loss(): + loss_val = loss() + return loss_val * math_ops.cast(loss_scale, loss_val.dtype.base_dtype) + else: + if callable(loss): + loss_val = loss() + else: + loss_val = loss + scaled_loss = loss_val * math_ops.cast(loss_scale, + loss_val.dtype.base_dtype) + grads_and_vars = self._opt.compute_gradients( + scaled_loss, + var_list=var_list, + gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + grad_loss=grad_loss) + return self._down_scale(grads_and_vars, loss_scale) + + def apply_gradients(self, grads_and_vars, global_step=None, name=None): + """Apply gradients. See base class @{tf.train.Optimizer}.""" + grads = [g for (g, _) in grads_and_vars] + + is_finite_grad = [] + for g in grads: + is_finite_grad.append(math_ops.reduce_all(gen_math_ops.is_finite(g))) + is_overall_finite = math_ops.reduce_all(is_finite_grad) + + # Only update gradients when all grads are finite. + def true_apply_gradients_fn(): + return self._opt.apply_gradients(grads_and_vars, global_step, name) + + update_vars = control_flow_ops.cond( + is_overall_finite, true_apply_gradients_fn, gen_control_flow_ops.no_op) + # Potentially adjust gradient scale in case of finite gradients. + return control_flow_ops.group( + update_vars, + self._loss_scale_manager.update_loss_scale(is_overall_finite)) + + def _down_scale(self, grads_vars, loss_scale): + # Down scale grads by the loss_scale. + gv = [] + inv_loss_scale = gen_math_ops.reciprocal(loss_scale) + for g, v in grads_vars: + if g is not None: + gv.append((g * math_ops.cast(inv_loss_scale, g.dtype.base_dtype), v)) + else: + gv.append((g, v)) + return gv diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py new file mode 100644 index 0000000000..dded61ccd5 --- /dev/null +++ b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py @@ -0,0 +1,216 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for LossScaleOptimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.mixed_precision.python import loss_scale_manager as lsm_lib +from tensorflow.contrib.mixed_precision.python import loss_scale_optimizer as lso +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import gradient_descent as gd + + +class LossScaleOptimizerTest(test.TestCase): + + def _build_graph(self, lr, init_val, loss_scale_opt_fn=None): + x = variable_scope.get_variable( + "x", initializer=init_val, dtype=dtypes.float32) + c1 = constant_op.constant(1e4, dtype=dtypes.float16) + c2 = constant_op.constant(1e-4, dtype=dtypes.float16) + c3 = constant_op.constant(1e-4, dtype=dtypes.float16) + if context.executing_eagerly(): + loss = lambda: math_ops.cast(x, dtypes.float16) * c1 * c2 * c3 + else: + loss = math_ops.cast(x, dtypes.float16) * c1 * c2 * c3 + + opt = gd.GradientDescentOptimizer(lr) + if loss_scale_opt_fn: + opt = loss_scale_opt_fn(opt) + return x, loss, opt + + @test_util.run_in_graph_and_eager_modes() + def test_float16_underflow_without_loss_scale(self): + lr = 1 + init_val = 1. + x, loss, opt = self._build_graph(lr, init_val) + + self.evaluate(variables.global_variables_initializer()) + self.evaluate(opt.minimize(loss, var_list=[x])) + + # Symbolic grad is c1 * c2 * c3 = 1e-4 and actual grad is 0, since in + # backprop, c2 * c3 underflows in fp16 range. So variable isn't updated. + expected_update = 0 + symbolic_update = 1e-4 * lr + self.assertAllClose( + init_val - expected_update, + self.evaluate(x), + rtol=0, + atol=min(symbolic_update, 1e-6)) + + @test_util.run_in_graph_and_eager_modes() + def test_float16_with_loss_scale(self): + lr = 1. + init_val = 1. + + def loss_scale_opt_fn(opt): + return lso.LossScaleOptimizer(opt, lsm_lib.FixedLossScaleManager(1e4)) + + x, loss, opt = self._build_graph(lr, init_val, loss_scale_opt_fn) + + self.evaluate(variables.global_variables_initializer()) + self.evaluate(opt.minimize(loss, var_list=[x])) + + # Symbolic grad is c1 * c2 * c3 = 1e-4 and actual grad is the same, due to + # up-scaled loss before backprop starts. + expected_update = 1.e-4 * lr + self.assertAllClose( + init_val - expected_update, + self.evaluate(x), + rtol=0, + atol=min(expected_update, 1e-6)) + + @test_util.run_in_graph_and_eager_modes() + def test_compute_gradients_with_loss_scale(self): + lr = 1 + init_val = 1. + + def loss_scale_opt_fn(opt): + return lso.LossScaleOptimizer(opt, lsm_lib.FixedLossScaleManager(1e4)) + + x, loss, opt = self._build_graph(lr, init_val, loss_scale_opt_fn) + grads_and_vars = opt.compute_gradients(loss, var_list=[x]) + + self.assertEqual(len(grads_and_vars), 1) + + self.evaluate(variables.global_variables_initializer()) + g_v = self.evaluate(grads_and_vars[0][0]) + self.assertAllClose(g_v, 1e-4) + self.assertIs(grads_and_vars[0][1], x) + # Gradients aren't applied. + self.assertAllClose(init_val, self.evaluate(x), rtol=0, atol=1e-6) + + @test_util.run_in_graph_and_eager_modes() + def test_compute_gradients_without_loss_scale(self): + lr = 1 + init_val = 1. + x, loss, opt = self._build_graph(lr, init_val) + grads_and_vars = opt.compute_gradients(loss, var_list=[x]) + + self.assertEqual(len(grads_and_vars), 1) + self.evaluate(variables.global_variables_initializer()) + g_v = self.evaluate(grads_and_vars[0][0]) + self.assertAllClose(g_v, 0) + + @test_util.run_in_graph_and_eager_modes() + def test_apply_gradients(self): + + x = variable_scope.get_variable("x", initializer=1., dtype=dtypes.float32) + dataset = dataset_ops.Dataset.from_tensor_slices([np.nan, np.inf, 0.1]) + itr = dataset.make_one_shot_iterator() + + lr = 1 + opt = gd.GradientDescentOptimizer(lr) + lsm = lsm_lib.FixedLossScaleManager(1.e4) + opt = lso.LossScaleOptimizer(opt, lsm) + train_fn = lambda: opt.apply_gradients([(itr.get_next(), x)]) + if not context.executing_eagerly(): + train_op = train_fn() + + expected_output = [1, 1, 1 - 0.1] + actual_output = [] + + self.evaluate(variables.global_variables_initializer()) + for _ in range(3): + # nan or inf is not applied. + if context.executing_eagerly(): + train_fn() + else: + self.evaluate(train_op) + actual_output.append(self.evaluate(x)) + self.assertAllClose(expected_output, actual_output) + + @test_util.run_in_graph_and_eager_modes() + def test_apply_gradients_loss_scale_is_updated(self): + + class SimpleLossScaleManager(lsm_lib.LossScaleManager): + """A simple loss scale manager for easier testing. + + It increments loss scale by 1 if grads are finite, and decreases loss + scale by 1 if otherwise. + """ + + def __init__(self, loss_scale): + self._loss_scale = variable_scope.variable( + name="loss_scale", + initial_value=loss_scale, + dtype=dtypes.float32, + trainable=False) + + def get_loss_scale(self): + return self._loss_scale + + def update_loss_scale(self, if_finite_grads): + return control_flow_ops.cond( + if_finite_grads, lambda: state_ops.assign_add(self._loss_scale, 1), + lambda: state_ops.assign_sub(self._loss_scale, 1)) + + x = variable_scope.get_variable("x", initializer=1., dtype=dtypes.float32) + dataset = dataset_ops.Dataset.from_tensor_slices([np.nan, np.inf, 0.1]) + itr = dataset.make_one_shot_iterator() + + lr = 1 + init_loss_scale = 8 + opt = gd.GradientDescentOptimizer(lr) + lsm = SimpleLossScaleManager(init_loss_scale) + opt = lso.LossScaleOptimizer(opt, lsm) + train_fn = lambda: opt.apply_gradients([(itr.get_next(), x)]) + if not context.executing_eagerly(): + train_op = train_fn() + + self.evaluate(variables.global_variables_initializer()) + + expected_loss_scale = [ + init_loss_scale - 1, init_loss_scale - 2, init_loss_scale - 2 + 1 + ] + expected_output = [1, 1, 1 - 0.1] + actual_output = [] + for i in range(3): + # nan or inf is not applied. + if context.executing_eagerly(): + train_fn() + else: + self.evaluate(train_op) + actual_output.append(self.evaluate(x)) + self.assertAllClose(expected_loss_scale[i], + self.evaluate(lsm._loss_scale)) + self.assertAllClose(expected_output, actual_output) + + +if __name__ == "__main__": + test.main() |