aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-28 19:01:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-28 19:08:59 -0700
commit20d5683b826be03776978af3b8108fc3b5dc9cb8 (patch)
tree2742a2f95489ec391db1583a36d193bfc30cc338 /tensorflow/compiler/tests
parent069f808e5c0462819bcd6c73c75491b00cdd42c2 (diff)
Fix FTRL L2-shrinkage behavior: the gradient from the L2 shrinkage term should not end up in the accumulator.
PiperOrigin-RevId: 210648271
Diffstat (limited to 'tensorflow/compiler/tests')
-rw-r--r--tensorflow/compiler/tests/ftrl_test.py44
1 files changed, 42 insertions, 2 deletions
diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py
index 7ca50b02d9..b1deb7f6a7 100644
--- a/tensorflow/compiler/tests/ftrl_test.py
+++ b/tensorflow/compiler/tests/ftrl_test.py
@@ -259,9 +259,49 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
# Validate updated params
self.assertAllCloseAccordingToType(
- np.array([-0.21931979, -0.40642974]), var0.eval(), rtol=1e-4)
+ np.array([-0.22578996, -0.44345799]), var0.eval(), rtol=1e-4)
self.assertAllCloseAccordingToType(
- np.array([-0.0282721, -0.07188385]), var1.eval(), rtol=1e-4)
+ np.array([-0.14378493, -0.13229476]), var1.eval(), rtol=1e-4)
+
+ def testFtrlWithL2ShrinkageDoesNotChangeLrSchedule(self):
+ """Verifies that l2 shrinkage in FTRL does not change lr schedule."""
+ for dtype in self.float_types:
+ with self.test_session(), self.test_scope():
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+ grads1 = constant_op.constant([0.1, 0.2], dtype=dtype)
+
+ opt0 = ftrl.FtrlOptimizer(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=2.0,
+ l2_shrinkage_regularization_strength=0.1)
+ opt1 = ftrl.FtrlOptimizer(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=2.0)
+ update0 = opt0.apply_gradients([(grads0, var0)])
+ update1 = opt1.apply_gradients([(grads1, var1)])
+ variables.global_variables_initializer().run()
+
+ self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
+ self.assertAllCloseAccordingToType([1.0, 2.0], var1.eval())
+
+ # Run 10 steps FTRL
+ for _ in range(10):
+ update0.run()
+ update1.run()
+
+ # var0 is experiencing L2 shrinkage so it should be smaller than var1
+ # in magnitude.
+ self.assertTrue((var0.eval()**2 < var1.eval()**2).all())
+ accum0 = list(opt0._slots["accum"].values())[0].eval()
+ accum1 = list(opt1._slots["accum"].values())[0].eval()
+ # L2 shrinkage should not change how we update grad accumulator.
+ self.assertAllCloseAccordingToType(accum0, accum1)
# When variables are initialized with Zero, FTRL-Proximal has two properties:
# 1. Without L1&L2 but with fixed learning rate, FTRL-Proximal is identical