aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/opt
diff options
context:
space:
mode:
authorGravatar Xin Jin <jinxin900924@gmail.com>2018-09-04 20:44:43 +0800
committerGravatar GitHub <noreply@github.com>2018-09-04 20:44:43 +0800
commitce035c2493c060b38e53ca7a63c66b26e265b210 (patch)
tree85af4fc680847ffbe06037b429a013ade32c4ce4 /tensorflow/contrib/opt
parent16c42f0d4826b12a5359281997ee3f8e27fd5a87 (diff)
parent1c3d02eb3594e9d92cd26562e797142ee34505b2 (diff)
Merge branch 'master' into ma_easgd
Diffstat (limited to 'tensorflow/contrib/opt')
-rw-r--r--tensorflow/contrib/opt/BUILD95
-rw-r--r--tensorflow/contrib/opt/__init__.py19
-rw-r--r--tensorflow/contrib/opt/python/training/adamax_test.py18
-rw-r--r--tensorflow/contrib/opt/python/training/addsign_test.py6
-rw-r--r--tensorflow/contrib/opt/python/training/elastic_average_optimizer.py159
-rw-r--r--tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py129
-rw-r--r--tensorflow/contrib/opt/python/training/external_optimizer_test.py18
-rw-r--r--tensorflow/contrib/opt/python/training/ggt.py312
-rw-r--r--tensorflow/contrib/opt/python/training/ggt_test.py183
-rw-r--r--tensorflow/contrib/opt/python/training/lars_optimizer.py164
-rw-r--r--tensorflow/contrib/opt/python/training/lars_optimizer_test.py127
-rw-r--r--tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py4
-rw-r--r--tensorflow/contrib/opt/python/training/matrix_functions.py155
-rw-r--r--tensorflow/contrib/opt/python/training/matrix_functions_test.py63
-rw-r--r--tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py6
-rw-r--r--tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper_test.py4
-rw-r--r--tensorflow/contrib/opt/python/training/nadam_optimizer_test.py4
-rw-r--r--tensorflow/contrib/opt/python/training/powersign.py2
-rw-r--r--tensorflow/contrib/opt/python/training/powersign_test.py2
-rw-r--r--tensorflow/contrib/opt/python/training/reg_adagrad_optimizer_test.py22
-rw-r--r--tensorflow/contrib/opt/python/training/shampoo.py420
-rw-r--r--tensorflow/contrib/opt/python/training/shampoo_test.py772
-rw-r--r--tensorflow/contrib/opt/python/training/sign_decay_test.py6
-rw-r--r--tensorflow/contrib/opt/python/training/variable_clipping_optimizer_test.py4
-rw-r--r--tensorflow/contrib/opt/python/training/weight_decay_optimizers.py435
-rw-r--r--tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py188
26 files changed, 3215 insertions, 102 deletions
diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD
index 13aa1d7e7a..93e589907e 100644
--- a/tensorflow/contrib/opt/BUILD
+++ b/tensorflow/contrib/opt/BUILD
@@ -19,24 +19,32 @@ py_library(
"python/training/drop_stale_gradient_optimizer.py",
"python/training/elastic_average_optimizer.py",
"python/training/external_optimizer.py",
+ "python/training/ggt.py",
+ "python/training/lars_optimizer.py",
"python/training/lazy_adam_optimizer.py",
+ "python/training/matrix_functions.py",
"python/training/model_average_optimizer.py",
"python/training/moving_average_optimizer.py",
"python/training/multitask_optimizer_wrapper.py",
"python/training/nadam_optimizer.py",
"python/training/powersign.py",
"python/training/reg_adagrad_optimizer.py",
+ "python/training/shampoo.py",
"python/training/sign_decay.py",
"python/training/variable_clipping_optimizer.py",
+ "python/training/weight_decay_optimizers.py",
],
srcs_version = "PY2AND3",
deps = [
+ "//tensorflow/contrib/optimizer_v2:optimizer_v2_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:clip_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_ops",
"//tensorflow/python:gradients",
"//tensorflow/python:init_ops",
+ "//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
"//tensorflow/python:state_ops",
@@ -194,6 +202,25 @@ py_test(
],
)
+py_test(
+ name = "weight_decay_optimizers_test",
+ srcs = ["python/training/weight_decay_optimizers_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":opt_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ ],
+)
+
tf_py_test(
name = "drop_stale_gradient_optimizer_test",
srcs = ["python/training/drop_stale_gradient_optimizer_test.py"],
@@ -302,3 +329,71 @@ py_test(
"//third_party/py/numpy",
],
)
+
+py_test(
+ name = "ggt_test",
+ srcs = ["python/training/ggt_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":opt_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "shampoo_test",
+ size = "large",
+ srcs = ["python/training/shampoo_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":opt_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "lars_optimizer_test",
+ srcs = ["python/training/lars_optimizer_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":opt_py",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ ],
+)
+
+py_test(
+ name = "matrix_functions_test",
+ srcs = ["python/training/matrix_functions_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":opt_py",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ ],
+)
diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py
index 4c13c8e247..ad7d7cfa6e 100644
--- a/tensorflow/contrib/opt/__init__.py
+++ b/tensorflow/contrib/opt/__init__.py
@@ -22,15 +22,20 @@ from __future__ import print_function
from tensorflow.contrib.opt.python.training.adamax import *
from tensorflow.contrib.opt.python.training.addsign import *
from tensorflow.contrib.opt.python.training.drop_stale_gradient_optimizer import *
+from tensorflow.contrib.opt.python.training.elastic_average_optimizer import *
from tensorflow.contrib.opt.python.training.external_optimizer import *
+from tensorflow.contrib.opt.python.training.lars_optimizer import *
+from tensorflow.contrib.opt.python.training.ggt import *
from tensorflow.contrib.opt.python.training.lazy_adam_optimizer import *
+from tensorflow.contrib.opt.python.training.model_average_optimizer import *
from tensorflow.contrib.opt.python.training.moving_average_optimizer import *
from tensorflow.contrib.opt.python.training.multitask_optimizer_wrapper import *
from tensorflow.contrib.opt.python.training.nadam_optimizer import *
+from tensorflow.contrib.opt.python.training.reg_adagrad_optimizer import *
+from tensorflow.contrib.opt.python.training.shampoo import *
+from tensorflow.contrib.opt.python.training.weight_decay_optimizers import *
from tensorflow.contrib.opt.python.training.powersign import *
from tensorflow.contrib.opt.python.training.variable_clipping_optimizer import *
-from tensorflow.contrib.opt.python.training.elastic_average_optimizer import *
-from tensorflow.contrib.opt.python.training.model_average_optimizer import *
# pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
@@ -43,9 +48,14 @@ _allowed_symbols = [
'DelayCompensatedGradientDescentOptimizer',
'DropStaleGradientOptimizer',
'ExternalOptimizerInterface',
+ 'LARSOptimizer',
'LazyAdamOptimizer',
'NadamOptimizer',
'MovingAverageOptimizer',
+ 'MomentumWOptimizer',
+ 'AdamWOptimizer',
+ 'DecoupledWeightDecayExtension',
+ 'extend_with_decoupled_weight_decay',
'ScipyOptimizerInterface',
'VariableClippingOptimizer',
'MultitaskOptimizerWrapper',
@@ -53,7 +63,10 @@ _allowed_symbols = [
'ElasticAverageOptimizer',
'ElasticAverageCustomGetter',
'ModelAverageOptimizer',
- 'ModelAverageCustomGetter'
+ 'ModelAverageCustomGetter',
+ 'GGTOptimizer',
+ 'ShampooOptimizer',
+ 'RegAdagradOptimizer',
]
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/opt/python/training/adamax_test.py b/tensorflow/contrib/opt/python/training/adamax_test.py
index 21bf3f5313..61d8b94eca 100644
--- a/tensorflow/contrib/opt/python/training/adamax_test.py
+++ b/tensorflow/contrib/opt/python/training/adamax_test.py
@@ -74,7 +74,7 @@ class AdaMaxOptimizerTest(test.TestCase):
def doTestSparse(self, use_resource=False):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
# Initialize variables for numpy implementation.
zero_slots = lambda: np.zeros((3), dtype=dtype.as_numpy_dtype)
m0, v0, m1, v1 = zero_slots(), zero_slots(), zero_slots(), zero_slots()
@@ -142,7 +142,7 @@ class AdaMaxOptimizerTest(test.TestCase):
def testSparseRepeatedIndices(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
repeated_index_update_var = variables.Variable(
[[1.0], [2.0]], dtype=dtype)
aggregated_update_var = variables.Variable(
@@ -172,7 +172,7 @@ class AdaMaxOptimizerTest(test.TestCase):
def doTestBasic(self, use_resource=False):
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
- with self.test_session(graph=ops.Graph()):
+ with self.session(graph=ops.Graph()):
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -224,14 +224,16 @@ class AdaMaxOptimizerTest(test.TestCase):
var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1)
# Validate updated params
- self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
- self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
+ self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0),
+ rtol=1e-2)
+ self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1),
+ rtol=1e-2)
if use_resource:
self.assertEqual("var0_%d/AdaMax:0" % (i,),
opt.get_slot(var=var0, name="m").name)
def testBasic(self):
- with self.test_session():
+ with self.cached_session():
self.doTestBasic(use_resource=False)
@test_util.run_in_graph_and_eager_modes(reset_test=True)
@@ -240,7 +242,7 @@ class AdaMaxOptimizerTest(test.TestCase):
def testTensorLearningRate(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -276,7 +278,7 @@ class AdaMaxOptimizerTest(test.TestCase):
def testSharing(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
diff --git a/tensorflow/contrib/opt/python/training/addsign_test.py b/tensorflow/contrib/opt/python/training/addsign_test.py
index 08d45ed73f..628a735e72 100644
--- a/tensorflow/contrib/opt/python/training/addsign_test.py
+++ b/tensorflow/contrib/opt/python/training/addsign_test.py
@@ -214,7 +214,7 @@ class AddSignTest(test.TestCase):
# Run 7 steps of AddSign
# first 4 steps with positive gradient
# last 3 steps with negative gradient (sign(gm) should be -1)
- for t in range(1, 4):
+ for t in range(1, 8):
if t < 5:
update.run()
else:
@@ -222,7 +222,7 @@ class AddSignTest(test.TestCase):
var0_np, m0 = addsign_update_numpy(
var0_np,
- grads0_np,
+ grads0_np if t < 5 else -grads0_np,
m0,
learning_rate,
alpha=alpha,
@@ -232,7 +232,7 @@ class AddSignTest(test.TestCase):
)
var1_np, m1 = addsign_update_numpy(
var1_np,
- grads1_np,
+ grads1_np if t < 5 else -grads1_np,
m1,
learning_rate,
alpha=alpha,
diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
index 209c4611f3..6c203e5519 100644
--- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
@@ -17,22 +17,23 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.ops import math_ops
-
-from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import gen_nn_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import optimizer
+from tensorflow.python.training import saver
from tensorflow.python.training import session_run_hook
-from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import data_flow_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import constant_op
LOCAL_VARIABLE_NAME = 'local_center_variable'
GLOBAL_VARIABLE_NAME = 'global_center_variable'
+GLOBAL_STEP = 'global_step'
class ElasticAverageCustomGetter(object):
@@ -52,16 +53,32 @@ class ElasticAverageCustomGetter(object):
with tf.device(
tf.train.replica_device_setter(
worker_device=worker_device,
- ps_device="/job:ps/cpu:0",
+ ps_device="/job:ps",
cluster=cluster)),
tf.variable_scope('',custom_getter=ea_custom_getter):
- hid_w = tf.get_variable(
- initializer=tf.truncated_normal(
- [IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units],
- stddev=1.0 / IMAGE_PIXELS),
- name="hid_w")
- hid_b = tf.get_variable(initializer=tf.zeros([FLAGS.hidden_units]),
- name="hid_b")
+ ...
+ create your model here
+ ...
+ with tf.device(worker_device):
+ opt = tf.train.MomentumOptimizer(...)
+ optimizer = ElasticAverageOptimizer(
+ opt,
+ num_worker=2,
+ moving_rate=0.01, # or use default value
+ communication_period=20,
+ ea_custom_getter=ea_custom_getter)
+ ...
+ train_op = optimizer.apply_gradients(
+ grads_vars,
+ global_step=global_step)
+ ...
+ hooks = [optimizer.make_session_run_hook(is_chief, task_index)]
+ ...
+ with tf.train.MonitoredTrainingSession(master=server.target,
+ is_chief=is_chief,
+ checkpoint_dir=("...),
+ save_checkpoint_secs=600,
+ hooks=hooks) as mon_sess:
"""
def __init__(self, worker_device):
@@ -83,21 +100,32 @@ class ElasticAverageCustomGetter(object):
collections=[ops.GraphKeys.LOCAL_VARIABLES],
*args,
**kwargs)
- global_center_variable = variable_scope.variable(
+ if kwargs['reuse'] == True:
+ return local_var
+ global_center_variable = getter(
name='%s/%s' % (GLOBAL_VARIABLE_NAME, name),
- initial_value=local_var.initialized_value(),
trainable=False,
- collections=[ops.GraphKeys.GLOBAL_VARIABLES])
+ collections=[ops.GraphKeys.GLOBAL_VARIABLES],
+ *args,
+ **kwargs)
with ops.device(self._worker_device):
- local_center_variable = variable_scope.variable(
+ local_center_variable = getter(
name='%s/%s' % (LOCAL_VARIABLE_NAME, name),
- initial_value=local_var.initialized_value(),
trainable=False,
- collections=[ops.GraphKeys.LOCAL_VARIABLES])
-
- self._local_map[local_var] = local_center_variable
- self._global_map[local_var] = global_center_variable
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ *args,
+ **kwargs)
+ if kwargs['partitioner'] is None:
+ self._local_map[local_var] = local_center_variable
+ self._global_map[local_var] = global_center_variable
+ else:
+ v_list = list(local_var)
+ for i in range(len(v_list)):
+ self._local_map[v_list[i]] \
+ = list(local_center_variable)[i]
+ self._global_map[v_list[i]] \
+ = list(global_center_variable)[i]
return local_var
else:
kwargs['trainable'] = trainable
@@ -132,6 +160,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
moving_rate=None,
rho=None,
use_locking=True,
+ synchronous=False,
name='ElasticAverageOptimizer'):
"""Construct a new gradient descent optimizer.
@@ -143,9 +172,16 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
communication_period: An int point value to controls the frequency
of the communication between every worker and the ps.
moving_rate: A floating point value to control the elastic difference.
- rho: the amount of exploration we allow ine the model. The default
+ rho: the amount of exploration we allow in the model. The default
value is moving_rate/learning_rate
+ rho=0.0 is suggested in async mode.
use_locking: If True use locks for update operations.
+ synchronous: Add_sync_queues_and_barrier or not.
+ True: all workers will wait for each other before start training
+ False: worker can start training when its initilization is done,
+ no need to wait for everyone is ready.
+ in case one worker is restarted, it can join and continue
+ training without being blocked.
name: Optional name prefix for the operations created when applying
gradients. Defaults to "ElasticAverageOptimizer".
"""
@@ -155,6 +191,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
self._period = communication_period
self._local_map = ea_custom_getter._local_map
self._global_map = ea_custom_getter._global_map
+ self._synchronous = synchronous
if moving_rate is None:
self._moving_rate = self.BETA / communication_period / num_worker
@@ -248,11 +285,29 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
TypeError: If `grads_and_vars` is malformed.
ValueError: If none of the variables have gradients.
"""
+ global_old = set(n.op.name for n in variables.global_variables())
apply_updates = self._opt.apply_gradients(grads_and_vars)
+ global_new = set(n.op.name for n in variables.global_variables())
with ops.control_dependencies([apply_updates]):
local_update = state_ops.assign_add(
self._local_step, 1, name='local_step_update').op
+ # this is for place the variables created by optimizer to local collection
+ # e.g., AdamOptimizer will create beta as global variables
+ def _adjust_optimizer_variable_collection(opt_vars):
+ g = ops.get_default_graph()
+ idx = 0
+ for _ in range(len(g._collections[ops.GraphKeys.GLOBAL_VARIABLES])):
+ var = g.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES)[idx]
+ name = var.op.name
+ if name in opt_vars:
+ ops.add_to_collection(ops.GraphKeys.LOCAL_VARIABLES, var)
+ del g.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES)[idx]
+ else:
+ idx += 1
+
+ _adjust_optimizer_variable_collection(global_new - global_old)
+
# update global variables.
def _Update_global_variables():
local_vars = [v for g, v in grads_and_vars if g is not None]
@@ -297,7 +352,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
variables equal to the global center variables before the training begins"""
def _Add_sync_queues_and_barrier(enqueue_after_list):
- """Adds ops to enqueu on all worker queues"""
+ """Adds ops to enqueue on all worker queues"""
sync_queues = [
data_flow_ops.FIFOQueue(
self._num_worker, [dtypes.bool],
@@ -331,6 +386,9 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
init_ops.append(state_ops.assign(lc_var, gc_var))
init_op = control_flow_ops.group(*(init_ops))
+ if self._synchronous == False:
+ return init_op
+
sync_queue_op = _Add_sync_queues_and_barrier([init_op])
return sync_queue_op
@@ -338,6 +396,51 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
"""Creates a hook to handle ElasticAverageOptimizerHook ops such as initialization."""
return _ElasticAverageOptimizerHook(self, is_chief, task_index)
+ def swapping_saver(self, var_list=None, name='swapping_saver', **kwargs):
+ """Create a saver copy global_center_variable to trainable variables
+ Please call this function after all your variables created with
+ ElasticAverageCustomGetter. For evaluations or inference, use this saver
+ during training. It will save the global_center_variable of the trained
+ parameters under the original parameter names.
+ Args:
+ var_list: List of variables to save, as per `Saver()`.
+ If set to None, save all the trainable_variables that have
+ been created before this call.
+ name: The name of the saver.
+ **kwargs: Keyword arguments of `Saver()`.
+ Returns:
+ A `tf.train.Saver` object.
+ Raises:
+ RuntimeError: global_center_variable is empty, please make sure
+ this is called after model created and
+ ElasticAverageCustomGetter is used when declaring you model
+ """
+ if not self._global_map:
+ raise RuntimeError('global_center_variable is empty, please make sure '
+ 'this is called after model created and '
+ 'ElasticAverageCustomGetter is used when declaring '
+ 'you model')
+
+ if var_list is None:
+ var_list = variables.trainable_variables()
+ if not isinstance(var_list, dict):
+ var_list = saver.BaseSaverBuilder.OpListToDict(var_list)
+
+ swapped_var_list = {}
+ for key, var in var_list.items():
+ tensor = var
+
+ if not isinstance(var, list):
+ for tvar in variables.trainable_variables():
+ if tvar.op.name == var.op.name:
+ tensor = self._global_map.get(tvar, var)
+ break
+ else: #partitioned variable
+ tensor = [self._global_map.get(lvar, lvar) for lvar in var]
+
+ swapped_var_list[key] = tensor
+
+ return saver.Saver(swapped_var_list, name=name, **kwargs)
class _ElasticAverageOptimizerHook(session_run_hook.SessionRunHook):
@@ -358,3 +461,7 @@ class _ElasticAverageOptimizerHook(session_run_hook.SessionRunHook):
if self._is_chief:
self._global_init_op = variables.global_variables_initializer()
self._variable_init_op = self._ea_optimizer.get_init_op(self._task_index)
+
+ def after_create_session(self, session, coord):
+ """Run initialization ops"""
+ session.run(self._variable_init_op) \ No newline at end of file
diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py
index 9d57dc08f6..5bf6a08de1 100644
--- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py
@@ -17,17 +17,22 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
import portpicker
+from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
+from tensorflow.python.training import device_setter
from tensorflow.python.training import gradient_descent
+from tensorflow.python.training import saver
from tensorflow.python.training import server_lib
from tensorflow.python.training import training
from tensorflow.python.training import training_util
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.training import device_setter
from tensorflow.contrib.opt.python.training.elastic_average_optimizer import \
ElasticAverageOptimizer, ElasticAverageCustomGetter, GLOBAL_VARIABLE_NAME
@@ -59,42 +64,72 @@ def create_local_cluster(num_workers, num_ps, protocol="grpc"):
# Creates the workers and return their sessions, graphs, train_ops.
# Chief worker will update at last
-def _get_workers(num_workers, period, workers, moving_rate):
+def _get_workers(num_workers, period, workers, moving_rate, num_ps=1):
sessions = []
graphs = []
train_ops = []
+ savers = []
for worker_id in range(num_workers):
graph = ops.Graph()
is_chief = (worker_id == 0)
with graph.as_default():
worker_device = "/job:worker/task:%d/cpu:0" % (worker_id)
- ea_coustom = ElasticAverageCustomGetter(worker_device=worker_device)
+ ea_custom = ElasticAverageCustomGetter(worker_device=worker_device)
with variable_scope.variable_scope(
- "", custom_getter=ea_coustom), ops.device(
+ "", custom_getter=ea_custom), ops.device(
device_setter.replica_device_setter(
worker_device=worker_device,
ps_device="/job:ps/task:0/cpu:0",
ps_tasks=1)):
- global_step = variables.Variable(0, name="global_step", trainable=False)
+ global_step = training_util.get_or_create_global_step()
var_0 = variable_scope.get_variable(initializer=0.0, name="v0")
var_1 = variable_scope.get_variable(initializer=1.0, name="v1")
-
- with ops.device("/job:worker/task:" + str(worker_id)):
- grads_0 = constant_op.constant(-1.0)
- grads_1 = constant_op.constant(-1.0)
-
- sgd_opt = gradient_descent.GradientDescentOptimizer(1.0)
- opt = ElasticAverageOptimizer(
- opt=sgd_opt,
- num_worker=num_workers,
- moving_rate=moving_rate,
- communication_period=period,
- ea_custom_getter=ea_coustom)
+ if num_ps > 1:
+ with variable_scope.variable_scope(
+ "",
+ partitioner=partitioned_variables.fixed_size_partitioner(
+ num_ps, axis=0),
+ custom_getter=ea_custom), ops.device(
+ device_setter.replica_device_setter(
+ worker_device=worker_device,
+ ps_device="/job:ps/task:0/cpu:0",
+ ps_tasks=num_ps)):
+
+ partition_var = variable_scope.get_variable(
+ 'partition_var',
+ shape=[2, 4],
+ initializer=init_ops.ones_initializer)
+ part_0 = list(partition_var)[0]
+ part_1 = list(partition_var)[1]
+
+ with ops.device("/job:worker/task:" + str(worker_id)):
+ grads_0 = constant_op.constant(-1.0)
+ grads_1 = constant_op.constant(-1.0)
+ grads_part_0 = constant_op.constant([[-1., -1., -1., -1.]])
+ grads_part_1 = constant_op.constant([[-1., -1., -1., -1.]])
+
+ sgd_opt = gradient_descent.GradientDescentOptimizer(1.0)
+ opt = ElasticAverageOptimizer(
+ opt=sgd_opt,
+ num_worker=num_workers,
+ moving_rate=moving_rate,
+ communication_period=period,
+ ea_custom_getter=ea_custom)
+ if num_ps == 1:
+ train_op = [
+ opt.apply_gradients(([grads_0, var_0], [grads_1, var_1]),
+ global_step)
+ ]
+ else:
train_op = [
- opt.apply_gradients(([grads_0, var_0], [grads_1, var_1]),
+ opt.apply_gradients(([grads_0, var_0],
+ [grads_1, var_1],
+ [grads_part_0, part_0],
+ [grads_part_1, part_1]),
global_step)
]
easgd_hook = opt.make_session_run_hook(is_chief, worker_id)
+ saver = opt.swapping_saver()
# Creates MonitoredSession
sess = training.MonitoredTrainingSession(
workers[worker_id].target, hooks=[easgd_hook])
@@ -102,8 +137,9 @@ def _get_workers(num_workers, period, workers, moving_rate):
sessions.append(sess)
graphs.append(graph)
train_ops.append(train_op)
+ savers.append(saver)
- return sessions, graphs, train_ops
+ return sessions, graphs, train_ops, savers
class ElasticAverageOptimizerTest(test.TestCase):
@@ -118,7 +154,7 @@ class ElasticAverageOptimizerTest(test.TestCase):
cluster, workers, _ = create_local_cluster(
num_workers=num_workers, num_ps=num_ps)
- sessions, graphs, train_ops = _get_workers(
+ sessions, graphs, train_ops, savers = _get_workers(
num_workers, communication_period, workers, 1.0)
var_0 = graphs[0].get_tensor_by_name("v0:0")
@@ -158,6 +194,21 @@ class ElasticAverageOptimizerTest(test.TestCase):
self.assertAllEqual(2.0, sessions[0].run(var_0_g))
self.assertAllEqual(3.0, sessions[0].run(var_1_g))
self.assertAllEqual(1, sessions[0].run(global_step))
+ sessions[0].run(train_ops[0])
+
+ # save, data will be global value
+ outfile = os.path.join(test.get_temp_dir(), "model")
+ savers[0].save(sessions[0]._sess._sess._sess._sess,
+ save_path=outfile)
+ ops.reset_default_graph() # restore on a new graph
+ with session.Session() as sess:
+ v0 = variable_scope.get_variable(initializer=0.0, name="v0")
+ v1 = variable_scope.get_variable(initializer=1.0, name="v1")
+ sess.run(variables.local_variables_initializer())
+ saver_opt = saver.Saver(var_list=[v1, v0])
+ saver_opt.restore(sess, outfile)
+ self.assertAllEqual(2.0, sess.run(v0))
+ self.assertAllEqual(3.0, sess.run(v1))
def test2Worker1Period(self):
num_workers = 2
@@ -166,8 +217,8 @@ class ElasticAverageOptimizerTest(test.TestCase):
cluster, workers, _ = create_local_cluster(
num_workers=num_workers, num_ps=num_ps)
- sessions, graphs, train_ops = _get_workers(
- num_workers, communication_period, workers, 0.5)
+ sessions, graphs, train_ops, savers = _get_workers(
+ num_workers, communication_period, workers, 0.5, num_ps=2)
var_0 = graphs[0].get_tensor_by_name("v0:0")
var_1 = graphs[0].get_tensor_by_name("v1:0")
@@ -177,6 +228,9 @@ class ElasticAverageOptimizerTest(test.TestCase):
var_0_g = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v0:0")
var_1_g = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v1:0")
+ part_0_g = graphs[0].get_tensor_by_name(
+ GLOBAL_VARIABLE_NAME + "/partition_var/part_0:0")
+
# Verify the initialized value.
self.assertAllEqual(0.0, sessions[0].run(var_0))
self.assertAllEqual(1.0, sessions[0].run(var_1))
@@ -194,22 +248,45 @@ class ElasticAverageOptimizerTest(test.TestCase):
self.assertAllEqual(1.75, sessions[0].run(var_1_g))
self.assertAllEqual(0.75, sessions[1].run(var_0_1))
self.assertAllEqual(1.75, sessions[1].run(var_1_1))
+ # part_0 of global_center copy
+ part_0_g = sessions[0].run(part_0_g)
+
+ outfile = os.path.join(test.get_temp_dir(), "model")
+ savers[0].save(sessions[0]._sess._sess._sess._sess,
+ save_path=outfile)
+
+ # verify restore of partitioned_variables
+ ops.reset_default_graph() # restore on a new graph
+ g = ops.get_default_graph()
+ with session.Session() as sess, g.as_default():
+ with variable_scope.variable_scope(
+ "",
+ partitioner=partitioned_variables.fixed_size_partitioner(
+ num_ps, axis=0)):
+ partition_var = variable_scope.get_variable(
+ 'partition_var',
+ shape=[2, 4],
+ initializer=init_ops.ones_initializer)
+ s = saver.Saver(var_list=[partition_var])
+ s.restore(sess, outfile)
+ part_0 = g.get_tensor_by_name('partition_var/part_0:0')
+ self.assertAllEqual(part_0_g, sess.run(part_0))
def testPS2TasksWithClusterSpecClass(self):
cluster_spec = server_lib.ClusterSpec({
"ps": ["ps0:2222", "ps1:2222"],
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
})
- ea_coustom = ElasticAverageCustomGetter(worker_device="/job:worker/task:0")
+ ea_custom = ElasticAverageCustomGetter(worker_device="/job:worker/task:0")
from tensorflow.python.training import device_setter
with ops.device(
device_setter.replica_device_setter(cluster=cluster_spec,
worker_device="/job:worker/task:0",
ps_device="/job:ps")), \
- variable_scope.variable_scope("", custom_getter=ea_coustom):
+ variable_scope.variable_scope("", custom_getter=ea_custom):
v = variable_scope.get_variable(initializer=[1, 2], name="v")
w = variable_scope.get_variable(initializer=[2, 1], name="w")
- v_g, w_g = ea_coustom._global_map[v], ea_coustom._global_map[w]
+ v_g, w_g = ea_custom._global_map[v], ea_custom._global_map[w]
self.assertDeviceEqual("/job:worker/task:0", v.device)
self.assertDeviceEqual("job:ps/task:0", v_g.device)
self.assertDeviceEqual("/job:worker/task:0", w.device)
diff --git a/tensorflow/contrib/opt/python/training/external_optimizer_test.py b/tensorflow/contrib/opt/python/training/external_optimizer_test.py
index 953586ee70..9997103016 100644
--- a/tensorflow/contrib/opt/python/training/external_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/external_optimizer_test.py
@@ -85,7 +85,7 @@ class ExternalOptimizerInterfaceTest(TestCase):
optimizer = MockOptimizerInterface(loss)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
optimizer.minimize(sess)
@@ -107,7 +107,7 @@ class ExternalOptimizerInterfaceTest(TestCase):
optimizer = MockOptimizerInterface(loss)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
initial_vector_val = sess.run(vector)
@@ -164,7 +164,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
optimizer = external_optimizer.ScipyOptimizerInterface(
self._objective(x), method=method, options=options)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
optimizer.minimize(sess)
@@ -176,7 +176,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
x = variables.Variable(array_ops.zeros(dimension))
optimizer = external_optimizer.ScipyOptimizerInterface(self._objective(x))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
optimizer.minimize(sess)
@@ -242,7 +242,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
optimizer = external_optimizer.ScipyOptimizerInterface(
loss, equalities=equalities, inequalities=inequalities, method='SLSQP')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
optimizer.minimize(sess)
self.assertAllClose(np.ones(2), sess.run(vector))
@@ -260,7 +260,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
optimizer = external_optimizer.ScipyOptimizerInterface(
loss, var_to_bounds=var_to_bounds)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
optimizer.minimize(sess)
self.assertAllClose(np.ones(2), sess.run(vector))
@@ -277,7 +277,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
optimizer = external_optimizer.ScipyOptimizerInterface(
loss, var_to_bounds=var_to_bounds)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
optimizer.minimize(sess)
self.assertAllClose([0., 2.], sess.run(vector))
@@ -293,7 +293,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
optimizer = external_optimizer.ScipyOptimizerInterface(
loss, method='SLSQP')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
optimizer.minimize(sess)
method = optimizer.optimizer_kwargs.get('method')
@@ -312,7 +312,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
optimizer = external_optimizer.ScipyOptimizerInterface(loss, method='SLSQP')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
initial_vector_val = sess.run(vector)
diff --git a/tensorflow/contrib/opt/python/training/ggt.py b/tensorflow/contrib/opt/python/training/ggt.py
new file mode 100644
index 0000000000..cae952d8f5
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/ggt.py
@@ -0,0 +1,312 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""GGT for Tensorflow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import numpy as np
+from tensorflow.contrib.optimizer_v2 import optimizer_v2
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+
+
+class GGTOptimizer(optimizer_v2.OptimizerV2):
+ """Optimizer that implements the GGT algorithm.
+
+ GGT has an advantage over sgd and adam on large models with poor conditioning,
+ for example language models and CNNs,
+ see [[ABCHSZZ 2018]](https://arxiv.org/pdf/1806.02958.pdf).
+ """
+
+ def __init__(self,
+ learning_rate=0.001,
+ beta1=0.9,
+ use_locking=False,
+ name="GGT",
+ window=10,
+ eps=1e-4,
+ svd_eps=1e-6,
+ sigma_eps=1e-2):
+ """Construct a new GGT optimizer.
+
+ Initialization:
+
+ ```
+ t <- 0 (Initialize timestep)
+ grad_buffer <- 0 (Initialize buffer for keeping past gradients)
+ flat_grad <- 0 (Initialize flattened gradient that contains gradients of all
+ variables)
+ m_0 <- 0 (Initialize 1st moment vector)
+ ```
+
+ Suppose all variables and their gradients are concatenated into vectors
+ `flat_vars` and `flat_grad`. The update rule for `flat_vars`
+ uses an optimization described at the beginning of section 2 of the paper:
+
+ ```
+ t <- t + 1
+
+ m_t <- beta1 * m_{t-1} + (1 - beta1) * flat_grad
+ grad_buffer[(t-1) % window, :] <- m_t
+
+ M <- grad_buffer^T / sqrt(min(t, window))
+ U, sigma, _ <- SVD(M^TM + I * svd_eps)
+
+ sigma_sqrt_inv <- (sqrt(sigma) + sigma_eps)^(-3)
+ sigma_sqrt_min <- min(sqrt(sigma))
+
+ if sigma_sqrt_min > eps:
+ new_step <- M U diag(sigma_sqrt_inv) U^T M^T m_t +
+ (m_t - M U diag(1/sigma) U^T M^T m_t) / sigma_sqrt_min
+ else:
+ new_step <- M U diag(sigma_sqrt_inv) U^T M^T m_t
+
+ flat_vars <- flat_vars - learning_rate * new_step
+ ```
+
+ GGT provides the power of full-matrix adaptive regularization at a cost not
+ much larger than SGD. As a result it is suited for large models where the
+ gradient covariance matrix has a poor condition number that slows down first
+ order methods.
+ GGT uses the preconditioner from full-matrix AdaGrad, with gradient history
+ attenuated exponentially as in Adam, and truncated to a window parameter.
+ It has provable guarantees even for non-convex optimization that is never
+ significantly worse than SGD and in some cases better.
+
+ Args:
+ learning_rate: A float hyperparameter. The learning rate.
+ beta1: A float hyperparameter. The exponential decay rate for the 1st
+ moment estimates.
+ use_locking: If True use locks for update operations.
+ name: Optional name for the operations created when applying gradients.
+ Defaults to "GGT".
+ window: An integer hyperparameter. The number of first moments to keep in
+ computing the adaptive preconditioner.
+ eps: A float hyperparameter. Used to truncate small eigenvalues of the
+ gradient covariance matrix.
+ svd_eps: A float hyperparameter. Used to stabilize SVD.
+ sigma_eps: A float hyperparameter. Used to regularize matrix inversion.
+ """
+ super(GGTOptimizer, self).__init__(use_locking, name)
+ self._set_hyper("lr", learning_rate)
+ self._set_hyper("beta1", beta1)
+ self._set_hyper("window", window)
+ self._set_hyper("eps", eps)
+ self._set_hyper("svd_eps", svd_eps)
+ self._set_hyper("sigma_eps", sigma_eps)
+
+ self.index_dict = {}
+ self.shape_dict = {}
+
+ def _create_vars(self, var_list, state):
+ # Construct ordered dictionary for variable dimensions, sorted by name.
+ shape_dict = {}
+ for v in var_list:
+ shape_dict[v.name] = np.prod(v.get_shape()).value
+ self.shape_dict = collections.OrderedDict(
+ sorted(shape_dict.items(), key=lambda t: t[0]))
+
+ # Assign each variable its location in flat_grad. The locations are based on
+ # the order of sorted names.
+ idx = 0
+ for v_name, v_dim in self.shape_dict.items():
+ self.index_dict[v_name] = idx
+ idx += v_dim
+
+ state.create_non_slot(
+ initial_value=math_ops.cast(0., dtype=var_list[0].dtype.base_dtype),
+ name="global_step")
+
+ # Buffer for keeping past gradients.
+ window = state.get_hyper("window")
+ grad_buffer_init = array_ops.zeros(
+ [window, idx], dtype=var_list[0].dtype.base_dtype)
+ state.create_non_slot(initial_value=grad_buffer_init, name="grad_buffer")
+
+ state.create_non_slot(
+ initial_value=array_ops.zeros(
+ (idx,), dtype=var_list[0].dtype.base_dtype),
+ name="moment1")
+
+ # Flattened gradient that contains gradients for all variables in the model.
+ state.create_non_slot(
+ initial_value=array_ops.zeros(
+ (idx,), dtype=var_list[0].dtype.base_dtype),
+ name="flat_grad")
+
+ def _get_global_step(self, state=None):
+ if state is None:
+ state = self._get_per_graph_state()
+ return state.get_non_slot("global_step")
+
+ def _get_moment1(self, state=None):
+ if state is None:
+ state = self._get_per_graph_state()
+ return state.get_non_slot("moment1")
+
+ def _get_grad_buffer(self, state=None):
+ if state is None:
+ state = self._get_per_graph_state()
+ return state.get_non_slot("grad_buffer")
+
+ def _get_flat_grad(self, state=None):
+ if state is None:
+ state = self._get_per_graph_state()
+ return state.get_non_slot("flat_grad")
+
+ def _apply_sparse(self, grad, var):
+ raise NotImplementedError("Sparse gradient updates are not supported.")
+
+ def _prepare(self, state):
+ self._variables = []
+
+ def _apply_dense(self, grad, var, state):
+ self._variables.append(var)
+ dim = self.shape_dict[var.name]
+ start_index = self.index_dict[var.name]
+ end_index = start_index + dim
+
+ # Update flat_gradient at the index associated with the variable.
+ flat_grad = self._get_flat_grad(state)
+ new_flat_grad = array_ops.reshape(grad, [-1])
+ flat_grad_updated = state_ops.scatter_update(
+ flat_grad, math_ops.range(start_index, end_index), new_flat_grad)
+
+ return flat_grad_updated
+
+ def _resource_apply_dense(self, grad, var, state):
+ self._variables.append(var)
+ dim = self.shape_dict[var.name]
+ start_index = self.index_dict[var.name]
+ end_index = start_index + dim
+
+ # Update flat_gradient at the index associated with the variable.
+ flat_grad = self._get_flat_grad(state)
+ new_flat_grad = array_ops.reshape(grad, [-1])
+ flat_grad_updated = state_ops.scatter_update(
+ flat_grad, math_ops.range(start_index, end_index), new_flat_grad)
+
+ return flat_grad_updated
+
+ def _finish(self, state):
+ var_dtype = self._variables[0].dtype.base_dtype
+ # Update global step.
+ global_step = self._get_global_step(state)
+ update_global_step = state_ops.assign_add(global_step, 1.)
+
+ # Update the first moment estimate.
+ beta1 = state.get_hyper("beta1", dtype=var_dtype)
+ moment1 = self._get_moment1(state)
+ flat_grad = self._get_flat_grad(state)
+ # moment1_t := beta1 * moment1_{t-1} + (1 - beta1) * flat_grad_t
+ update_moment1 = moment1.assign(beta1 * moment1 + (1. - beta1) * flat_grad)
+
+ # Update the gradient buffer.
+ window = state.get_hyper("window")
+ grad_buffer = self._get_grad_buffer(state)
+ next_grad_index = math_ops.floormod(
+ math_ops.to_int32(update_global_step - 1.), window)
+ # grad_buffer[(t-1) % window] := moment1_t
+ update_grad_buffer = state_ops.scatter_update(grad_buffer, next_grad_index,
+ update_moment1)
+
+ # Compute the update step.
+ eps = state.get_hyper("eps", dtype=var_dtype)
+ svd_eps = state.get_hyper("svd_eps", dtype=var_dtype)
+ sigma_eps = state.get_hyper("sigma_eps", dtype=var_dtype)
+ lr = state.get_hyper("lr", dtype=var_dtype)
+ denom = math_ops.sqrt(
+ math_ops.minimum(
+ ops.convert_to_tensor(update_global_step),
+ ops.convert_to_tensor(math_ops.cast(window, dtype=var_dtype))))
+ moment1_2d = array_ops.expand_dims(update_moment1, -1)
+
+ # m = grad_buffer^T / sqrt(min(t, window))
+ # m has shape [model dimension, window], where model dimension is the sum
+ # of the dimensions of the flattened variables.
+ m = array_ops.transpose(math_ops.divide(update_grad_buffer, denom))
+
+ # sigma, u, _ = SVD(m^Tm + I * svd_eps)
+ mm = math_ops.matmul(m, m, transpose_a=True)
+ damping = math_ops.cast(linalg_ops.eye(window), dtype=var_dtype) * svd_eps
+ sigma, u, _ = linalg_ops.svd(mm + damping)
+ sigma_sqrt = math_ops.sqrt(sigma)
+ sigma_sqrt_min = math_ops.reduce_min(sigma_sqrt)
+
+ # sigma_sqrt_inv = 1 / (\sqrt{sigma} + sigma_eps) ^ 3
+ # We add sigma_eps to alleviate numerical instability.
+ # Note that (m^Tm)^(-3/2) = u diag(sigma_sqrt_inv) u^T.
+ sigma_sqrt_inv = math_ops.divide(
+ math_ops.cast(1.0, dtype=var_dtype),
+ math_ops.pow(sigma_sqrt + sigma_eps, 3))
+
+ # In full matrix AdaGrad, the update step computes (mm^T)^(-1/2)g, where the
+ # inversion of a model dimension by model dimension matrix is needed. To
+ # speed up this computation we calculate the following instead:
+ # m(m^Tm)^(-3/2)m^T moment1 = m u diag(sigma_sqrt_inv) u^T m^T moment1.
+ new_step = array_ops.expand_dims(
+ array_ops.zeros(flat_grad.get_shape(), dtype=var_dtype), -1)
+ head = math_ops.matmul(
+ m,
+ math_ops.matmul(
+ u,
+ math_ops.matmul(
+ array_ops.diag(sigma_sqrt_inv),
+ math_ops.matmul(
+ u,
+ math_ops.matmul(m, moment1_2d, transpose_a=True),
+ transpose_a=True))))
+
+ # When inverting (mm^t)^(1/2), we also add epsilon * I regularization for
+ # degenerate cases. We expand ((mm^t)^(1/2) + epsilon * I)^(-1) using
+ # Woodbury's identity.
+ # For full derivation please see paper at
+ # https://arxiv.org/pdf/1806.02958.pdf
+ tail = moment1_2d - math_ops.matmul(
+ m,
+ math_ops.matmul(
+ u,
+ math_ops.matmul(
+ array_ops.diag(
+ math_ops.divide(math_ops.cast(1.0, dtype=var_dtype),
+ sigma)),
+ math_ops.matmul(
+ u,
+ math_ops.matmul(m, moment1_2d, transpose_a=True),
+ transpose_a=True))))
+ scaled_tail = math_ops.divide(tail, sigma_sqrt_min)
+
+ update_new_step = control_flow_ops.cond(
+ sigma_sqrt_min > eps, lambda: math_ops.add(head, scaled_tail),
+ lambda: math_ops.add(new_step, head))
+
+ # Update each variable.
+ update_step = []
+ for var in self._variables:
+ dim = self.shape_dict[var.name]
+ start_index = self.index_dict[var.name]
+ end_index = start_index + dim
+ var_update_correct_shape = array_ops.reshape(
+ update_new_step[start_index:end_index], var.get_shape())
+ var_updated = state_ops.assign_sub(var, lr * var_update_correct_shape)
+ update_step.append(var_updated)
+
+ return control_flow_ops.group(update_step)
diff --git a/tensorflow/contrib/opt/python/training/ggt_test.py b/tensorflow/contrib/opt/python/training/ggt_test.py
new file mode 100644
index 0000000000..1775edabb3
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/ggt_test.py
@@ -0,0 +1,183 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for GGTOptimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from tensorflow.contrib.opt.python.training.ggt import GGTOptimizer
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+def ggt_update_numpy(param,
+ g_t,
+ lr,
+ grad_buffer,
+ m,
+ window,
+ t,
+ beta1=0.9,
+ eps=1e-4,
+ svd_eps=1e-6,
+ sigma_eps=1e-2):
+ """Tests the correctness of one step of GGT."""
+ m_t = m * beta1 + (1 - beta1) * g_t
+ grad_buffer[((t - 1) % window), :] = m_t
+ m_matrix = np.transpose(grad_buffer / np.sqrt(np.minimum(t, window)))
+ mm = np.dot(np.transpose(m_matrix), m_matrix)
+ damping = np.eye(window) * svd_eps
+ u, sigma, _ = np.linalg.svd(mm + damping)
+
+ sigma_sqrt_inv = np.power(np.sqrt(sigma) + sigma_eps, -3)
+ new_step = np.linalg.multi_dot([
+ m_matrix, u,
+ np.diag(sigma_sqrt_inv),
+ np.transpose(u),
+ np.transpose(m_matrix), m_t
+ ])
+
+ sigma_sqrt_min = np.sqrt(sigma).min()
+
+ if sigma_sqrt_min > eps:
+ new_step += (m_t - np.linalg.multi_dot([
+ m_matrix, u,
+ np.diag(1.0 / sigma),
+ np.transpose(u),
+ np.transpose(m_matrix), m_t
+ ])) * (1.0 / sigma_sqrt_min)
+
+ param_t = param - lr * new_step
+ return param_t, m_t, grad_buffer
+
+
+class GGTOptimizerTest(test.TestCase):
+
+ def doTestBasic(self, use_resource=False):
+ # SVD does not support float16
+ for i, dtype in enumerate([dtypes.float32, dtypes.float64]):
+ with self.session(graph=ops.Graph()):
+ # Initialize variables for numpy implementation.
+ m0 = 0.0
+ window = 3
+ grad_buffer = np.zeros((window, 4), dtype=dtype.as_numpy_dtype)
+ lr = 0.001
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(
+ var0_np, name="var0_%d" % i)
+ var1 = resource_variable_ops.ResourceVariable(
+ var1_np, name="var1_%d" % i)
+ else:
+ var0 = variables.Variable(var0_np, name="var0")
+ var1 = variables.Variable(var1_np, name="var1")
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+
+ opt = GGTOptimizer(learning_rate=lr, window=window)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ opt_variables = opt.variables()
+
+ m_t = opt._get_moment1()
+ grad_buffer_t = opt._get_grad_buffer()
+ g_t = opt._get_flat_grad()
+ self.assertTrue(m_t is not None)
+ self.assertTrue(grad_buffer_t is not None)
+ self.assertTrue(g_t is not None)
+ self.assertIn(m_t, opt_variables)
+ self.assertIn(grad_buffer_t, opt_variables)
+ self.assertIn(g_t, opt_variables)
+
+ with ops.Graph().as_default():
+ # Shouldn't return non-slot variables from other graphs.
+ self.assertEqual(0, len(opt.variables()))
+
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+
+ m_t = opt._get_moment1()
+ grad_buffer_t = opt._get_grad_buffer()
+ g_t = opt._get_flat_grad()
+
+ # Run 3 steps of GGT
+ for t in range(1, 4):
+ if not context.executing_eagerly():
+ self.evaluate(update)
+ elif t > 1:
+ opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+
+ if t == 1:
+ self.assertAllCloseAccordingToType(
+ np.array([0.01, 0.01, 0.001, 0.001]), self.evaluate(m_t))
+ self.assertAllCloseAccordingToType(
+ np.array([[0.01, 0.01, 0.001, 0.001], [0., 0., 0., 0.],
+ [0., 0., 0., 0.]]), self.evaluate(grad_buffer_t))
+ elif t == 2:
+ self.assertAllCloseAccordingToType(
+ np.array([0.019, 0.019, 0.0019, 0.0019]), self.evaluate(m_t))
+ self.assertAllCloseAccordingToType(
+ np.array([[0.01, 0.01, 0.001, 0.001],
+ [0.019, 0.019, 0.0019, 0.0019], [0., 0., 0., 0.]]),
+ self.evaluate(grad_buffer_t))
+ else:
+ self.assertAllCloseAccordingToType(
+ np.array([0.0271, 0.0271, 0.00271, 0.00271]),
+ self.evaluate(m_t))
+ self.assertAllCloseAccordingToType(
+ np.array([[0.01, 0.01, 0.001,
+ 0.001], [0.019, 0.019, 0.0019, 0.0019],
+ [0.0271, 0.0271, 0.00271, 0.00271]]),
+ self.evaluate(grad_buffer_t))
+
+ self.assertAllCloseAccordingToType([0.1, 0.1, 0.01, 0.01],
+ self.evaluate(g_t))
+
+ var_np = np.append(var0_np, var1_np)
+ grads_np = np.append(grads0_np, grads1_np)
+ var_np, m0, grad_buffer = ggt_update_numpy(var_np, grads_np, lr,
+ grad_buffer, m0, window, t)
+
+ var0_np = var_np[:2]
+ var1_np = var_np[2:]
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
+ self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
+
+ def testBasic(self):
+ with self.cached_session():
+ self.doTestBasic(use_resource=False)
+
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
+ def testResourceBasic(self):
+ self.doTestBasic(use_resource=True)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/opt/python/training/lars_optimizer.py b/tensorflow/contrib/opt/python/training/lars_optimizer.py
new file mode 100644
index 0000000000..a8dafd9a4c
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/lars_optimizer.py
@@ -0,0 +1,164 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Layer-wise Adaptive Rate Scaling optimizer for large-batch training."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.training import optimizer
+from tensorflow.python.training import training_ops
+
+
+class LARSOptimizer(optimizer.Optimizer):
+ """Layer-wise Adaptive Rate Scaling for large batch training.
+
+ Introduced by "Large Batch Training of Convolutional Networks" by Y. You,
+ I. Gitman, and B. Ginsburg. (https://arxiv.org/abs/1708.03888)
+
+ Implements the LARS learning rate scheme presented in the paper above. This
+ optimizer is useful when scaling the batch size to up to 32K without
+ significant performance degradation. It is recommended to use the optimizer
+ in conjunction with:
+ - Gradual learning rate warm-up
+ - Linear learning rate scaling
+ - Poly rule learning rate decay
+
+ Note, LARS scaling is currently only enabled for dense tensors. Sparse tensors
+ use the default momentum optimizer.
+ """
+
+ def __init__(
+ self,
+ learning_rate,
+ momentum=0.9,
+ weight_decay=0.0001,
+ # The LARS coefficient is a hyperparameter
+ eeta=0.001,
+ epsilon=0.0,
+ name="LARSOptimizer",
+ # Enable skipping variables from LARS scaling.
+ # TODO(sameerkm): Enable a direct mechanism to pass a
+ # subset of variables to the optimizer.
+ skip_list=None,
+ use_nesterov=False):
+ """Construct a new LARS Optimizer.
+
+ Args:
+ learning_rate: A `Tensor` or floating point value. The base learning rate.
+ momentum: A floating point value. Momentum hyperparameter.
+ weight_decay: A floating point value. Weight decay hyperparameter.
+ eeta: LARS coefficient as used in the paper. Dfault set to LARS
+ coefficient from the paper. (eeta / weight_decay) determines the highest
+ scaling factor in LARS.
+ epsilon: Optional epsilon parameter to be set in models that have very
+ small gradients. Default set to 0.0.
+ name: Optional name prefix for variables and ops created by LARSOptimizer.
+ skip_list: List of strings to enable skipping variables from LARS scaling.
+ If any of the strings in skip_list is a subset of var.name, variable
+ 'var' is skipped from LARS scaling. For a typical classification model
+ with batch normalization, the skip_list is ['batch_normalization',
+ 'bias']
+ use_nesterov: when set to True, nesterov momentum will be enabled
+
+ Raises:
+ ValueError: If a hyperparameter is set to a non-sensical value.
+ """
+ if momentum < 0.0:
+ raise ValueError("momentum should be positive: %s" % momentum)
+ if weight_decay < 0.0:
+ raise ValueError("weight_decay should be positive: %s" % weight_decay)
+ super(LARSOptimizer, self).__init__(use_locking=False, name=name)
+
+ self._learning_rate = learning_rate
+ self._momentum = momentum
+ self._weight_decay = weight_decay
+ self._eeta = eeta
+ self._epsilon = epsilon
+ self._name = name
+ self._skip_list = skip_list
+ self._use_nesterov = use_nesterov
+
+ def _create_slots(self, var_list):
+ for v in var_list:
+ self._zeros_slot(v, "momentum", self._name)
+
+ def compute_lr(self, grad, var):
+ scaled_lr = self._learning_rate
+ if self._skip_list is None or not any(v in var.name
+ for v in self._skip_list):
+ w_norm = linalg_ops.norm(var, ord=2)
+ g_norm = linalg_ops.norm(grad, ord=2)
+ trust_ratio = array_ops.where(
+ math_ops.greater(w_norm, 0),
+ array_ops.where(
+ math_ops.greater(g_norm, 0),
+ (self._eeta * w_norm /
+ (g_norm + self._weight_decay * w_norm + self._epsilon)), 1.0),
+ 1.0)
+ scaled_lr = self._learning_rate * trust_ratio
+ return scaled_lr
+
+ def _apply_dense(self, grad, var):
+ scaled_lr = self.compute_lr(grad, var)
+ mom = self.get_slot(var, "momentum")
+ return training_ops.apply_momentum(
+ var,
+ mom,
+ scaled_lr,
+ grad,
+ self._momentum,
+ use_locking=False,
+ use_nesterov=self._use_nesterov)
+
+ def _resource_apply_dense(self, grad, var):
+ scaled_lr = self.compute_lr(grad, var)
+ mom = self.get_slot(var, "momentum")
+ return training_ops.resource_apply_momentum(
+ var.handle,
+ mom.handle,
+ scaled_lr,
+ grad,
+ self._momentum,
+ use_locking=False,
+ use_nesterov=self._use_nesterov)
+
+ # Fallback to momentum optimizer for sparse tensors
+ def _apply_sparse(self, grad, var):
+ mom = self.get_slot(var, "momentum")
+ return training_ops.sparse_apply_momentum(
+ var,
+ mom,
+ math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
+ grad.values,
+ grad.indices,
+ math_ops.cast(self._momentum_tensor, var.dtype.base_dtype),
+ use_locking=self._use_locking,
+ use_nesterov=self._use_nesterov).op
+
+ def _resource_apply_sparse(self, grad, var, indices):
+ mom = self.get_slot(var, "momentum")
+ return training_ops.resource_sparse_apply_momentum(
+ var.handle,
+ mom.handle,
+ math_ops.cast(self._learning_rate_tensor, grad.dtype),
+ grad,
+ indices,
+ math_ops.cast(self._momentum_tensor, grad.dtype),
+ use_locking=self._use_locking,
+ use_nesterov=self._use_nesterov)
diff --git a/tensorflow/contrib/opt/python/training/lars_optimizer_test.py b/tensorflow/contrib/opt/python/training/lars_optimizer_test.py
new file mode 100644
index 0000000000..b76db763da
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/lars_optimizer_test.py
@@ -0,0 +1,127 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0. Licensed to the Apache
+# Software Foundation. You may not use this file except in compliance with the
+# License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test for Layer-wise Adaptive Rate Scaling optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.opt.python.training import lars_optimizer as lo
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class LARSOptimizerTest(test.TestCase):
+
+ def testLARSGradientOneStep(self):
+ for _ in range(10):
+ for dtype in [dtypes.float32, dtypes.float64]:
+ with self.cached_session() as sess:
+ shape = [3, 3]
+ var_np = np.ones(shape)
+ grad_np = np.ones(shape)
+ lr_np = 0.1
+ m_np = 0.9
+ wd_np = 0.1
+ ep_np = 1e-5
+ eeta = 0.1
+ vel_np = np.zeros(shape)
+
+ var = variables.Variable(var_np, dtype=dtype)
+ grad = variables.Variable(grad_np, dtype=dtype)
+ opt = lo.LARSOptimizer(
+ learning_rate=lr_np,
+ momentum=m_np,
+ weight_decay=wd_np,
+ eeta=eeta,
+ epsilon=ep_np)
+
+ step = opt.apply_gradients([(grad, var)])
+ variables.global_variables_initializer().run()
+
+ pre_var = sess.run(var)
+ pre_vel = sess.run(opt.get_slot(var, 'momentum'))
+ self.assertAllClose(var_np, pre_var)
+ self.assertAllClose(vel_np, pre_vel)
+
+ step.run()
+ post_var = sess.run(var)
+ post_vel = sess.run(opt.get_slot(var, 'momentum'))
+
+ w_norm = np.linalg.norm(var_np.flatten(), ord=2)
+ g_norm = np.linalg.norm(grad_np.flatten(), ord=2)
+ trust_ratio = eeta * w_norm / (g_norm + wd_np * w_norm + ep_np)
+ scaled_lr = lr_np * trust_ratio
+
+ vel_np = m_np * vel_np + grad_np
+ var_np -= scaled_lr * vel_np
+
+ self.assertAllClose(var_np, post_var)
+ self.assertAllClose(vel_np, post_vel)
+
+ def testLARSGradientMultiStep(self):
+ for _ in range(10):
+ for dtype in [dtypes.float32, dtypes.float64]:
+ with self.cached_session() as sess:
+ shape = [3, 3]
+ var_np = np.ones(shape)
+ grad_np = np.ones(shape)
+ lr_np = 0.1
+ m_np = 0.9
+ wd_np = 0.1
+ ep_np = 1e-5
+ eeta = 0.1
+ vel_np = np.zeros(shape)
+
+ var = variables.Variable(var_np, dtype=dtype)
+ grad = variables.Variable(grad_np, dtype=dtype)
+ opt = lo.LARSOptimizer(
+ learning_rate=lr_np,
+ momentum=m_np,
+ eeta=eeta,
+ weight_decay=wd_np,
+ epsilon=ep_np)
+
+ step = opt.apply_gradients([(grad, var)])
+ variables.global_variables_initializer().run()
+
+ pre_var = sess.run(var)
+ pre_vel = sess.run(opt.get_slot(var, 'momentum'))
+ self.assertAllClose(var_np, pre_var)
+ self.assertAllClose(vel_np, pre_vel)
+
+ for _ in range(10):
+ step.run()
+
+ post_var = sess.run(var)
+ post_vel = sess.run(opt.get_slot(var, 'momentum'))
+
+ w_norm = np.linalg.norm(var_np.flatten(), ord=2)
+ g_norm = np.linalg.norm(grad_np.flatten(), ord=2)
+ trust_ratio = eeta * w_norm / (g_norm + wd_np * w_norm + ep_np)
+ scaled_lr = lr_np * trust_ratio
+
+ vel_np = m_np * vel_np + grad_np
+ var_np -= scaled_lr * vel_np
+
+ self.assertAllClose(var_np, post_var)
+ self.assertAllClose(vel_np, post_vel)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
index a16857db7d..dc4c462ce4 100644
--- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
@@ -53,7 +53,7 @@ class AdamOptimizerTest(test.TestCase):
def testSparse(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -109,7 +109,7 @@ class AdamOptimizerTest(test.TestCase):
def testSparseRepeatedIndices(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
repeated_index_update_var = variables.Variable(
[[1.0], [2.0]], dtype=dtype)
aggregated_update_var = variables.Variable(
diff --git a/tensorflow/contrib/opt/python/training/matrix_functions.py b/tensorflow/contrib/opt/python/training/matrix_functions.py
new file mode 100644
index 0000000000..baab577638
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/matrix_functions.py
@@ -0,0 +1,155 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Matrix functions contains iterative methods for M^p."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+
+
+def matrix_square_root(mat_a, mat_a_size, iter_count=100, ridge_epsilon=1e-4):
+ """Iterative method to get matrix square root.
+
+ Stable iterations for the matrix square root, Nicholas J. Higham
+
+ Page 231, Eq 2.6b
+ http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.6.8799&rep=rep1&type=pdf
+
+ Args:
+ mat_a: the symmetric PSD matrix whose matrix square root be computed
+ mat_a_size: size of mat_a.
+ iter_count: Maximum number of iterations.
+ ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
+
+ Returns:
+ mat_a^0.5
+ """
+
+ def _iter_condition(i, unused_mat_y, unused_old_mat_y, unused_mat_z,
+ unused_old_mat_z, err, old_err):
+ # This method require that we check for divergence every step.
+ return math_ops.logical_and(i < iter_count, err < old_err)
+
+ def _iter_body(i, mat_y, unused_old_mat_y, mat_z, unused_old_mat_z, err,
+ unused_old_err):
+ current_iterate = 0.5 * (3.0 * identity - math_ops.matmul(mat_z, mat_y))
+ current_mat_y = math_ops.matmul(mat_y, current_iterate)
+ current_mat_z = math_ops.matmul(current_iterate, mat_z)
+ # Compute the error in approximation.
+ mat_sqrt_a = current_mat_y * math_ops.sqrt(norm)
+ mat_a_approx = math_ops.matmul(mat_sqrt_a, mat_sqrt_a)
+ residual = mat_a - mat_a_approx
+ current_err = math_ops.sqrt(math_ops.reduce_sum(residual * residual)) / norm
+ return i + 1, current_mat_y, mat_y, current_mat_z, mat_z, current_err, err
+
+ identity = linalg_ops.eye(math_ops.to_int32(mat_a_size))
+ mat_a = mat_a + ridge_epsilon * identity
+ norm = math_ops.sqrt(math_ops.reduce_sum(mat_a * mat_a))
+ mat_init_y = mat_a / norm
+ mat_init_z = identity
+ init_err = norm
+
+ _, _, prev_mat_y, _, _, _, _ = control_flow_ops.while_loop(
+ _iter_condition, _iter_body, [
+ 0, mat_init_y, mat_init_y, mat_init_z, mat_init_z, init_err,
+ init_err + 1.0
+ ])
+ return prev_mat_y * math_ops.sqrt(norm)
+
+
+def matrix_inverse_pth_root(mat_g,
+ mat_g_size,
+ alpha,
+ iter_count=100,
+ epsilon=1e-6,
+ ridge_epsilon=1e-6):
+ """Computes mat_g^alpha, where alpha = -1/p, p a positive integer.
+
+ We use an iterative Schur-Newton method from equation 3.2 on page 9 of:
+
+ A Schur-Newton Method for the Matrix p-th Root and its Inverse
+ by Chun-Hua Guo and Nicholas J. Higham
+ SIAM Journal on Matrix Analysis and Applications,
+ 2006, Vol. 28, No. 3 : pp. 788-804
+ https://pdfs.semanticscholar.org/0abe/7f77433cf5908bfe2b79aa91af881da83858.pdf
+
+ Args:
+ mat_g: the symmetric PSD matrix whose power it to be computed
+ mat_g_size: size of mat_g.
+ alpha: exponent, must be -1/p for p a positive integer.
+ iter_count: Maximum number of iterations.
+ epsilon: accuracy indicator, useful for early termination.
+ ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
+
+ Returns:
+ mat_g^alpha
+ """
+
+ identity = linalg_ops.eye(math_ops.to_int32(mat_g_size))
+
+ def mat_power(mat_m, p):
+ """Computes mat_m^p, for p a positive integer.
+
+ Power p is known at graph compile time, so no need for loop and cond.
+ Args:
+ mat_m: a square matrix
+ p: a positive integer
+
+ Returns:
+ mat_m^p
+ """
+ assert p == int(p) and p > 0
+ power = None
+ while p > 0:
+ if p % 2 == 1:
+ power = math_ops.matmul(mat_m, power) if power is not None else mat_m
+ p //= 2
+ mat_m = math_ops.matmul(mat_m, mat_m)
+ return power
+
+ def _iter_condition(i, mat_m, _):
+ return math_ops.logical_and(
+ i < iter_count,
+ math_ops.reduce_max(math_ops.abs(mat_m - identity)) > epsilon)
+
+ def _iter_body(i, mat_m, mat_x):
+ mat_m_i = (1 - alpha) * identity + alpha * mat_m
+ return (i + 1, math_ops.matmul(mat_power(mat_m_i, -1.0 / alpha), mat_m),
+ math_ops.matmul(mat_x, mat_m_i))
+
+ if mat_g_size == 1:
+ mat_h = math_ops.pow(mat_g + ridge_epsilon, alpha)
+ else:
+ damped_mat_g = mat_g + ridge_epsilon * identity
+ z = (1 - 1 / alpha) / (2 * linalg_ops.norm(damped_mat_g))
+ # The best value for z is
+ # (1 - 1/alpha) * (c_max^{-alpha} - c_min^{-alpha}) /
+ # (c_max^{1-alpha} - c_min^{1-alpha})
+ # where c_max and c_min are the largest and smallest singular values of
+ # damped_mat_g.
+ # The above estimate assumes that c_max > c_min * 2^p. (p = -1/alpha)
+ # Can replace above line by the one below, but it is less accurate,
+ # hence needs more iterations to converge.
+ # z = (1 - 1/alpha) / math_ops.trace(damped_mat_g)
+ # If we want the method to always converge, use z = 1 / norm(damped_mat_g)
+ # or z = 1 / math_ops.trace(damped_mat_g), but these can result in many
+ # extra iterations.
+ _, _, mat_h = control_flow_ops.while_loop(
+ _iter_condition, _iter_body,
+ [0, damped_mat_g * z, identity * math_ops.pow(z, -alpha)])
+ return mat_h
diff --git a/tensorflow/contrib/opt/python/training/matrix_functions_test.py b/tensorflow/contrib/opt/python/training/matrix_functions_test.py
new file mode 100644
index 0000000000..518fa38233
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/matrix_functions_test.py
@@ -0,0 +1,63 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for Matrix functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.opt.python.training import matrix_functions
+from tensorflow.python.platform import test
+
+TOLERANCE = 1e-3
+
+
+def np_power(mat_g, alpha):
+ """Computes mat_g^alpha for a square symmetric matrix mat_g."""
+
+ mat_u, diag_d, mat_v = np.linalg.svd(mat_g)
+ diag_d = np.power(diag_d, alpha)
+ return np.dot(np.dot(mat_u, np.diag(diag_d)), mat_v)
+
+
+class MatrixFunctionTests(test.TestCase):
+
+ def testMatrixSquareRootFunction(self):
+ """Tests for matrix square roots."""
+
+ size = 20
+ mat_a = np.random.rand(size, size)
+ mat = np.dot(mat_a, mat_a.T)
+ expected_mat = np_power(mat, 0.5)
+ mat_root = matrix_functions.matrix_square_root(mat, size)
+ self.assertAllCloseAccordingToType(
+ expected_mat, mat_root, atol=TOLERANCE, rtol=TOLERANCE)
+
+ def testMatrixInversePthRootFunction(self):
+ """Tests for matrix inverse pth roots."""
+
+ size = 20
+ mat_a = np.random.rand(size, size)
+ mat = np.dot(mat_a, mat_a.T)
+ expected_mat = np_power(mat, -0.125)
+ mat_root = matrix_functions.matrix_inverse_pth_root(mat, size, -0.125)
+ self.assertAllCloseAccordingToType(
+ expected_mat, mat_root, atol=TOLERANCE, rtol=TOLERANCE)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py
index ac04ad9911..f22e724528 100644
--- a/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py
@@ -46,7 +46,7 @@ class MovingAverageOptimizerTest(test.TestCase):
def _helpTestRun(self, use_resource=False):
for sequential_update in [True, False]:
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
orig_val0 = [1.0, 2.0]
orig_val1 = [3.0, 4.0]
var0 = variable_scope.get_variable(
@@ -165,7 +165,7 @@ class MovingAverageOptimizerTest(test.TestCase):
self.assertLess(avg_val1[i], orig_val1[i])
def testFailWhenSaverCreatedBeforeInitialized(self):
- with self.test_session():
+ with self.cached_session():
var = variables.Variable([1.0], name='var', dtype=dtypes.float32)
opt = moving_average_optimizer.MovingAverageOptimizer(
gradient_descent.GradientDescentOptimizer(learning_rate=2.0))
@@ -187,7 +187,7 @@ class MovingAverageOptimizerTest(test.TestCase):
self.apply_gradients_called = True
return super(WrapperOptimizer, self).apply_gradients(*args, **kwargs)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
var = variables.Variable([1.2], name='var', dtype=dtypes.float32)
loss = var ** 2
wrapper_opt = WrapperOptimizer(learning_rate=2.0)
diff --git a/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper_test.py b/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper_test.py
index 618d8eb18d..904aa9ab13 100644
--- a/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper_test.py
+++ b/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper_test.py
@@ -34,7 +34,7 @@ class MultitaskOptimizerWrapperTest(test.TestCase):
"""
def testWrapper(self):
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
var1 = variables.Variable([3.0, 4.0], dtype=dtypes.float32)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtypes.float32)
@@ -92,7 +92,7 @@ class MultitaskOptimizerWrapperTest(test.TestCase):
self.evaluate(slot1))
def testGradientClipping(self):
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
var1 = variables.Variable([3.0, 4.0], dtype=dtypes.float32)
var2 = variables.Variable([3.0, 4.0], dtype=dtypes.float32)
diff --git a/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py b/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py
index 825c08a09a..85e05ce71c 100644
--- a/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py
@@ -53,7 +53,7 @@ class NadamOptimizerTest(test.TestCase):
def doTestSparse(self, use_resource=False):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -106,7 +106,7 @@ class NadamOptimizerTest(test.TestCase):
def doTestBasic(self, use_resource=False):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
diff --git a/tensorflow/contrib/opt/python/training/powersign.py b/tensorflow/contrib/opt/python/training/powersign.py
index 828f3c51c9..b4aa19264d 100644
--- a/tensorflow/contrib/opt/python/training/powersign.py
+++ b/tensorflow/contrib/opt/python/training/powersign.py
@@ -65,7 +65,7 @@ class PowerSignOptimizer(optimizer.Optimizer):
Example usage for PowerSign-cd (PowerSign with cosine sign decay)
```
decay_steps = 1000
- linear_decay_fn = sign_decays.get_linear_decay_fn(decay_steps)
+ linear_decay_fn = sign_decays.get_cosine_decay_fn(decay_steps)
opt = PowerSignOptimizer(learning_rate=0.1, sign_decay_fn=linear_decay_fn)
```
diff --git a/tensorflow/contrib/opt/python/training/powersign_test.py b/tensorflow/contrib/opt/python/training/powersign_test.py
index 5214082dd6..0bcf5d230a 100644
--- a/tensorflow/contrib/opt/python/training/powersign_test.py
+++ b/tensorflow/contrib/opt/python/training/powersign_test.py
@@ -216,7 +216,7 @@ class PowerSignTest(test.TestCase):
self.assertAllClose([1.0, 2.0], var0.eval())
self.assertAllClose([3.0, 4.0], var1.eval())
- # Run 3 steps of powersign
+ # Run 7 steps of powersign
# first 4 steps with positive gradient
# last 3 steps with negative gradient (sign(gm) should be -1)
for t in range(1, 8):
diff --git a/tensorflow/contrib/opt/python/training/reg_adagrad_optimizer_test.py b/tensorflow/contrib/opt/python/training/reg_adagrad_optimizer_test.py
index ea56e1646a..c09e2ac76d 100644
--- a/tensorflow/contrib/opt/python/training/reg_adagrad_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/reg_adagrad_optimizer_test.py
@@ -36,7 +36,7 @@ class RegAdagradOptimizerTest(test.TestCase):
def doTestBasic(self, use_locking=False, use_resource=False):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
if use_resource:
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
@@ -73,7 +73,7 @@ class RegAdagradOptimizerTest(test.TestCase):
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable(
[[1.0, 2.0], [3.0, 4.0]], dtype=dtype)
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
@@ -92,7 +92,7 @@ class RegAdagradOptimizerTest(test.TestCase):
def testTensorLearningRate(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -116,7 +116,7 @@ class RegAdagradOptimizerTest(test.TestCase):
def testSparseBasic(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
var1 = variables.Variable([[3.0], [4.0]], dtype=dtype)
grads0 = ops.IndexedSlices(
@@ -144,7 +144,7 @@ class RegAdagradOptimizerTest(test.TestCase):
def testSparseRepeatedIndices(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
repeated_index_update_var = variables.Variable(
[[1.0], [2.0]], dtype=dtype)
aggregated_update_var = variables.Variable([[1.0], [2.0]], dtype=dtype)
@@ -170,7 +170,7 @@ class RegAdagradOptimizerTest(test.TestCase):
def testSparseRepeatedIndicesResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var_repeated = resource_variable_ops.ResourceVariable(
[1.0, 2.0], dtype=dtype)
loss_repeated = math_ops.reduce_sum(
@@ -194,7 +194,7 @@ class RegAdagradOptimizerTest(test.TestCase):
def testSparseStability(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
shape = [1, 6]
var0 = variables.Variable(
[[
@@ -230,7 +230,7 @@ class RegAdagradOptimizerTest(test.TestCase):
def testSharing(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -263,7 +263,7 @@ class RegAdagradOptimizerTest(test.TestCase):
np.array([2.715679168701172, 3.715679168701172]), var1.eval())
def testDynamicShapeVariable_Ok(self):
- with self.test_session():
+ with self.cached_session():
v = variable_scope.get_variable(
"v", initializer=constant_op.constant(1.), validate_shape=False)
self.assertFalse(v.shape.is_fully_defined())
@@ -274,7 +274,7 @@ class RegAdagradOptimizerTest(test.TestCase):
def testSkipUpdatingSlots(self):
iav = 0.130005 # A value that works with float16
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -306,7 +306,7 @@ class RegAdagradOptimizerTest(test.TestCase):
def testSparseSkipUpdatingSlots(self):
iav = 0.130005 # A value that works with float16
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
var1 = variables.Variable([[3.0], [4.0]], dtype=dtype)
grads0 = ops.IndexedSlices(
diff --git a/tensorflow/contrib/opt/python/training/shampoo.py b/tensorflow/contrib/opt/python/training/shampoo.py
new file mode 100644
index 0000000000..f161521b97
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/shampoo.py
@@ -0,0 +1,420 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""The Shampoo Optimizer.
+
+Variant of Adagrad using one preconditioner matrix per variable dimension.
+For details, see https://arxiv.org/abs/1802.09568
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from tensorflow.contrib.opt.python.training import matrix_functions
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.platform import tf_logging
+from tensorflow.python.training import optimizer
+
+
+def GetParam(var, timestep):
+ if callable(var):
+ return var(timestep)
+ else:
+ return var
+
+
+class ShampooOptimizer(optimizer.Optimizer):
+ """The Shampoo Optimizer
+
+ Variant of Adagrad using one preconditioner matrix per variable dimension.
+ For details, see https://arxiv.org/abs/1802.09568
+
+ gbar is time-weighted accumulated gradient:
+ gbar[t] = gbar_decay[t] * gbar[t-1] + gbar_weight[t] * g[t]
+
+ mat_gbar is time-weighted accumulated gradient square:
+ mat_gbar_j[t] = mat_gbar_decay[t] * mat_gbar_j[t-1]
+ + mat_gbar_weight[t] * gg_j[t]
+ where if g[t] = g_abcd then gg_a[t] = g_abcd g_a'bcd (Einstein notation)
+
+ Update rule:
+ w[t+1] = w[t] - learning_rate[t] * Prod_j mat_gbar_j[t]^(-alpha/n) gbar[t]
+ Again, mat_gbar_j[t]^(-alpha) gbar[t] is a tensor contraction along the
+ j'th dimension of gbar[t] with the first dimension of
+ mat_gbar_j[t]^(-alpha/n), where alpha is a hyperparameter,
+ and n = rank of the variable.
+ Prod_j represents doing this contraction for all j in 0..n-1.
+
+ Typically learning_rate is constant, but could be time dependent by passing
+ a lambda function that depends on step.
+ """
+
+ def __init__(self,
+ global_step=0,
+ max_matrix_size=768,
+ gbar_decay=0.0,
+ gbar_weight=1.0,
+ mat_gbar_decay=1.0,
+ mat_gbar_weight=1.0,
+ learning_rate=1.0,
+ svd_interval=1,
+ precond_update_interval=1,
+ epsilon=1e-4,
+ alpha=0.5,
+ use_iterative_root=False,
+ use_locking=False,
+ name="Shampoo"):
+ """Default values of the various hyper-parameters.
+
+ gbar_decay, gbar_weight etc. can be a float or a time varying parameter.
+ For time-varying parameters use e.g. "lambda T: T / (T + 1.0)"
+ where the expression in the lambda is a tensorflow expression
+
+ Args:
+ global_step: tensorflow variable indicating the step.
+ max_matrix_size: We do not perform SVD for matrices larger than this.
+ gbar_decay:
+ gbar_weight: Used to update gbar:
+ gbar[t] = gbar_decay[t] * gbar[t-1] + gbar_weight[t] * g[t]
+ mat_gbar_decay:
+ mat_gbar_weight: Used to update mat_gbar:
+ mat_gbar_j[t] = mat_gbar_decay[t] * mat_gbar_j[t-1]
+ + mat_gbar_weight[t] * gg_j[t]
+ learning_rate: Similar to SGD
+ svd_interval: We should do SVD after this many steps. Default = 1, i.e.
+ every step. Usually 20 leads to no loss of accuracy, and
+ 50 or 100 is also OK. May also want more often early,
+ and less often later - set in caller as for example:
+ "svd_interval = lambda(T): tf.cond(
+ T < 2000, lambda: 20.0, lambda: 1000.0)"
+ precond_update_interval: We should update the preconditioners after
+ this many steps. Default = 1. Usually less than
+ svd_interval.
+ epsilon: epsilon * I_n is added to each mat_gbar_j for stability
+ alpha: total power of the preconditioners.
+ use_iterative_root: should the optimizer use SVD (faster) or the
+ iterative root method (for TPU) for finding the
+ roots of PSD matrices.
+ use_locking:
+ name: name of optimizer.
+ """
+
+ super(ShampooOptimizer, self).__init__(use_locking, name)
+
+ self._global_step = math_ops.to_float(global_step)
+ self._max_matrix_size = max_matrix_size
+ self._gbar_decay = gbar_decay
+ self._gbar_weight = gbar_weight
+ self._mat_gbar_decay = mat_gbar_decay
+ self._mat_gbar_weight = mat_gbar_weight
+ self._learning_rate = learning_rate
+ self._svd_interval = svd_interval
+ self._precond_update_interval = precond_update_interval
+ self._epsilon = epsilon
+ self._alpha = alpha
+ self._use_iterative_root = use_iterative_root
+ self._name = name
+
+ def _create_slots(self, var_list):
+ for v in var_list:
+ with ops.colocate_with(v):
+ _ = self._zeros_slot(v, "gbar", self._name)
+ shape = np.array(v.get_shape())
+ for i, d in enumerate(shape):
+ d_tensor = ops.convert_to_tensor(d)
+ if d <= self._max_matrix_size:
+ mat_g_init = array_ops.zeros_like(linalg_ops.eye(d_tensor))
+ if self._svd_interval > 1:
+ _ = self._get_or_make_slot(v, linalg_ops.eye(d_tensor),
+ "H_" + str(i), self._name)
+ else:
+ mat_g_init = array_ops.zeros([d_tensor])
+
+ _ = self._get_or_make_slot(v, mat_g_init, "Gbar_" + str(i),
+ self._name)
+
+ def _resource_apply_dense(self, grad, var):
+ return self._apply_dense(grad, var)
+
+ def _apply_dense(self, grad, var):
+ return self._apply_gradient(grad, var)
+
+ def _resource_apply_sparse(self, grad_values, var, grad_indices):
+ return self._apply_sparse_shared(grad_values, grad_indices, var)
+
+ def _apply_sparse(self, grad, var):
+ return self._apply_sparse_shared(grad.values, grad.indices, var)
+
+ def _apply_sparse_shared(self, grad_values, grad_indices, var):
+ if var.get_shape()[0] <= self._max_matrix_size or self._gbar_decay != 0.0:
+ # The dimension is small enough, we can make the variable dense and
+ # do a dense update
+ dense_grad = array_ops.scatter_nd(
+ array_ops.expand_dims(grad_indices, axis=1), grad_values,
+ array_ops.shape(var, out_type=grad_indices.dtype))
+ return self._apply_gradient(dense_grad, var)
+ return self._apply_gradient(grad_values, var, grad_indices)
+
+ def _weighted_average(self, var, weight, weight_t, rest):
+ """Computes exponential weighted average: var = weight_t * var + rest.
+
+ Important to ensure that var does not occur in rest, otherwise
+ we can get race conditions in a distributed setting.
+
+ Args:
+ var: variable to be updated
+ weight: parameter to be checked. If it is a constant, we can optimize.
+ weight_t: current value of parameter, used for weighting
+ rest: the remaining tensor to be added
+
+ Returns:
+ updated variable.
+ """
+ if weight == 0.0:
+ return rest # no need to update var, we will never use it.
+ if weight == 1.0: # common case
+ return state_ops.assign_add(var, rest)
+ # The op below can cause race conditions in a distributed setting,
+ # since computing weight_t * var + rest can take some time, during
+ # which var may be set by another worker. To prevent this, it should
+ # be implemented as a C++ op.
+ return var.assign_add((weight_t - 1) * var + rest)
+
+ def _update_mat_g(self, mat_g, grad, axes, mat_gbar_decay,
+ mat_gbar_weight, i):
+ """Updates the cumulative outer products of the gradients.
+
+ Args:
+ mat_g: the matrix to be updated
+ grad: the gradient of the variable
+ axes: a list of k-1 integers 0 to k-1, except i
+ mat_gbar_decay: constant for weighted average:
+ mat_g = mat_g * decay + grad * weight
+ mat_gbar_weight: constant for weighted average
+ i: index of dimension to be updated.
+
+ Returns:
+ updated mat_g = mat_g * mat_gbar_decay + grad_outer * mat_gbar_weight
+
+ In Einstein notation if i = 0: grad_outer_aa'= g_abcd g_a'bcd
+ thus grad_outer is a matrix d_i x d_i, where d_i is the size of the
+ i'th dimension of g.
+ Alternate view: If mat_i(grad) is the flattening of grad to a
+ d_i x (d_1d_2...d_{i-1}d_{i+1}...d_k) matrix, then
+ grad_outer = mat_i(grad) mat_i(grad).transpose
+ """
+ grad_outer = math_ops.tensordot(grad, grad, axes=(axes, axes),
+ name="grad_outer_" + str(i))
+ return self._weighted_average(mat_g, self._mat_gbar_decay, mat_gbar_decay,
+ mat_gbar_weight * grad_outer)
+
+ def _compute_power_svd(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name):
+ """Computes mat_h = mat_g^alpha using svd. mat_g is a symmetric PSD matrix.
+
+ Args:
+ var: the variable we are updating.
+ mat_g: the symmetric PSD matrix whose power it to be computed
+ mat_g_size: size of mat_g
+ alpha: a real number
+ mat_h_slot_name: name of slot to store the power, if needed.
+
+ Returns:
+ mat_h = mat_g^alpha
+
+ Stores mat_h in the appropriate slot, if it exists.
+ Note that mat_g is PSD. So we could use linalg_ops.self_adjoint_eig.
+ """
+ if mat_g_size == 1:
+ mat_h = math_ops.pow(mat_g + self._epsilon, alpha)
+ else:
+ damping = self._epsilon * linalg_ops.eye(math_ops.to_int32(mat_g_size))
+ diag_d, mat_u, mat_v = linalg_ops.svd(mat_g + damping, full_matrices=True)
+ mat_h = math_ops.matmul(
+ mat_v * math_ops.pow(math_ops.maximum(diag_d, self._epsilon), alpha),
+ array_ops.transpose(mat_u))
+ if mat_h_slot_name is not None:
+ return state_ops.assign(self.get_slot(var, mat_h_slot_name), mat_h)
+ return mat_h
+
+ def _compute_power_iter(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name,
+ iter_count=100, epsilon=1e-6):
+ """Computes mat_g^alpha, where alpha = -1/p, p a positive integer."""
+
+ mat_g_sqrt = matrix_functions.matrix_square_root(mat_g, mat_g_size,
+ iter_count, self._epsilon)
+ mat_h = matrix_functions.matrix_inverse_pth_root(
+ mat_g_sqrt,
+ mat_g_size,
+ 2 * alpha,
+ iter_count,
+ epsilon,
+ ridge_epsilon=0.0)
+
+ if mat_h_slot_name is not None:
+ return state_ops.assign(self.get_slot(var, mat_h_slot_name), mat_h)
+ return mat_h
+
+ def _compute_power(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name=None):
+ """Just a switch between the iterative power vs svd."""
+ with ops.name_scope("matrix_iterative_power"):
+ if self._use_iterative_root:
+ return self._compute_power_iter(var, mat_g, mat_g_size, alpha,
+ mat_h_slot_name)
+ else:
+ return self._compute_power_svd(var, mat_g, mat_g_size, alpha,
+ mat_h_slot_name)
+
+ def _apply_gradient(self, grad, var, indices=None):
+ """The main function to update a variable.
+
+ Args:
+ grad: A Tensor containing gradient to apply.
+ var: A Tensor containing the variable to update.
+ indices: An array of integers, for sparse update.
+
+ Returns:
+ Updated variable var = var - learning_rate * preconditioner * grad
+
+ If the gradient is dense, var and grad have the same shape.
+ If the update is sparse, then the first dimension of the gradient and var
+ may differ, others are all the same. In this case the indices array
+ provides the set of indices of the variable which are to be updated with
+ each row of the gradient.
+ """
+ global_step = self._global_step + 1
+
+ # Update accumulated weighted average of gradients
+ gbar = self.get_slot(var, "gbar")
+ gbar_decay_t = GetParam(self._gbar_decay, global_step)
+ gbar_weight_t = GetParam(self._gbar_weight, global_step)
+ if indices is not None:
+ # Note - the sparse update is not easily implemented, since the
+ # algorithm needs all indices of gbar to be updated
+ # if mat_gbar_decay != 1 or mat_gbar_decay != 0.
+ # One way to make mat_gbar_decay = 1 is by rescaling.
+ # If we want the update:
+ # G_{t+1} = a_{t+1} G_t + b_{t+1} w_t
+ # define:
+ # r_{t+1} = a_{t+1} * r_t
+ # h_t = G_t / r_t
+ # Then:
+ # h_{t+1} = h_t + (b_{t+1} / r_{t+1}) * w_t
+ # So we get the mat_gbar_decay = 1 as desired.
+ # We can implement this in a future version as needed.
+ # However we still need gbar_decay = 0, otherwise all indices
+ # of the variable will need to be updated.
+ if self._gbar_decay != 0.0:
+ tf_logging.warning("Not applying momentum for variable: %s" % var.name)
+ gbar_updated = grad
+ else:
+ gbar_updated = self._weighted_average(gbar, self._gbar_decay,
+ gbar_decay_t,
+ gbar_weight_t * grad)
+
+ # Update the preconditioners and compute the preconditioned gradient
+ shape = var.get_shape()
+ mat_g_list = []
+ for i in range(len(shape)):
+ mat_g_list.append(self.get_slot(var, "Gbar_" + str(i)))
+ mat_gbar_decay_t = GetParam(self._mat_gbar_decay, global_step)
+ mat_gbar_weight_t = GetParam(self._mat_gbar_weight, global_step)
+
+ preconditioned_grad = gbar_updated
+ v_rank = len(mat_g_list)
+ neg_alpha = - GetParam(self._alpha, global_step) / v_rank
+ svd_interval = GetParam(self._svd_interval, global_step)
+ precond_update_interval = GetParam(self._precond_update_interval,
+ global_step)
+ for i, mat_g in enumerate(mat_g_list):
+ # axes is the list of indices to reduce - everything but the current i.
+ axes = list(range(i)) + list(range(i+1, v_rank))
+ if shape[i] <= self._max_matrix_size:
+ # If the tensor size is sufficiently small perform full Shampoo update
+ # Note if precond_update_interval > 1 and mat_gbar_decay_t != 1, this
+ # is not strictly correct. However we will use it for now, and
+ # fix if needed. (G_1 = aG + bg ==> G_n = a^n G + (1+a+..+a^{n-1})bg)
+
+ # pylint: disable=g-long-lambda,cell-var-from-loop
+ mat_g_updated = control_flow_ops.cond(
+ math_ops.mod(global_step, precond_update_interval) < 1,
+ lambda: self._update_mat_g(
+ mat_g, grad, axes, mat_gbar_decay_t,
+ mat_gbar_weight_t * precond_update_interval, i),
+ lambda: mat_g)
+
+ mat_g_updated = mat_g_updated / float(shape[i].value)
+
+ if self._svd_interval == 1:
+ mat_h = self._compute_power(var, mat_g_updated, shape[i], neg_alpha)
+ else:
+ mat_h = control_flow_ops.cond(
+ math_ops.mod(global_step, svd_interval) < 1,
+ lambda: self._compute_power(var, mat_g_updated, shape[i],
+ neg_alpha, "H_" + str(i)),
+ lambda: self.get_slot(var, "H_" + str(i)))
+
+ # mat_h is a square matrix of size d_i x d_i
+ # preconditioned_grad is a d_i x ... x d_n x d_0 x ... d_{i-1} tensor
+ # After contraction with a d_i x d_i tensor
+ # it becomes a d_{i+1} x ... x d_n x d_0 x ... d_i tensor
+ # (the first dimension is contracted out, and the second dimension of
+ # mat_h is appended). After going through all the indices, it becomes
+ # a d_0 x ... x d_n tensor again.
+ preconditioned_grad = math_ops.tensordot(preconditioned_grad, mat_h,
+ axes=([0], [0]),
+ name="precond_" + str(i))
+ else:
+ # Tensor size is too large -- perform diagonal Shampoo update
+ # Only normalize non-vector cases.
+ if axes:
+ normalizer = 1.0 if indices is not None else float(shape[i].value)
+ grad_outer = math_ops.reduce_sum(grad * grad, axis=axes) / normalizer
+ else:
+ grad_outer = grad * grad
+
+ if i == 0 and indices is not None:
+ assert self._mat_gbar_decay == 1.0
+ mat_g_updated = state_ops.scatter_add(mat_g, indices,
+ mat_gbar_weight_t * grad_outer)
+ mat_h = math_ops.pow(
+ array_ops.gather(mat_g_updated, indices) + self._epsilon,
+ neg_alpha)
+ else:
+ mat_g_updated = self._weighted_average(mat_g,
+ self._mat_gbar_decay,
+ mat_gbar_decay_t,
+ mat_gbar_weight_t * grad_outer)
+ mat_h = math_ops.pow(mat_g_updated + self._epsilon, neg_alpha)
+
+ # Need to do the transpose to ensure that the tensor becomes
+ # a d_{i+1} x ... x d_n x d_0 x ... d_i tensor as described above.
+ preconditioned_grad = array_ops.transpose(
+ preconditioned_grad, perm=list(range(1, v_rank)) + [0]) * mat_h
+
+ # Update the variable based on the Shampoo update
+ learning_rate_t = GetParam(self._learning_rate, global_step)
+ if indices is not None:
+ var_updated = state_ops.scatter_add(
+ var, indices, -learning_rate_t * preconditioned_grad)
+ else:
+ var_updated = state_ops.assign_sub(var,
+ learning_rate_t * preconditioned_grad)
+ return var_updated
diff --git a/tensorflow/contrib/opt/python/training/shampoo_test.py b/tensorflow/contrib/opt/python/training/shampoo_test.py
new file mode 100644
index 0000000000..05bcf2cfa3
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/shampoo_test.py
@@ -0,0 +1,772 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Functional tests for AdaMoo optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.contrib.opt.python.training import shampoo
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+TOLERANCE = 1e-3
+RIDGE_EPSILON = 1e-4
+
+
+def np_power(mat_g, alpha):
+ """Computes mat_g^alpha for a square symmetric matrix mat_g."""
+
+ mat_u, diag_d, mat_v = np.linalg.svd(mat_g)
+ diag_d = np.power(diag_d, alpha)
+ return np.dot(np.dot(mat_u, np.diag(diag_d)), mat_v)
+
+
+class ShampooTest(test.TestCase, parameterized.TestCase):
+
+ @parameterized.named_parameters(('Var', False), ('ResourceVar', True))
+ def testBasicVector(self, use_resource_var):
+ """Similar to the full Adagrad update."""
+
+ size = 20
+ init_var_np = np.zeros(size)
+ grad_np = np.random.rand(size)
+ grad_np_2 = np.random.rand(size)
+
+ with self.cached_session() as sess:
+ global_step = variables.Variable(
+ 0, dtype=dtypes.int64, use_resource=use_resource_var)
+ var = variables.Variable(
+ init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
+ grad = constant_op.constant(grad_np, dtype=dtypes.float32)
+ grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
+
+ opt = shampoo.ShampooOptimizer(global_step)
+ update = opt.apply_gradients(zip([grad], [var]),
+ global_step=global_step)
+ update_2 = opt.apply_gradients(zip([grad_2], [var]),
+ global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ init_val = sess.run(var)
+ self.assertAllCloseAccordingToType(init_var_np, init_val)
+
+ # Run a step of Shampoo
+ update.run()
+ new_val = sess.run(var)
+
+ # let up compute this in numpy
+ # Update rule is var = var - lr * mat_g^{-0.5} * grad
+ # lr = 1
+ mat_g = np.outer(grad_np, grad_np) / grad_np.shape[0]
+ mat_h = np_power(mat_g + RIDGE_EPSILON * np.eye(size), -0.5)
+ new_val_np = init_var_np - np.dot(mat_h, grad_np)
+
+ self.assertAllCloseAccordingToType(new_val_np, new_val,
+ atol=TOLERANCE, rtol=TOLERANCE)
+
+ # Run another step of Shampoo
+ update_2.run()
+ new_val = sess.run(var)
+
+ mat_g += np.outer(grad_np_2, grad_np_2) / grad_np.shape[0]
+ mat_h = np_power(mat_g + RIDGE_EPSILON * np.eye(size), -0.5)
+ new_val_np -= np.dot(mat_h, grad_np_2)
+
+ self.assertAllCloseAccordingToType(new_val_np, new_val,
+ atol=TOLERANCE, rtol=TOLERANCE)
+
+ @parameterized.named_parameters(('Var', False), ('ResourceVar', True))
+ def testBasicMatrix(self, use_resource_var):
+ """Check update when gradient is a matrix."""
+ size = [10, 5]
+ init_var_np = np.zeros(size)
+ grad_np = np.random.rand(size[0], size[1])
+ grad_np_2 = np.random.rand(size[0], size[1])
+
+ with self.cached_session() as sess:
+ global_step = variables.Variable(
+ 0, dtype=dtypes.int64, use_resource=use_resource_var)
+ var = variables.Variable(
+ init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
+ grad = constant_op.constant(grad_np, dtype=dtypes.float32)
+ grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
+
+ opt = shampoo.ShampooOptimizer(global_step)
+ update = opt.apply_gradients(zip([grad], [var]),
+ global_step=global_step)
+ update_2 = opt.apply_gradients(zip([grad_2], [var]),
+ global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ init_val = sess.run(var)
+ self.assertAllCloseAccordingToType(init_var_np, init_val)
+
+ # Run a step of Shampoo
+ update.run()
+ new_val = sess.run(var)
+
+ # let up compute this in numpy
+ # Update rule is var = var - lr * mat_g1^{-0.25} * grad * mat_g2^{-0.25}
+ # lr = 1
+ mat_g1 = np.dot(grad_np, grad_np.transpose()) / grad_np.shape[0]
+ mat_left = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.25)
+ mat_g2 = np.dot(grad_np.transpose(), grad_np) / grad_np.shape[1]
+ mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25)
+ new_val_np = init_var_np - np.dot(np.dot(mat_left, grad_np), mat_right)
+
+ self.assertAllCloseAccordingToType(new_val_np, new_val,
+ atol=TOLERANCE, rtol=TOLERANCE)
+
+ # Run another step of Shampoo
+ update_2.run()
+ new_val = sess.run(var)
+
+ mat_g1 += np.dot(grad_np_2, grad_np_2.transpose()) / grad_np_2.shape[0]
+ mat_left = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.25)
+ mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) / grad_np_2.shape[1]
+ mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25)
+ new_val_np -= np.dot(np.dot(mat_left, grad_np_2), mat_right)
+
+ self.assertAllCloseAccordingToType(new_val_np, new_val,
+ atol=TOLERANCE, rtol=TOLERANCE)
+
+ def _testBasicTensor(self, use_iterative_root, use_resource_var):
+ """Check update when gradient is a tensor.
+
+ Args:
+ use_iterative_root: use iterative power method or SVD to find nth roots.
+ use_resource_var: use resource var as variables.
+ """
+ size = [10, 5, 7]
+ init_var_np = np.zeros(size)
+ grad_np = np.random.rand(size[0], size[1], size[2])
+ grad_np_2 = np.random.rand(size[0], size[1], size[2])
+
+ with self.cached_session() as sess:
+ global_step = variables.Variable(
+ 0, dtype=dtypes.int64, use_resource=use_resource_var)
+ var = variables.Variable(
+ init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
+ grad = constant_op.constant(grad_np, dtype=dtypes.float32)
+ grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
+
+ opt = shampoo.ShampooOptimizer(global_step,
+ use_iterative_root=use_iterative_root)
+ update = opt.apply_gradients(zip([grad], [var]),
+ global_step=global_step)
+ update_2 = opt.apply_gradients(zip([grad_2], [var]),
+ global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ init_val = sess.run(var)
+ self.assertAllCloseAccordingToType(init_var_np, init_val)
+
+ # Run a step of Shampoo
+ update.run()
+ new_val = sess.run(var)
+
+ # let up compute this in numpy
+ # Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad
+ # lr = 1
+ mat_g1 = (
+ np.tensordot(grad_np, grad_np, axes=([1, 2], [1, 2])) /
+ grad_np.shape[0])
+ mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.5 / 3.0)
+ mat_g2 = (
+ np.tensordot(grad_np, grad_np, axes=([0, 2], [0, 2])) /
+ grad_np.shape[1])
+ mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.5 / 3.0)
+ mat_g3 = (
+ np.tensordot(grad_np, grad_np, axes=([0, 1], [0, 1])) /
+ grad_np.shape[2])
+ mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), -0.5 / 3.0)
+
+ precond_grad = np.tensordot(grad_np, mat_g1_a, axes=([0], [0]))
+ precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0]))
+ precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0]))
+ new_val_np = init_var_np - precond_grad
+
+ self.assertAllCloseAccordingToType(new_val_np, new_val,
+ atol=TOLERANCE, rtol=TOLERANCE)
+
+ # Run another step of Shampoo
+ update_2.run()
+ new_val = sess.run(var)
+
+ mat_g1 += (
+ np.tensordot(grad_np_2, grad_np_2, axes=([1, 2], [1, 2])) /
+ grad_np_2.shape[0])
+ mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.5 / 3.0)
+ mat_g2 += (
+ np.tensordot(grad_np_2, grad_np_2, axes=([0, 2], [0, 2])) /
+ grad_np_2.shape[1])
+ mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.5 / 3.0)
+ mat_g3 += (
+ np.tensordot(grad_np_2, grad_np_2, axes=([0, 1], [0, 1])) /
+ grad_np_2.shape[2])
+ mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), -0.5 / 3.0)
+
+ precond_grad = np.tensordot(grad_np_2, mat_g1_a, axes=([0], [0]))
+ precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0]))
+ precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0]))
+ new_val_np -= precond_grad
+
+ self.assertAllCloseAccordingToType(new_val_np, new_val,
+ atol=TOLERANCE, rtol=TOLERANCE)
+
+ @parameterized.named_parameters(
+ ('SVDWithVar', False, False),
+ ('SVDWithResourceVar', False, True),
+ ('IterRootWithVar', True, False),
+ ('IterRootWithResourceVar', True, True),
+ )
+ def testBasicTensor(self, use_iterative_root, use_resource_var):
+ self._testBasicTensor(use_iterative_root, use_resource_var)
+
+ @parameterized.named_parameters(('Var', False), ('ResourceVar', True))
+ def testLargeVector(self, use_resource_var):
+ """This is just the diagonal Adagrad update."""
+
+ size = 2000
+ init_var_np = np.zeros(size)
+ grad_np = np.random.rand(size)
+ grad_np_2 = np.random.rand(size)
+
+ with self.cached_session() as sess:
+ global_step = variables.Variable(
+ 0, dtype=dtypes.int64, use_resource=use_resource_var)
+ var = variables.Variable(
+ init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
+ grad = constant_op.constant(grad_np, dtype=dtypes.float32)
+ grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
+
+ opt = shampoo.ShampooOptimizer(global_step)
+ update = opt.apply_gradients(zip([grad], [var]),
+ global_step=global_step)
+ update_2 = opt.apply_gradients(zip([grad_2], [var]),
+ global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ init_val = sess.run(var)
+ self.assertAllCloseAccordingToType(init_var_np, init_val)
+
+ # Run a step of Shampoo
+ update.run()
+ new_val = sess.run(var)
+
+ # let up compute this in numpy
+ # Update rule is var = var - lr * gg^{-0.5} * grad
+ # lr = 1
+ mat_g = (grad_np * grad_np)
+ new_val_np = init_var_np - np.power(mat_g + RIDGE_EPSILON, -0.5) * grad_np
+
+ self.assertAllCloseAccordingToType(
+ new_val_np, new_val, atol=TOLERANCE, rtol=TOLERANCE)
+ # Run another step of Shampoo
+ update_2.run()
+ new_val = sess.run(var)
+
+ mat_g += (grad_np_2 * grad_np_2)
+ new_val_np -= np.power(mat_g + RIDGE_EPSILON, -0.5) * grad_np_2
+
+ self.assertAllCloseAccordingToType(
+ new_val_np, new_val, atol=TOLERANCE, rtol=TOLERANCE)
+
+
+ @parameterized.named_parameters(('Var', False), ('ResourceVar', True))
+ def testLargeMatrix(self, use_resource_var):
+ """Gradient is a matrix, one of whose dimensions is large.
+
+ We do diagonal updates for large dimensions.
+
+ Args:
+ use_resource_var: use resource var as variables.
+ """
+
+ size = [2000, 3]
+ init_var_np = np.zeros(size)
+ grad_np = np.random.rand(size[0], size[1])
+ grad_np_2 = np.random.rand(size[0], size[1])
+
+ with self.cached_session() as sess:
+ global_step = variables.Variable(
+ 0, dtype=dtypes.int64, use_resource=use_resource_var)
+ var = variables.Variable(
+ init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
+ grad = constant_op.constant(grad_np, dtype=dtypes.float32)
+ grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
+
+ opt = shampoo.ShampooOptimizer(global_step)
+ update = opt.apply_gradients(zip([grad], [var]),
+ global_step=global_step)
+ update_2 = opt.apply_gradients(zip([grad_2], [var]),
+ global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ init_val = sess.run(var)
+ self.assertAllCloseAccordingToType(init_var_np, init_val)
+
+ # Run a step of Shampoo
+ update.run()
+ new_val = sess.run(var)
+
+ # let up compute this in numpy
+ # Update rule is var = var - lr * mat_left * grad * mat_right
+ # where the mat_left * grad is just element-wise product,
+ # with broadcasting
+ # lr = 1
+
+ mat_g1 = np.sum(
+ grad_np * grad_np, axis=1, keepdims=True) / grad_np.shape[0]
+ mat_left = np.power(mat_g1 + RIDGE_EPSILON, -0.25)
+ mat_g2 = np.dot(grad_np.transpose(), grad_np) / grad_np.shape[1]
+ mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25)
+ new_val_np = init_var_np - np.dot(grad_np * mat_left, mat_right)
+
+ self.assertAllCloseAccordingToType(new_val_np, new_val,
+ atol=TOLERANCE, rtol=TOLERANCE)
+
+ # Run another step of Shampoo
+ update_2.run()
+ new_val = sess.run(var)
+
+ mat_g1 += np.sum(
+ grad_np_2 * grad_np_2, axis=1, keepdims=True) / grad_np_2.shape[0]
+ mat_left = np.power(mat_g1 + RIDGE_EPSILON, -0.25)
+ mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) / grad_np_2.shape[1]
+ mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25)
+ new_val_np -= np.dot(grad_np_2 * mat_left, mat_right)
+
+ self.assertAllCloseAccordingToType(new_val_np, new_val,
+ atol=TOLERANCE, rtol=TOLERANCE)
+
+ @parameterized.named_parameters(('Var', False))
+ def testSparseUpdateLarge(self, use_resource_var):
+ """Check update when gradient is of type IndexSlices.
+
+ We do diagonal updates for the first dimension, unless it is very small.
+
+ Args:
+ use_resource_var: use resource var as variables.
+ """
+ size = [2000, 3]
+ sample_size_1 = 100
+ init_var_np = np.zeros(size)
+ grad_indices = np.sort(np.random.choice(np.arange(size[0]), sample_size_1,
+ replace=False))
+ grad_np = np.random.rand(sample_size_1, size[1])
+
+ sample_size_2 = 7
+ grad_indices_2 = np.sort(np.random.choice(np.arange(size[0]), sample_size_2,
+ replace=False))
+ grad_np_2 = np.random.rand(sample_size_2, size[1])
+
+ with self.cached_session() as sess:
+ global_step = variables.Variable(
+ 0, dtype=dtypes.int64, use_resource=use_resource_var)
+ var = variables.Variable(
+ init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
+ grad = ops.IndexedSlices(
+ constant_op.constant(grad_np, dtype=dtypes.float32),
+ constant_op.constant(grad_indices),
+ constant_op.constant(size))
+ grad_2 = ops.IndexedSlices(
+ constant_op.constant(grad_np_2, dtype=dtypes.float32),
+ constant_op.constant(grad_indices_2),
+ constant_op.constant(size))
+
+ opt = shampoo.ShampooOptimizer(global_step)
+ update = opt.apply_gradients(zip([grad], [var]),
+ global_step=global_step)
+ update_2 = opt.apply_gradients(zip([grad_2], [var]),
+ global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ init_val = sess.run(var)
+ self.assertAllCloseAccordingToType(init_var_np, init_val)
+
+ # Run a step of Shampoo
+ update.run()
+ new_val = sess.run(var)
+
+ # let up compute this in numpy
+ # Update rule is var = var - lr * mat_left * grad * mat_right
+ # where the mat_left * grad is just element-wise product,
+ # with broadcasting
+ # lr = 1
+ # In this case the update lr * mat_left * grad * mat_right is
+ # of size 10 x 2.
+ # So the correct indices of var need to be updated.
+
+ mat_g1 = np.sum(grad_np * grad_np, axis=1, keepdims=True)
+ mat_g1_acc = np.zeros((size[0], 1))
+ mat_g1_acc[grad_indices] += mat_g1
+ mat_left = np.power(mat_g1 + RIDGE_EPSILON, -0.25)
+ mat_g2 = np.dot(grad_np.transpose(), grad_np) / grad_np.shape[1]
+ mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25)
+ new_val_np = init_var_np
+ new_val_np[grad_indices, :] -= np.dot(grad_np * mat_left, mat_right)
+
+ self.assertAllCloseAccordingToType(new_val_np, new_val,
+ atol=TOLERANCE, rtol=TOLERANCE)
+
+ # Run another step of Shampoo
+ update_2.run()
+ new_val = sess.run(var)
+
+ mat_g1 = np.sum(grad_np_2 * grad_np_2, axis=1, keepdims=True)
+ mat_g1_acc[grad_indices_2] += mat_g1
+ mat_left = np.power(mat_g1_acc[grad_indices_2] + RIDGE_EPSILON, -0.25)
+ mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) / grad_np_2.shape[1]
+ mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25)
+ new_val_np[grad_indices_2, :] -= np.dot(grad_np_2 * mat_left, mat_right)
+
+ self.assertAllCloseAccordingToType(new_val_np, new_val,
+ atol=TOLERANCE, rtol=TOLERANCE)
+
+ def _testSparseUpdateSmall(self, use_iterative_root, use_resource_var):
+ """Gradient is of type IndexSlices, but the first dimension is small.
+
+ We create dense gradient and do the full update with SVD etc.
+
+ Args:
+ use_iterative_root: use iterative power method or SVD to find nth roots.
+ use_resource_var: use resource var as variables.
+ """
+
+ size = [100, 3, 5]
+ sample_size = 10
+ init_var_np = np.zeros(size)
+ grad_indices = np.sort(np.random.choice(np.arange(size[0]), sample_size,
+ replace=False))
+ grad_np = np.random.rand(sample_size, size[1], size[2])
+
+ with self.cached_session() as sess:
+ global_step = variables.Variable(
+ 0, dtype=dtypes.int64, use_resource=use_resource_var)
+ var = variables.Variable(
+ init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
+ grad = ops.IndexedSlices(
+ constant_op.constant(grad_np, dtype=dtypes.float32),
+ constant_op.constant(grad_indices),
+ constant_op.constant(size))
+
+ opt = shampoo.ShampooOptimizer(global_step,
+ use_iterative_root=use_iterative_root)
+ update = opt.apply_gradients(zip([grad], [var]),
+ global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ init_val = sess.run(var)
+ self.assertAllCloseAccordingToType(init_var_np, init_val)
+
+ # Run a step of Shampoo
+ update.run()
+ new_val = sess.run(var)
+
+ # let up compute this in numpy
+ # Update rule is var = var - lr * Prod_i mat_g_i^{-0.125} grad
+ # lr = 1
+ grad_dense = np.zeros_like(init_var_np)
+ grad_dense[grad_indices] = grad_np
+
+ mat_g1 = np.tensordot(
+ grad_dense, grad_dense, axes=([1, 2], [1, 2])) / grad_dense.shape[0]
+ mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.5 / 3.0)
+ mat_g2 = np.tensordot(
+ grad_dense, grad_dense, axes=([0, 2], [0, 2])) / grad_dense.shape[1]
+ mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.5 / 3.0)
+ mat_g3 = np.tensordot(
+ grad_dense, grad_dense, axes=([0, 1], [0, 1])) / grad_dense.shape[2]
+ mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), -0.5 / 3.0)
+
+ precond_grad = np.tensordot(grad_dense, mat_g1_a, axes=([0], [0]))
+ precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0]))
+ precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0]))
+ new_val_np = init_var_np - precond_grad
+
+ self.assertAllCloseAccordingToType(new_val_np, new_val,
+ atol=TOLERANCE, rtol=TOLERANCE)
+
+ @parameterized.named_parameters(
+ ('SVDWithVar', False, False),
+ ('SVDWithResourceVar', False, True),
+ ('IterRootWithVar', True, False),
+ ('IterRootWithResourceVar', True, True),
+ )
+ def testSparseUpdateSmall(self, use_iterative_root, use_resource_var):
+ self._testSparseUpdateSmall(use_iterative_root, use_resource_var)
+
+ def _testBasicTensorWithMomentum(self, use_iterative_root, use_resource_var):
+ """Check update with momentum when gradient is a tensor.
+
+ Args:
+ use_iterative_root: use iterative power method or SVD to find nth roots.
+ use_resource_var: use resource var as variables.
+ """
+ size = [10, 5, 7]
+ init_var_np = np.zeros(size)
+ grad_np = np.random.rand(size[0], size[1], size[2])
+ grad_np_2 = np.random.rand(size[0], size[1], size[2])
+ gbar_decay = 0.9
+ gbar_weight = 0.1
+
+ with self.cached_session() as sess:
+ global_step = variables.Variable(
+ 0, dtype=dtypes.int64, use_resource=use_resource_var)
+ var = variables.Variable(
+ init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
+ grad = constant_op.constant(grad_np, dtype=dtypes.float32)
+ grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
+
+ opt = shampoo.ShampooOptimizer(global_step, gbar_decay=gbar_decay,
+ gbar_weight=gbar_weight,
+ use_iterative_root=use_iterative_root)
+ update = opt.apply_gradients(zip([grad], [var]),
+ global_step=global_step)
+ update_2 = opt.apply_gradients(zip([grad_2], [var]),
+ global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ # Run a step of Shampoo
+ update.run()
+ new_val = sess.run(var)
+
+ # let up compute this in numpy
+ # Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad
+ # lr = 1
+ mat_g1 = np.tensordot(
+ grad_np, grad_np, axes=([1, 2], [1, 2])) / grad_np.shape[0]
+ mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.5 / 3.0)
+ mat_g2 = np.tensordot(
+ grad_np, grad_np, axes=([0, 2], [0, 2])) / grad_np.shape[1]
+ mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.5 / 3.0)
+ mat_g3 = np.tensordot(
+ grad_np, grad_np, axes=([0, 1], [0, 1])) / grad_np.shape[2]
+ mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), -0.5 / 3.0)
+
+ gbar_np = gbar_weight * grad_np
+ precond_grad = np.tensordot(gbar_np, mat_g1_a, axes=([0], [0]))
+ precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0]))
+ precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0]))
+ new_val_np = init_var_np - precond_grad
+
+ self.assertAllCloseAccordingToType(new_val_np, new_val,
+ atol=TOLERANCE, rtol=TOLERANCE)
+
+ # Run another step of Shampoo
+ update_2.run()
+ new_val = sess.run(var)
+
+ mat_g1 += np.tensordot(
+ grad_np_2, grad_np_2, axes=([1, 2], [1, 2])) / grad_np_2.shape[0]
+ mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.5 / 3.0)
+ mat_g2 += np.tensordot(
+ grad_np_2, grad_np_2, axes=([0, 2], [0, 2])) / grad_np_2.shape[1]
+ mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.5 / 3.0)
+ mat_g3 += np.tensordot(
+ grad_np_2, grad_np_2, axes=([0, 1], [0, 1])) / grad_np_2.shape[2]
+ mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), -0.5 / 3.0)
+
+ gbar_np_2 = gbar_decay * gbar_np + gbar_weight * grad_np_2
+ precond_grad = np.tensordot(gbar_np_2, mat_g1_a, axes=([0], [0]))
+ precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0]))
+ precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0]))
+ new_val_np -= precond_grad
+
+ self.assertAllCloseAccordingToType(new_val_np, new_val,
+ atol=TOLERANCE, rtol=TOLERANCE)
+
+ @parameterized.named_parameters(
+ ('SVDWithVar', False, False),
+ ('SVDWithResourceVar', False, True),
+ ('IterRootWithVar', True, False),
+ ('IterRootWithResourceVar', True, True),
+ )
+ def testBasicTensorWithMomentum(self, use_iterative_root, use_resource_var):
+ self._testBasicTensorWithMomentum(use_iterative_root, use_resource_var)
+
+ def _testDelayedSVD(self, use_iterative_root, use_resource_var):
+ """Performing the SVD every nth step.
+
+ Args:
+ use_iterative_root: use iterative power method or SVD to find nth roots.
+ use_resource_var: use resource var as variables.
+ """
+ size = [10, 5, 7]
+ init_var_np = np.zeros(size).astype(np.float32)
+ iterations = 20
+ svd_interval = 5
+ grad_np = np.random.rand(
+ iterations, size[0], size[1], size[2]).astype(np.float32)
+ mat_g1_a = np.eye(size[0])
+ mat_g1 = np.zeros_like(mat_g1_a)
+ mat_g2_a = np.eye(size[1])
+ mat_g2 = np.zeros_like(mat_g2_a)
+ mat_g3_a = np.eye(size[2])
+ mat_g3 = np.zeros_like(mat_g3_a)
+
+ with self.cached_session() as sess:
+ global_step = variables.Variable(
+ 0, dtype=dtypes.int64, use_resource=use_resource_var)
+ var = variables.Variable(
+ init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
+ grad = array_ops.placeholder(dtypes.float32, shape=size)
+
+ opt = shampoo.ShampooOptimizer(global_step, svd_interval=svd_interval,
+ use_iterative_root=use_iterative_root)
+ update = opt.apply_gradients(zip([grad], [var]),
+ global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ init_val = sess.run(var)
+ self.assertAllCloseAccordingToType(init_var_np, init_val)
+ new_val_np = init_var_np
+
+ # Run n steps of Shampoo
+ for i in range(iterations):
+ _ = sess.run(update, feed_dict={grad: grad_np[i]})
+ new_val = sess.run(var)
+
+ # let up compute this in numpy
+ # Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad
+ # lr = 1
+ mat_g1 += np.tensordot(
+ grad_np[i], grad_np[i], axes=([1, 2], [1, 2])) / grad_np[i].shape[0]
+ mat_g2 += np.tensordot(
+ grad_np[i], grad_np[i], axes=([0, 2], [0, 2])) / grad_np[i].shape[1]
+ mat_g3 += np.tensordot(
+ grad_np[i], grad_np[i], axes=([0, 1], [0, 1])) / grad_np[i].shape[2]
+ if (i + 1) % svd_interval == 0:
+ mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]),
+ -0.5 / 3.0)
+ mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]),
+ -0.5 / 3.0)
+ mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]),
+ -0.5 / 3.0)
+
+ precond_grad = np.tensordot(grad_np[i], mat_g1_a, axes=([0], [0]))
+ precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0]))
+ precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0]))
+ new_val_np -= precond_grad
+
+ self.assertAllCloseAccordingToType(new_val_np, new_val,
+ atol=TOLERANCE, rtol=TOLERANCE)
+
+ @parameterized.named_parameters(
+ ('SVDWithVar', False, False),
+ ('SVDWithResourceVar', False, True),
+ ('IterRootWithVar', True, False),
+ ('IterRootWithResourceVar', True, True),
+ )
+ def testDelayedSVD(self, use_iterative_root, use_resource_var):
+ self._testDelayedSVD(use_iterative_root, use_resource_var)
+
+ def _testDelayedPrecondUpdate(self, use_iterative_root, use_resource_var):
+ """Update the squared sum every nth step, drop the other steps.
+
+ Args:
+ use_iterative_root: use iterative power method or SVD to find nth roots.
+ use_resource_var: use resource var as variables.
+ """
+ size = [10, 5, 7]
+ init_var_np = np.zeros(size).astype(np.float32)
+ iterations = 100
+ grad_np = np.random.rand(
+ iterations, size[0], size[1], size[2]).astype(np.float32)
+ svd_interval = 20
+ precond_update_interval = 5
+ mat_g1_a = np.eye(size[0])
+ mat_g1 = np.zeros_like(mat_g1_a)
+ mat_g2_a = np.eye(size[1])
+ mat_g2 = np.zeros_like(mat_g2_a)
+ mat_g3_a = np.eye(size[2])
+ mat_g3 = np.zeros_like(mat_g3_a)
+
+ with self.cached_session() as sess:
+ global_step = variables.Variable(
+ 0, dtype=dtypes.int64, use_resource=use_resource_var)
+ var = variables.Variable(
+ init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
+ grad = array_ops.placeholder(dtypes.float32, shape=size)
+
+ opt = shampoo.ShampooOptimizer(
+ global_step, svd_interval=svd_interval,
+ precond_update_interval=precond_update_interval,
+ use_iterative_root=use_iterative_root)
+ update = opt.apply_gradients(zip([grad], [var]),
+ global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ init_val = sess.run(var)
+ self.assertAllCloseAccordingToType(init_var_np, init_val)
+ new_val_np = init_var_np
+
+ # Run n steps of Shampoo
+ for i in range(iterations):
+ _ = sess.run(update, feed_dict={grad: grad_np[i]})
+ new_val = sess.run(var)
+
+ # let up compute this in numpy
+ # Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad
+ # lr = 1
+ if (i + 1) % precond_update_interval == 0:
+ mat_g1 += (
+ np.tensordot(grad_np[i], grad_np[i], axes=([1, 2], [1, 2])) /
+ grad_np[i].shape[0] * precond_update_interval)
+ mat_g2 += (
+ np.tensordot(grad_np[i], grad_np[i], axes=([0, 2], [0, 2])) /
+ grad_np[i].shape[1] * precond_update_interval)
+ mat_g3 += (
+ np.tensordot(grad_np[i], grad_np[i], axes=([0, 1], [0, 1])) /
+ grad_np[i].shape[2] * precond_update_interval)
+
+ if (i + 1) % svd_interval == 0:
+ mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]),
+ -0.5 / 3.0)
+ mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]),
+ -0.5 / 3.0)
+ mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]),
+ -0.5 / 3.0)
+
+ precond_grad = np.tensordot(grad_np[i], mat_g1_a, axes=([0], [0]))
+ precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0]))
+ precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0]))
+ new_val_np -= precond_grad
+
+ self.assertAllCloseAccordingToType(new_val_np, new_val,
+ atol=TOLERANCE, rtol=TOLERANCE)
+
+ @parameterized.named_parameters(
+ ('SVDWithVar', False, False),
+ ('SVDWithResourceVar', False, True),
+ ('IterRootWithVar', True, False),
+ ('IterRootWithResourceVar', True, True),
+ )
+ def testDelayedPrecondUpdate(self, use_iterative_root, use_resource_var):
+ self._testDelayedPrecondUpdate(use_iterative_root, use_resource_var)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/opt/python/training/sign_decay_test.py b/tensorflow/contrib/opt/python/training/sign_decay_test.py
index c31cb924ea..3a84789afd 100644
--- a/tensorflow/contrib/opt/python/training/sign_decay_test.py
+++ b/tensorflow/contrib/opt/python/training/sign_decay_test.py
@@ -66,7 +66,7 @@ class SignDecaysTest(test.TestCase):
linear_decay_fn = sign_decay.get_linear_decay_fn(num_training_steps)
for step in range(0, 1000, 100):
- with self.test_session():
+ with self.cached_session():
tf_decayed = linear_decay_fn(step).eval()
py_decayed = py_linear_decay_fn(num_training_steps)(step)
self.assertAlmostEqual(tf_decayed, py_decayed, places=4)
@@ -78,7 +78,7 @@ class SignDecaysTest(test.TestCase):
num_training_steps, num_periods=5, zero_after=2)
for step in range(0, 1000, 100):
- with self.test_session():
+ with self.cached_session():
tf_decayed = cosine_decay_fn(step).eval()
py_decayed = py_cosine_decay_fn(num_training_steps)(step)
self.assertAlmostEqual(tf_decayed, py_decayed, places=4)
@@ -95,7 +95,7 @@ class SignDecaysTest(test.TestCase):
num_training_steps, num_periods=5, zero_after=2)
for step in range(0, 1000, 100):
- with self.test_session():
+ with self.cached_session():
tf_decayed = restart_decay_fn(step).eval()
py_decayed = py_restart_decay_fn(num_training_steps)(step)
self.assertAlmostEqual(tf_decayed, py_decayed, places=4)
diff --git a/tensorflow/contrib/opt/python/training/variable_clipping_optimizer_test.py b/tensorflow/contrib/opt/python/training/variable_clipping_optimizer_test.py
index fdda86b0b5..ff0ea8d766 100644
--- a/tensorflow/contrib/opt/python/training/variable_clipping_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/variable_clipping_optimizer_test.py
@@ -158,7 +158,7 @@ class VariableClippingOptimizerTest(test.TestCase):
def testDenseLocal(self):
for dtype in [dtypes.float32, dtypes.float64, dtypes.half]:
- with self.test_session():
+ with self.cached_session():
var0, var1, update_op = self._setupDense(False, dtype)
self._assertDenseCorrect(var0, var1, update_op)
@@ -171,7 +171,7 @@ class VariableClippingOptimizerTest(test.TestCase):
def testSparseLocal(self):
for dtype in [dtypes.float64, dtypes.float32, dtypes.half]:
- with self.test_session():
+ with self.cached_session():
var0, var1, update_op = self._setupSparse(False, dtype)
self._assertSparseCorrect(var0, var1, update_op)
diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py
new file mode 100644
index 0000000000..200b0d2008
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py
@@ -0,0 +1,435 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Base class to make optimizers weight decay ready."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.opt.python.training import shampoo
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.training import adam
+from tensorflow.python.training import momentum as momentum_opt
+from tensorflow.python.training import optimizer
+from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.ops import array_ops
+
+
+class DecoupledWeightDecayExtension(object):
+ """This class allows to extend optimizers with decoupled weight decay.
+
+ It implements the decoupled weight decay described by Loshchilov & Hutter
+ (https://arxiv.org/pdf/1711.05101.pdf), in which the weight decay is
+ decoupled from the optimization steps w.r.t. to the loss function.
+ For SGD variants, this simplifies hyperparameter search since it decouples
+ the settings of weight decay and learning rate.
+ For adaptive gradient algorithms, it regularizes variables with large
+ gradients more than L2 regularization would, which was shown to yield better
+ training loss and generalization error in the paper above.
+
+ This class alone is not an optimizer but rather extends existing
+ optimizers with decoupled weight decay. We explicitly define the two examples
+ used in the above paper (SGDW and AdamW), but in general this can extend
+ any OptimizerX by using
+ `extend_with_weight_decay(OptimizerX, weight_decay=weight_decay)`.
+ In order for it to work, it must be the first class the Optimizer with
+ weight decay inherits from, e.g.
+
+ ```python
+ class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer):
+ def __init__(self, weight_decay, *args, **kwargs):
+ super(AdamWOptimizer, self).__init__(weight_decay, *args, **kwargs).
+ ```
+
+ Note that this extension decays weights BEFORE applying the update based
+ on the gradient, i.e. this extension only has the desired behaviour for
+ optimizers which do not depend on the value of'var' in the update step!
+ """
+
+ def __init__(self, weight_decay, **kwargs):
+ """Construct the extension class that adds weight decay to an optimizer.
+
+ Args:
+ weight_decay: A `Tensor` or a floating point value, the factor by which
+ a variable is decayed in the update step.
+ **kwargs: Optional list or tuple or set of `Variable` objects to
+ decay.
+ """
+ self._decay_var_list = None # is set in minimize or apply_gradients
+ self._weight_decay = weight_decay
+ # The tensors are initialized in call to _prepare
+ self._weight_decay_tensor = None
+ super(DecoupledWeightDecayExtension, self).__init__(**kwargs)
+
+ def minimize(self, loss, global_step=None, var_list=None,
+ gate_gradients=optimizer.Optimizer.GATE_OP,
+ aggregation_method=None, colocate_gradients_with_ops=False,
+ name=None, grad_loss=None, decay_var_list=None):
+ """Add operations to minimize `loss` by updating `var_list` with decay.
+
+ This function is the same as Optimizer.minimize except that it allows to
+ specify the variables that should be decayed using decay_var_list.
+ If decay_var_list is None, all variables in var_list are decayed.
+
+ For more information see the documentation of Optimizer.minimize.
+
+ Args:
+ loss: A `Tensor` containing the value to minimize.
+ global_step: Optional `Variable` to increment by one after the
+ variables have been updated.
+ var_list: Optional list or tuple of `Variable` objects to update to
+ minimize `loss`. Defaults to the list of variables collected in
+ the graph under the key `GraphKeys.TRAINABLE_VARIABLES`.
+ gate_gradients: How to gate the computation of gradients. Can be
+ `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
+ aggregation_method: Specifies the method used to combine gradient terms.
+ Valid values are defined in the class `AggregationMethod`.
+ colocate_gradients_with_ops: If True, try colocating gradients with
+ the corresponding op.
+ name: Optional name for the returned operation.
+ grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
+ decay_var_list: Optional list of decay variables.
+
+ Returns:
+ An Operation that updates the variables in `var_list`. If `global_step`
+ was not `None`, that operation also increments `global_step`.
+
+ """
+ self._decay_var_list = set(decay_var_list) if decay_var_list else False
+ return super(DecoupledWeightDecayExtension, self).minimize(
+ loss, global_step=global_step, var_list=var_list,
+ gate_gradients=gate_gradients, aggregation_method=aggregation_method,
+ colocate_gradients_with_ops=colocate_gradients_with_ops, name=name,
+ grad_loss=grad_loss)
+
+ def apply_gradients(self, grads_and_vars, global_step=None, name=None,
+ decay_var_list=None):
+ """Apply gradients to variables and decay the variables.
+
+ This function is the same as Optimizer.apply_gradients except that it
+ allows to specify the variables that should be decayed using
+ decay_var_list. If decay_var_list is None, all variables in var_list
+ are decayed.
+
+ For more information see the documentation of Optimizer.apply_gradients.
+
+ Args:
+ grads_and_vars: List of (gradient, variable) pairs as returned by
+ `compute_gradients()`.
+ global_step: Optional `Variable` to increment by one after the
+ variables have been updated.
+ name: Optional name for the returned operation. Default to the
+ name passed to the `Optimizer` constructor.
+ decay_var_list: Optional list of decay variables.
+
+ Returns:
+ An `Operation` that applies the specified gradients. If `global_step`
+ was not None, that operation also increments `global_step`.
+ """
+ self._decay_var_list = set(decay_var_list) if decay_var_list else False
+ return super(DecoupledWeightDecayExtension, self).apply_gradients(
+ grads_and_vars, global_step=global_step, name=name)
+
+ def _prepare(self):
+ weight_decay = self._weight_decay
+ if callable(weight_decay):
+ weight_decay = weight_decay()
+ self._weight_decay_tensor = ops.convert_to_tensor(
+ weight_decay, name="weight_decay")
+ # Call the optimizers _prepare function.
+ super(DecoupledWeightDecayExtension, self)._prepare()
+
+ def _decay_weights_op(self, var):
+ if not self._decay_var_list or var in self._decay_var_list:
+ return var.assign_sub(self._weight_decay * var, self._use_locking)
+ return control_flow_ops.no_op()
+
+ def _decay_weights_sparse_op(self, var, indices, scatter_add):
+ if not self._decay_var_list or var in self._decay_var_list:
+ update = -self._weight_decay * array_ops.gather(var, indices)
+ return scatter_add(var, indices, update, self._use_locking)
+ return control_flow_ops.no_op()
+
+ # Here, we overwrite the apply functions that the base optimizer calls.
+ # super().apply_x resolves to the apply_x function of the BaseOptimizer.
+ def _apply_dense(self, grad, var):
+ with ops.control_dependencies([self._decay_weights_op(var)]):
+ return super(DecoupledWeightDecayExtension, self)._apply_dense(grad, var)
+
+ def _resource_apply_dense(self, grad, var):
+ with ops.control_dependencies([self._decay_weights_op(var)]):
+ return super(DecoupledWeightDecayExtension, self)._resource_apply_dense(
+ grad, var)
+
+ def _apply_sparse(self, grad, var):
+ scatter_add = state_ops.scatter_add
+ decay_op = self._decay_weights_sparse_op(var, grad.indices, scatter_add)
+ with ops.control_dependencies([decay_op]):
+ return super(DecoupledWeightDecayExtension, self)._apply_sparse(
+ grad, var)
+
+ def _resource_scatter_add(self, x, i, v, _=None):
+ # last argument allows for one overflow argument, to have the same function
+ # signature as state_ops.scatter_add
+ with ops.control_dependencies(
+ [resource_variable_ops.resource_scatter_add(x.handle, i, v)]):
+ return x.value()
+
+ def _resource_apply_sparse(self, grad, var, indices):
+ scatter_add = self._resource_scatter_add
+ decay_op = self._decay_weights_sparse_op(var, indices, scatter_add)
+ with ops.control_dependencies([decay_op]):
+ return super(DecoupledWeightDecayExtension, self)._resource_apply_sparse(
+ grad, var, indices)
+
+
+def extend_with_decoupled_weight_decay(base_optimizer):
+ """Factory function returning an optimizer class with decoupled weight decay.
+
+ Returns an optimizer class. An instance of the returned class computes the
+ update step of `base_optimizer` and additionally decays the weights.
+ E.g., the class returned by
+ `extend_with_decoupled_weight_decay(tf.train.AdamOptimizer)` is equivalent to
+ `tf.contrib.opt.AdamWOptimizer`.
+
+ The API of the new optimizer class slightly differs from the API of the
+ base optimizer:
+ - The first argument to the constructor is the weight decay rate.
+ - `minimize` and `apply_gradients` accept the optional keyword argument
+ `decay_var_list`, which specifies the variables that should be decayed.
+ If `None`, all variables that are optimized are decayed.
+
+ Usage example:
+ ```python
+ # MyAdamW is a new class
+ MyAdamW = extend_with_decoupled_weight_decay(tf.train.AdamOptimizer)
+ # Create a MyAdamW object
+ optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001)
+ sess.run(optimizer.minimize(loss, decay_variables=[var1, var2]))
+
+ Note that this extension decays weights BEFORE applying the update based
+ on the gradient, i.e. this extension only has the desired behaviour for
+ optimizers which do not depend on the value of'var' in the update step!
+ ```
+
+ Args:
+ base_optimizer: An optimizer class that inherits from tf.train.Optimizer.
+
+ Returns:
+ A new optimizer class that inherits from DecoupledWeightDecayExtension
+ and base_optimizer.
+ """
+
+ class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension,
+ base_optimizer):
+ """Base_optimizer with decoupled weight decay.
+
+ This class computes the update step of `base_optimizer` and
+ additionally decays the variable with the weight decay being decoupled from
+ the optimization steps w.r.t. to the loss function, as described by
+ Loshchilov & Hutter (https://arxiv.org/pdf/1711.05101.pdf).
+ For SGD variants, this simplifies hyperparameter search since
+ it decouples the settings of weight decay and learning rate.
+ For adaptive gradient algorithms, it regularizes variables with large
+ gradients more than L2 regularization would, which was shown to yield
+ better training loss and generalization error in the paper above.
+ """
+
+ def __init__(self, weight_decay, *args, **kwargs):
+ # super delegation is necessary here
+ # pylint: disable=useless-super-delegation
+ super(OptimizerWithDecoupledWeightDecay, self).__init__(
+ weight_decay, *args, **kwargs)
+ # pylint: enable=useless-super-delegation
+
+ return OptimizerWithDecoupledWeightDecay
+
+
+@tf_export("contrib.opt.MomentumWOptimizer")
+class MomentumWOptimizer(DecoupledWeightDecayExtension,
+ momentum_opt.MomentumOptimizer):
+ """Optimizer that implements the Momentum algorithm with weight_decay.
+
+ This is an implementation of the SGDW optimizer described in "Fixing
+ Weight Decay Regularization in Adam" by Loshchilov & Hutter
+ (https://arxiv.org/abs/1711.05101)
+ ([pdf])(https://arxiv.org/pdf/1711.05101.pdf).
+ It computes the update step of `train.MomentumOptimizer` and additionally
+ decays the variable. Note that this is different from adding
+ L2 regularization on the variables to the loss. Decoupling the weight decay
+ from other hyperparameters (in particular the learning rate) simplifies
+ hyperparameter search.
+
+ For further information see the documentation of the Momentum Optimizer.
+
+ Note that this optimizer can also be instantiated as
+ ```python
+ extend_with_weight_decay(tf.train.MomentumOptimizer,
+ weight_decay=weight_decay)
+ ```
+ """
+
+ def __init__(self, weight_decay, learning_rate, momentum,
+ use_locking=False, name="MomentumW", use_nesterov=False):
+ """Construct a new MomentumW optimizer.
+
+ For further information see the documentation of the Momentum Optimizer.
+
+ Args:
+ weight_decay: A `Tensor` or a floating point value. The weight decay.
+ learning_rate: A `Tensor` or a floating point value. The learning rate.
+ momentum: A `Tensor` or a floating point value. The momentum.
+ use_locking: If `True` use locks for update operations.
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to "Momentum".
+ use_nesterov: If `True` use Nesterov Momentum.
+ See [Sutskever et al., 2013](
+ http://jmlr.org/proceedings/papers/v28/sutskever13.pdf).
+ This implementation always computes gradients at the value of the
+ variable(s) passed to the optimizer. Using Nesterov Momentum makes the
+ variable(s) track the values called `theta_t + mu*v_t` in the paper.
+
+ @compatibility(eager)
+ When eager execution is enabled, learning_rate, weight_decay and momentum
+ can each be a callable that takes no arguments and returns the actual value
+ to use. This can be useful for changing these values across different
+ invocations of optimizer functions.
+ @end_compatibility
+ """
+ super(MomentumWOptimizer, self).__init__(
+ weight_decay, learning_rate=learning_rate, momentum=momentum,
+ use_locking=use_locking, name=name, use_nesterov=use_nesterov)
+
+
+@tf_export("contrib.opt.AdamWOptimizer")
+class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer):
+ """Optimizer that implements the Adam algorithm with weight decay.
+
+ This is an implementation of the AdamW optimizer described in "Fixing
+ Weight Decay Regularization in Adam" by Loshchilov & Hutter
+ (https://arxiv.org/abs/1711.05101)
+ ([pdf])(https://arxiv.org/pdf/1711.05101.pdf).
+
+ It computes the update step of `train.AdamOptimizer` and additionally decays
+ the variable. Note that this is different from adding L2 regularization on
+ the variables to the loss: it regularizes variables with large
+ gradients more than L2 regularization would, which was shown to yield better
+ training loss and generalization error in the paper above.
+
+ For further information see the documentation of the Adam Optimizer.
+
+ Note that this optimizer can also be instantiated as
+ ```python
+ extend_with_weight_decay(tf.train.AdamOptimizer, weight_decay=weight_decay)
+ ```
+ """
+
+ def __init__(self, weight_decay, learning_rate=0.001, beta1=0.9, beta2=0.999,
+ epsilon=1e-8, use_locking=False, name="AdamW"):
+ """Construct a new AdamW optimizer.
+
+ For further information see the documentation of the Adam Optimizer.
+
+ Args:
+ weight_decay: A `Tensor` or a floating point value. The weight decay.
+ learning_rate: A Tensor or a floating point value. The learning rate.
+ beta1: A float value or a constant float tensor.
+ The exponential decay rate for the 1st moment estimates.
+ beta2: A float value or a constant float tensor.
+ The exponential decay rate for the 2nd moment estimates.
+ epsilon: A small constant for numerical stability. This epsilon is
+ "epsilon hat" in the Kingma and Ba paper (in the formula just before
+ Section 2.1), not the epsilon in Algorithm 1 of the paper.
+ use_locking: If True use locks for update operations.
+ name: Optional name for the operations created when applying gradients.
+ Defaults to "Adam".
+ """
+ super(AdamWOptimizer, self).__init__(
+ weight_decay, learning_rate=learning_rate, beta1=beta1, beta2=beta2,
+ epsilon=epsilon, use_locking=use_locking, name=name)
+
+
+@tf_export("contrib.opt.ShampooWOptimizer")
+class ShampooWOptimizer(DecoupledWeightDecayExtension,
+ shampoo.ShampooOptimizer):
+ """Optimizer that implements the Shampoo algorithm with weight decay.
+
+ For further information see the documentation of the Shampoo Optimizer.
+ """
+
+ def __init__(self,
+ weight_decay,
+ global_step,
+ max_matrix_size=768,
+ gbar_decay=0.0,
+ gbar_weight=1.0,
+ mat_gbar_decay=1.0,
+ mat_gbar_weight=1.0,
+ learning_rate=1.0,
+ svd_interval=1,
+ precond_update_interval=1,
+ epsilon=1e-4,
+ alpha=0.5,
+ use_iterative_root=False,
+ use_locking=False,
+ name="ShampooW"):
+ """Construct a new ShampooW optimizer.
+
+ For further information see the documentation of the Shampoo Optimizer.
+
+ Args:
+ weight_decay: A `Tensor` or a floating point value. The weight decay.
+ global_step: tensorflow variable indicating the step.
+ max_matrix_size: We do not perform SVD for matrices larger than this.
+ gbar_decay:
+ gbar_weight: Used to update gbar: gbar[t] = gbar_decay[t] * gbar[t-1] +
+ gbar_weight[t] * g[t]
+ mat_gbar_decay:
+ mat_gbar_weight: Used to update mat_gbar: mat_gbar_j[t] =
+ mat_gbar_decay[t] * mat_gbar_j[t-1] + mat_gbar_weight[t] * gg_j[t]
+ learning_rate: Similar to SGD
+ svd_interval: We should do SVD after this many steps. Default = 1, i.e.
+ every step. Usually 20 leads to no loss of accuracy, and 50 or 100 is
+ also OK. May also want more often early,
+ and less often later - set in caller as for example:
+ "svd_interval = lambda(T): tf.cond(
+ T < 2000, lambda: 20.0, lambda: 1000.0)"
+ precond_update_interval: We should update the preconditioners after this
+ many steps. Default = 1. Usually less than svd_interval.
+ epsilon: epsilon * I_n is added to each mat_gbar_j for stability
+ alpha: total power of the preconditioners.
+ use_iterative_root: should the optimizer use SVD (faster) or the iterative
+ root method (for TPU) for finding the roots of PSD matrices.
+ use_locking: If `True` use locks for update operations.
+ name: name of optimizer.
+ """
+ super(ShampooWOptimizer, self).__init__(
+ weight_decay,
+ global_step=global_step,
+ max_matrix_size=max_matrix_size,
+ gbar_decay=gbar_decay,
+ gbar_weight=gbar_weight,
+ mat_gbar_decay=mat_gbar_weight,
+ learning_rate=learning_rate,
+ svd_interval=svd_interval,
+ precond_update_interval=precond_update_interval,
+ epsilon=epsilon,
+ alpha=alpha,
+ use_iterative_root=use_iterative_root,
+ use_locking=use_locking,
+ name=name)
diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py
new file mode 100644
index 0000000000..9c91078301
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py
@@ -0,0 +1,188 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for optimizers with weight decay."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.opt.python.training import weight_decay_optimizers
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import adam
+
+WEIGHT_DECAY = 0.01
+
+
+def adamw_update_numpy(param, g_t, t, m, v, lr=0.001, beta1=0.9,
+ beta2=0.999, epsilon=1e-8):
+ lr_t = lr * np.sqrt(1 - beta2**t) / (1 - beta1**t)
+
+ m_t = beta1 * m + (1 - beta1) * g_t
+ v_t = beta2 * v + (1 - beta2) * g_t * g_t
+
+ param_t = (param - lr_t * m_t / (np.sqrt(v_t) + epsilon) -
+ (param * WEIGHT_DECAY))
+ return param_t, m_t, v_t
+
+
+def momentumw_update_numpy(param, g_t, m, lr=0.001, momentum=0.9, **_):
+ # v, t are not needed for momentum optimizer
+ m = momentum * m + g_t
+ param_t = param - lr * m - param * WEIGHT_DECAY
+ return param_t, m, None
+
+
+class WeightDecayOptimizerTest(test.TestCase):
+
+ def doTest(self, optimizer, update_fn, optimizer_name, slot_name,
+ use_resource=False, do_sparse=False):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ with self.session(graph=ops.Graph()):
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(
+ var0_np, name="var0_%d" % i)
+ var1 = resource_variable_ops.ResourceVariable(
+ var1_np, name="var1_%d" % i)
+ else:
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+
+ if do_sparse:
+ grads0_np_indices = np.array([0, 1], dtype=np.int32)
+ grads0 = ops.IndexedSlices(constant_op.constant(grads0_np),
+ constant_op.constant(grads0_np_indices),
+ constant_op.constant([2]))
+ grads1_np_indices = np.array([0, 1], dtype=np.int32)
+ grads1 = ops.IndexedSlices(constant_op.constant(grads1_np),
+ constant_op.constant(grads1_np_indices),
+ constant_op.constant([2]))
+ else:
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+
+ opt = optimizer()
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+
+ if not context.executing_eagerly():
+ with ops.Graph().as_default():
+ # Shouldn't return non-slot variables from other graphs.
+ self.assertEqual(0, len(opt.variables()))
+ self.evaluate(variables.global_variables_initializer())
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+
+ # Run 3 steps of the optimizer
+ for t in range(1, 4):
+ if not context.executing_eagerly():
+ self.evaluate(update)
+ elif t > 1:
+ opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+
+ var0_np, m0, v0 = update_fn(var0_np, grads0_np, t=t, m=m0, v=v0)
+ var1_np, m1, v1 = update_fn(var1_np, grads1_np, t=t, m=m1, v=v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
+ self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
+ if use_resource:
+ self.assertEqual("var0_%d/%s:0" % (i, optimizer_name),
+ opt.get_slot(var=var0, name=slot_name).name)
+
+
+class AdamWOptimizerTest(WeightDecayOptimizerTest):
+
+ @staticmethod
+ def get_optimizer():
+ return weight_decay_optimizers.AdamWOptimizer(WEIGHT_DECAY)
+
+ def testSparse(self):
+ self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m",
+ use_resource=False, do_sparse=True)
+
+ def testResourceSparse(self):
+ self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m",
+ use_resource=True, do_sparse=True)
+
+ def testBasic(self):
+ self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m",
+ use_resource=False)
+
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
+ def testResourceBasic(self):
+ self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m",
+ use_resource=True)
+
+
+class MomentumWOptimizerTest(WeightDecayOptimizerTest):
+
+ @staticmethod
+ def get_optimizer():
+ return weight_decay_optimizers.MomentumWOptimizer(WEIGHT_DECAY, 0.001, 0.9)
+
+ def testSparse(self):
+ self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW",
+ "momentum", use_resource=False, do_sparse=True)
+
+ def testResourceSparse(self):
+ self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW",
+ "momentum", use_resource=True, do_sparse=True)
+
+ def testBasic(self):
+ self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW",
+ "momentum", use_resource=False)
+
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
+ def testResourceBasic(self):
+ self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW",
+ "momentum", use_resource=True)
+
+
+class ExtendWithWeightDecayTest(WeightDecayOptimizerTest):
+
+ @staticmethod
+ def get_optimizer():
+ adamw = weight_decay_optimizers.extend_with_decoupled_weight_decay(
+ adam.AdamOptimizer)
+ return adamw(WEIGHT_DECAY)
+
+ def testBasic(self):
+ self.doTest(self.get_optimizer, adamw_update_numpy, "Adam", "m",
+ use_resource=False)
+
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
+ def testResourceBasic(self):
+ self.doTest(self.get_optimizer, adamw_update_numpy, "Adam", "m",
+ use_resource=True)
+
+
+if __name__ == "__main__":
+ test.main()