diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-28 19:01:26 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-28 19:08:59 -0700 |
commit | 20d5683b826be03776978af3b8108fc3b5dc9cb8 (patch) | |
tree | 2742a2f95489ec391db1583a36d193bfc30cc338 /tensorflow/compiler/tests | |
parent | 069f808e5c0462819bcd6c73c75491b00cdd42c2 (diff) |
Fix FTRL L2-shrinkage behavior: the gradient from the L2 shrinkage term should not end up in the accumulator.
PiperOrigin-RevId: 210648271
Diffstat (limited to 'tensorflow/compiler/tests')
-rw-r--r-- | tensorflow/compiler/tests/ftrl_test.py | 44 |
1 files changed, 42 insertions, 2 deletions
diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py index 7ca50b02d9..b1deb7f6a7 100644 --- a/tensorflow/compiler/tests/ftrl_test.py +++ b/tensorflow/compiler/tests/ftrl_test.py @@ -259,9 +259,49 @@ class FtrlOptimizerTest(xla_test.XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( - np.array([-0.21931979, -0.40642974]), var0.eval(), rtol=1e-4) + np.array([-0.22578996, -0.44345799]), var0.eval(), rtol=1e-4) self.assertAllCloseAccordingToType( - np.array([-0.0282721, -0.07188385]), var1.eval(), rtol=1e-4) + np.array([-0.14378493, -0.13229476]), var1.eval(), rtol=1e-4) + + def testFtrlWithL2ShrinkageDoesNotChangeLrSchedule(self): + """Verifies that l2 shrinkage in FTRL does not change lr schedule.""" + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) + grads1 = constant_op.constant([0.1, 0.2], dtype=dtype) + + opt0 = ftrl.FtrlOptimizer( + 3.0, + initial_accumulator_value=0.1, + l1_regularization_strength=0.001, + l2_regularization_strength=2.0, + l2_shrinkage_regularization_strength=0.1) + opt1 = ftrl.FtrlOptimizer( + 3.0, + initial_accumulator_value=0.1, + l1_regularization_strength=0.001, + l2_regularization_strength=2.0) + update0 = opt0.apply_gradients([(grads0, var0)]) + update1 = opt1.apply_gradients([(grads1, var1)]) + variables.global_variables_initializer().run() + + self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) + self.assertAllCloseAccordingToType([1.0, 2.0], var1.eval()) + + # Run 10 steps FTRL + for _ in range(10): + update0.run() + update1.run() + + # var0 is experiencing L2 shrinkage so it should be smaller than var1 + # in magnitude. + self.assertTrue((var0.eval()**2 < var1.eval()**2).all()) + accum0 = list(opt0._slots["accum"].values())[0].eval() + accum1 = list(opt1._slots["accum"].values())[0].eval() + # L2 shrinkage should not change how we update grad accumulator. + self.assertAllCloseAccordingToType(accum0, accum1) # When variables are initialized with Zero, FTRL-Proximal has two properties: # 1. Without L1&L2 but with fixed learning rate, FTRL-Proximal is identical |