aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/opt
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-18 14:53:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-18 14:57:29 -0700
commit24e9804217a450fc0f8e8f2c4a98e1a593aa77f8 (patch)
tree2629065f4b99e06cd453dff01888e4d36a77accc /tensorflow/contrib/opt
parent3fa0009cbdb8ef95593ffaf63d97e05bf1835cb8 (diff)
This is an initial submission of GGT to tensorflow contrib.
Paper link: https://arxiv.org/pdf/1806.02958.pdf PiperOrigin-RevId: 201063723
Diffstat (limited to 'tensorflow/contrib/opt')
-rw-r--r--tensorflow/contrib/opt/BUILD22
-rw-r--r--tensorflow/contrib/opt/__init__.py4
-rw-r--r--tensorflow/contrib/opt/python/training/ggt.py312
-rw-r--r--tensorflow/contrib/opt/python/training/ggt_test.py183
4 files changed, 520 insertions, 1 deletions
diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD
index 13aa1d7e7a..4f35de4e5d 100644
--- a/tensorflow/contrib/opt/BUILD
+++ b/tensorflow/contrib/opt/BUILD
@@ -19,6 +19,7 @@ py_library(
"python/training/drop_stale_gradient_optimizer.py",
"python/training/elastic_average_optimizer.py",
"python/training/external_optimizer.py",
+ "python/training/ggt.py",
"python/training/lazy_adam_optimizer.py",
"python/training/model_average_optimizer.py",
"python/training/moving_average_optimizer.py",
@@ -31,12 +32,15 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
+ "//tensorflow/contrib/optimizer_v2:optimizer_v2_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:clip_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_ops",
"//tensorflow/python:gradients",
"//tensorflow/python:init_ops",
+ "//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
"//tensorflow/python:state_ops",
@@ -302,3 +306,21 @@ py_test(
"//third_party/py/numpy",
],
)
+
+py_test(
+ name = "ggt_test",
+ srcs = ["python/training/ggt_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":opt_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py
index 4c13c8e247..b41148329d 100644
--- a/tensorflow/contrib/opt/__init__.py
+++ b/tensorflow/contrib/opt/__init__.py
@@ -31,6 +31,7 @@ from tensorflow.contrib.opt.python.training.powersign import *
from tensorflow.contrib.opt.python.training.variable_clipping_optimizer import *
from tensorflow.contrib.opt.python.training.elastic_average_optimizer import *
from tensorflow.contrib.opt.python.training.model_average_optimizer import *
+from tensorflow.contrib.opt.python.training.ggt import *
# pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
@@ -53,7 +54,8 @@ _allowed_symbols = [
'ElasticAverageOptimizer',
'ElasticAverageCustomGetter',
'ModelAverageOptimizer',
- 'ModelAverageCustomGetter'
+ 'ModelAverageCustomGetter',
+ 'GGTOptimizer',
]
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/opt/python/training/ggt.py b/tensorflow/contrib/opt/python/training/ggt.py
new file mode 100644
index 0000000000..928c453517
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/ggt.py
@@ -0,0 +1,312 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""GGT for Tensorflow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import numpy as np
+from tensorflow.contrib.optimizer_v2 import optimizer_v2
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+
+
+class GGTOptimizer(optimizer_v2.OptimizerV2):
+ """Optimizer that implements the GGT algorithm.
+
+ GGT has an advantage over sgd and adam on large models with poor conditioning,
+ for example language models and CNNs,
+ see [ABCHSZZ 2018]([pdf](https://arxiv.org/pdf/1806.02958.pdf)).
+ """
+
+ def __init__(self,
+ learning_rate=0.001,
+ beta1=0.9,
+ use_locking=False,
+ name="GGT",
+ window=10,
+ eps=1e-4,
+ svd_eps=1e-6,
+ sigma_eps=1e-2):
+ """Construct a new GGT optimizer.
+
+ Initialization:
+
+ ```
+ t <- 0 (Initialize timestep)
+ grad_buffer <- 0 (Initialize buffer for keeping past gradients)
+ flat_grad <- 0 (Initialize flattened gradient that contains gradients of all
+ variables)
+ m_0 <- 0 (Initialize 1st moment vector)
+ ```
+
+ Suppose all variables and their gradients are concatenated into vectors
+ `flat_vars` and `flat_grad`. The update rule for `flat_vars`
+ uses an optimization described at the beginning of section 2 of the paper:
+
+ ```
+ t <- t + 1
+
+ m_t <- beta1 * m_{t-1} + (1 - beta1) * flat_grad
+ grad_buffer[(t-1) % window, :] <- m_t
+
+ M <- grad_buffer^T / sqrt(min(t, window))
+ U, sigma, _ <- SVD(M^TM + I * svd_eps)
+
+ sigma_sqrt_inv <- (sqrt(sigma) + sigma_eps)^(-3)
+ sigma_sqrt_min <- min(sqrt(sigma))
+
+ if sigma_sqrt_min > eps:
+ new_step <- M U diag(sigma_sqrt_inv) U^T M^T m_t +
+ (m_t - M U diag(1/sigma) U^T M^T m_t) / sigma_sqrt_min
+ else:
+ new_step <- M U diag(sigma_sqrt_inv) U^T M^T m_t
+
+ flat_vars <- flat_vars - learning_rate * new_step
+ ```
+
+ GGT provides the power of full-matrix adaptive regularization at a cost not
+ much larger than SGD. As a result it is suited for large models where the
+ gradient covariance matrix has a poor condition number that slows down first
+ order methods.
+ GGT uses the preconditioner from full-matrix AdaGrad, with gradient history
+ attenuated exponentially as in Adam, and truncated to a window parameter.
+ It has provable guarantees even for non-convex optimization that is never
+ significantly worse than SGD and in some cases better.
+
+ Args:
+ learning_rate: A float hyperparameter. The learning rate.
+ beta1: A float hyperparameter. The exponential decay rate for the 1st
+ moment estimates.
+ use_locking: If True use locks for update operations.
+ name: Optional name for the operations created when applying gradients.
+ Defaults to "GGT".
+ window: An integer hyperparameter. The number of first moments to keep in
+ computing the adaptive preconditioner.
+ eps: A float hyperparameter. Used to truncate small eigenvalues of the
+ gradient covariance matrix.
+ svd_eps: A float hyperparameter. Used to stabilize SVD.
+ sigma_eps: A float hyperparameter. Used to regularize matrix inversion.
+ """
+ super(GGTOptimizer, self).__init__(use_locking, name)
+ self._set_hyper("lr", learning_rate)
+ self._set_hyper("beta1", beta1)
+ self._set_hyper("window", window)
+ self._set_hyper("eps", eps)
+ self._set_hyper("svd_eps", svd_eps)
+ self._set_hyper("sigma_eps", sigma_eps)
+
+ self.index_dict = {}
+ self.shape_dict = {}
+
+ def _create_vars(self, var_list, state):
+ # Construct ordered dictionary for variable dimensions, sorted by name.
+ shape_dict = {}
+ for v in var_list:
+ shape_dict[v.name] = np.prod(v.get_shape()).value
+ self.shape_dict = collections.OrderedDict(
+ sorted(shape_dict.items(), key=lambda t: t[0]))
+
+ # Assign each variable its location in flat_grad. The locations are based on
+ # the order of sorted names.
+ idx = 0
+ for v_name, v_dim in self.shape_dict.items():
+ self.index_dict[v_name] = idx
+ idx += v_dim
+
+ state.create_non_slot(
+ initial_value=math_ops.cast(0., dtype=var_list[0].dtype.base_dtype),
+ name="global_step")
+
+ # Buffer for keeping past gradients.
+ window = state.get_hyper("window")
+ grad_buffer_init = array_ops.zeros(
+ [window, idx], dtype=var_list[0].dtype.base_dtype)
+ state.create_non_slot(initial_value=grad_buffer_init, name="grad_buffer")
+
+ state.create_non_slot(
+ initial_value=array_ops.zeros(
+ (idx,), dtype=var_list[0].dtype.base_dtype),
+ name="moment1")
+
+ # Flattened gradient that contains gradients for all variables in the model.
+ state.create_non_slot(
+ initial_value=array_ops.zeros(
+ (idx,), dtype=var_list[0].dtype.base_dtype),
+ name="flat_grad")
+
+ def _get_global_step(self, state=None):
+ if state is None:
+ state = self._get_per_graph_state()
+ return state.get_non_slot("global_step")
+
+ def _get_moment1(self, state=None):
+ if state is None:
+ state = self._get_per_graph_state()
+ return state.get_non_slot("moment1")
+
+ def _get_grad_buffer(self, state=None):
+ if state is None:
+ state = self._get_per_graph_state()
+ return state.get_non_slot("grad_buffer")
+
+ def _get_flat_grad(self, state=None):
+ if state is None:
+ state = self._get_per_graph_state()
+ return state.get_non_slot("flat_grad")
+
+ def _apply_sparse(self, grad, var):
+ raise NotImplementedError("Sparse gradient updates are not supported.")
+
+ def _prepare(self, state):
+ self._variables = []
+
+ def _apply_dense(self, grad, var, state):
+ self._variables.append(var)
+ dim = self.shape_dict[var.name]
+ start_index = self.index_dict[var.name]
+ end_index = start_index + dim
+
+ # Update flat_gradient at the index associated with the variable.
+ flat_grad = self._get_flat_grad(state)
+ new_flat_grad = array_ops.reshape(grad, [-1])
+ flat_grad_updated = state_ops.scatter_update(
+ flat_grad, math_ops.range(start_index, end_index), new_flat_grad)
+
+ return flat_grad_updated
+
+ def _resource_apply_dense(self, grad, var, state):
+ self._variables.append(var)
+ dim = self.shape_dict[var.name]
+ start_index = self.index_dict[var.name]
+ end_index = start_index + dim
+
+ # Update flat_gradient at the index associated with the variable.
+ flat_grad = self._get_flat_grad(state)
+ new_flat_grad = array_ops.reshape(grad, [-1])
+ flat_grad_updated = state_ops.scatter_update(
+ flat_grad, math_ops.range(start_index, end_index), new_flat_grad)
+
+ return flat_grad_updated
+
+ def _finish(self, state):
+ var_dtype = self._variables[0].dtype.base_dtype
+ # Update global step.
+ global_step = self._get_global_step(state)
+ update_global_step = state_ops.assign_add(global_step, 1.)
+
+ # Update the first moment estimate.
+ beta1 = state.get_hyper("beta1", dtype=var_dtype)
+ moment1 = self._get_moment1(state)
+ flat_grad = self._get_flat_grad(state)
+ # moment1_t := beta1 * moment1_{t-1} + (1 - beta1) * flat_grad_t
+ update_moment1 = moment1.assign(beta1 * moment1 + (1. - beta1) * flat_grad)
+
+ # Update the gradient buffer.
+ window = state.get_hyper("window")
+ grad_buffer = self._get_grad_buffer(state)
+ next_grad_index = math_ops.floormod(
+ math_ops.to_int32(update_global_step - 1.), window)
+ # grad_buffer[(t-1) % window] := moment1_t
+ update_grad_buffer = state_ops.scatter_update(grad_buffer, next_grad_index,
+ update_moment1)
+
+ # Compute the update step.
+ eps = state.get_hyper("eps", dtype=var_dtype)
+ svd_eps = state.get_hyper("svd_eps", dtype=var_dtype)
+ sigma_eps = state.get_hyper("sigma_eps", dtype=var_dtype)
+ lr = state.get_hyper("lr", dtype=var_dtype)
+ denom = math_ops.sqrt(
+ math_ops.minimum(
+ ops.convert_to_tensor(update_global_step),
+ ops.convert_to_tensor(math_ops.cast(window, dtype=var_dtype))))
+ moment1_2d = array_ops.expand_dims(update_moment1, -1)
+
+ # m = grad_buffer^T / sqrt(min(t, window))
+ # m has shape [model dimension, window], where model dimension is the sum
+ # of the dimensions of the flattened variables.
+ m = array_ops.transpose(math_ops.divide(update_grad_buffer, denom))
+
+ # sigma, u, _ = SVD(m^Tm + I * svd_eps)
+ mm = math_ops.matmul(m, m, transpose_a=True)
+ damping = math_ops.cast(linalg_ops.eye(window), dtype=var_dtype) * svd_eps
+ sigma, u, _ = linalg_ops.svd(mm + damping)
+ sigma_sqrt = math_ops.sqrt(sigma)
+ sigma_sqrt_min = math_ops.reduce_min(sigma_sqrt)
+
+ # sigma_sqrt_inv = 1 / (\sqrt{sigma} + sigma_eps) ^ 3
+ # We add sigma_eps to alleviate numerical instability.
+ # Note that (m^Tm)^(-3/2) = u diag(sigma_sqrt_inv) u^T.
+ sigma_sqrt_inv = math_ops.divide(
+ math_ops.cast(1.0, dtype=var_dtype),
+ math_ops.pow(sigma_sqrt + sigma_eps, 3))
+
+ # In full matrix AdaGrad, the update step computes (mm^T)^(-1/2)g, where the
+ # inversion of a model dimension by model dimension matrix is needed. To
+ # speed up this computation we calculate the following instead:
+ # m(m^Tm)^(-3/2)m^T moment1 = m u diag(sigma_sqrt_inv) u^T m^T moment1.
+ new_step = array_ops.expand_dims(
+ array_ops.zeros(flat_grad.get_shape(), dtype=var_dtype), -1)
+ head = math_ops.matmul(
+ m,
+ math_ops.matmul(
+ u,
+ math_ops.matmul(
+ array_ops.diag(sigma_sqrt_inv),
+ math_ops.matmul(
+ u,
+ math_ops.matmul(m, moment1_2d, transpose_a=True),
+ transpose_a=True))))
+
+ # When inverting (mm^t)^(1/2), we also add epsilon * I regularization for
+ # degenerate cases. We expand ((mm^t)^(1/2) + epsilon * I)^(-1) using
+ # Woodbury's identity.
+ # For full derivation please see paper at
+ # https://arxiv.org/pdf/1806.02958.pdf
+ tail = moment1_2d - math_ops.matmul(
+ m,
+ math_ops.matmul(
+ u,
+ math_ops.matmul(
+ array_ops.diag(
+ math_ops.divide(math_ops.cast(1.0, dtype=var_dtype),
+ sigma)),
+ math_ops.matmul(
+ u,
+ math_ops.matmul(m, moment1_2d, transpose_a=True),
+ transpose_a=True))))
+ scaled_tail = math_ops.divide(tail, sigma_sqrt_min)
+
+ update_new_step = control_flow_ops.cond(
+ sigma_sqrt_min > eps, lambda: math_ops.add(head, scaled_tail),
+ lambda: math_ops.add(new_step, head))
+
+ # Update each variable.
+ update_step = []
+ for var in self._variables:
+ dim = self.shape_dict[var.name]
+ start_index = self.index_dict[var.name]
+ end_index = start_index + dim
+ var_update_correct_shape = array_ops.reshape(
+ update_new_step[start_index:end_index], var.get_shape())
+ var_updated = state_ops.assign_sub(var, lr * var_update_correct_shape)
+ update_step.append(var_updated)
+
+ return control_flow_ops.group(update_step)
diff --git a/tensorflow/contrib/opt/python/training/ggt_test.py b/tensorflow/contrib/opt/python/training/ggt_test.py
new file mode 100644
index 0000000000..42162960b0
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/ggt_test.py
@@ -0,0 +1,183 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for GGTOptimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from tensorflow.contrib.opt.python.training.ggt import GGTOptimizer
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+def ggt_update_numpy(param,
+ g_t,
+ lr,
+ grad_buffer,
+ m,
+ window,
+ t,
+ beta1=0.9,
+ eps=1e-4,
+ svd_eps=1e-6,
+ sigma_eps=1e-2):
+ """Tests the correctness of one step of GGT."""
+ m_t = m * beta1 + (1 - beta1) * g_t
+ grad_buffer[((t - 1) % window), :] = m_t
+ m_matrix = np.transpose(grad_buffer / np.sqrt(np.minimum(t, window)))
+ mm = np.dot(np.transpose(m_matrix), m_matrix)
+ damping = np.eye(window) * svd_eps
+ u, sigma, _ = np.linalg.svd(mm + damping)
+
+ sigma_sqrt_inv = np.power(np.sqrt(sigma) + sigma_eps, -3)
+ new_step = np.linalg.multi_dot([
+ m_matrix, u,
+ np.diag(sigma_sqrt_inv),
+ np.transpose(u),
+ np.transpose(m_matrix), m_t
+ ])
+
+ sigma_sqrt_min = np.sqrt(sigma).min()
+
+ if sigma_sqrt_min > eps:
+ new_step += (m_t - np.linalg.multi_dot([
+ m_matrix, u,
+ np.diag(1.0 / sigma),
+ np.transpose(u),
+ np.transpose(m_matrix), m_t
+ ])) * (1.0 / sigma_sqrt_min)
+
+ param_t = param - lr * new_step
+ return param_t, m_t, grad_buffer
+
+
+class GGTOptimizerTest(test.TestCase):
+
+ def doTestBasic(self, use_resource=False):
+ # SVD does not support float16
+ for i, dtype in enumerate([dtypes.float32, dtypes.float64]):
+ with self.test_session(graph=ops.Graph()):
+ # Initialize variables for numpy implementation.
+ m0 = 0.0
+ window = 3
+ grad_buffer = np.zeros((window, 4), dtype=dtype.as_numpy_dtype)
+ lr = 0.001
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(
+ var0_np, name="var0_%d" % i)
+ var1 = resource_variable_ops.ResourceVariable(
+ var1_np, name="var1_%d" % i)
+ else:
+ var0 = variables.Variable(var0_np, name="var0")
+ var1 = variables.Variable(var1_np, name="var1")
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+
+ opt = GGTOptimizer(learning_rate=lr, window=window)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ opt_variables = opt.variables()
+
+ m_t = opt._get_moment1()
+ grad_buffer_t = opt._get_grad_buffer()
+ g_t = opt._get_flat_grad()
+ self.assertTrue(m_t is not None)
+ self.assertTrue(grad_buffer_t is not None)
+ self.assertTrue(g_t is not None)
+ self.assertIn(m_t, opt_variables)
+ self.assertIn(grad_buffer_t, opt_variables)
+ self.assertIn(g_t, opt_variables)
+
+ with ops.Graph().as_default():
+ # Shouldn't return non-slot variables from other graphs.
+ self.assertEqual(0, len(opt.variables()))
+
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+
+ m_t = opt._get_moment1()
+ grad_buffer_t = opt._get_grad_buffer()
+ g_t = opt._get_flat_grad()
+
+ # Run 3 steps of GGT
+ for t in range(1, 4):
+ if not context.executing_eagerly():
+ self.evaluate(update)
+ elif t > 1:
+ opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+
+ if t == 1:
+ self.assertAllCloseAccordingToType(
+ np.array([0.01, 0.01, 0.001, 0.001]), self.evaluate(m_t))
+ self.assertAllCloseAccordingToType(
+ np.array([[0.01, 0.01, 0.001, 0.001], [0., 0., 0., 0.],
+ [0., 0., 0., 0.]]), self.evaluate(grad_buffer_t))
+ elif t == 2:
+ self.assertAllCloseAccordingToType(
+ np.array([0.019, 0.019, 0.0019, 0.0019]), self.evaluate(m_t))
+ self.assertAllCloseAccordingToType(
+ np.array([[0.01, 0.01, 0.001, 0.001],
+ [0.019, 0.019, 0.0019, 0.0019], [0., 0., 0., 0.]]),
+ self.evaluate(grad_buffer_t))
+ else:
+ self.assertAllCloseAccordingToType(
+ np.array([0.0271, 0.0271, 0.00271, 0.00271]),
+ self.evaluate(m_t))
+ self.assertAllCloseAccordingToType(
+ np.array([[0.01, 0.01, 0.001,
+ 0.001], [0.019, 0.019, 0.0019, 0.0019],
+ [0.0271, 0.0271, 0.00271, 0.00271]]),
+ self.evaluate(grad_buffer_t))
+
+ self.assertAllCloseAccordingToType([0.1, 0.1, 0.01, 0.01],
+ self.evaluate(g_t))
+
+ var_np = np.append(var0_np, var1_np)
+ grads_np = np.append(grads0_np, grads1_np)
+ var_np, m0, grad_buffer = ggt_update_numpy(var_np, grads_np, lr,
+ grad_buffer, m0, window, t)
+
+ var0_np = var_np[:2]
+ var1_np = var_np[2:]
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
+ self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
+
+ def testBasic(self):
+ with self.test_session():
+ self.doTestBasic(use_resource=False)
+
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
+ def testResourceBasic(self):
+ self.doTestBasic(use_resource=True)
+
+
+if __name__ == "__main__":
+ test.main()