aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-28 19:01:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-28 19:08:59 -0700
commit20d5683b826be03776978af3b8108fc3b5dc9cb8 (patch)
tree2742a2f95489ec391db1583a36d193bfc30cc338 /tensorflow/python/training
parent069f808e5c0462819bcd6c73c75491b00cdd42c2 (diff)
Fix FTRL L2-shrinkage behavior: the gradient from the L2 shrinkage term should not end up in the accumulator.
PiperOrigin-RevId: 210648271
Diffstat (limited to 'tensorflow/python/training')
-rw-r--r--tensorflow/python/training/ftrl_test.py101
1 files changed, 85 insertions, 16 deletions
diff --git a/tensorflow/python/training/ftrl_test.py b/tensorflow/python/training/ftrl_test.py
index 775bdb3f60..76ca5b45c9 100644
--- a/tensorflow/python/training/ftrl_test.py
+++ b/tensorflow/python/training/ftrl_test.py
@@ -117,8 +117,7 @@ class FtrlOptimizerTest(test.TestCase):
# Run 1 step of sgd
sgd_op.run()
# Validate updated params
- self.assertAllCloseAccordingToType(
- [[0, 1]], var0.eval(), atol=0.01)
+ self.assertAllCloseAccordingToType([[0, 1]], var0.eval(), atol=0.01)
def testFtrlWithL1(self):
for dtype in [dtypes.half, dtypes.float32]:
@@ -212,24 +211,96 @@ class FtrlOptimizerTest(test.TestCase):
v0_val, v1_val = sess.run([var0, var1])
self.assertAllCloseAccordingToType(
- np.array([-0.22078767, -0.41378114]), v0_val)
+ np.array([-0.22578995, -0.44345796]), v0_val)
self.assertAllCloseAccordingToType(
- np.array([-0.02919818, -0.07343706]), v1_val)
+ np.array([-0.14378493, -0.13229476]), v1_val)
+
+ def testFtrlWithL1_L2_L2ShrinkageSparse(self):
+ """Tests the new FTRL op with support for l2 shrinkage on sparse grads."""
+ for dtype in [dtypes.half, dtypes.float32]:
+ with self.test_session() as sess:
+ var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
+ var1 = variables.Variable([[4.0], [3.0]], dtype=dtype)
+ grads0 = ops.IndexedSlices(
+ constant_op.constant([0.1], shape=[1, 1], dtype=dtype),
+ constant_op.constant([0]), constant_op.constant([2, 1]))
+ grads1 = ops.IndexedSlices(
+ constant_op.constant([0.02], shape=[1, 1], dtype=dtype),
+ constant_op.constant([1]), constant_op.constant([2, 1]))
+
+ opt = ftrl.FtrlOptimizer(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=2.0,
+ l2_shrinkage_regularization_strength=0.1)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ v0_val, v1_val = sess.run([var0, var1])
+ self.assertAllCloseAccordingToType([[1.0], [2.0]], v0_val)
+ self.assertAllCloseAccordingToType([[4.0], [3.0]], v1_val)
+
+ # Run 10 steps FTRL
+ for _ in range(10):
+ update.run()
+
+ v0_val, v1_val = sess.run([var0, var1])
+ self.assertAllCloseAccordingToType([[-0.22578995], [2.]], v0_val)
+ self.assertAllCloseAccordingToType([[4.], [-0.13229476]], v1_val)
+
+ def testFtrlWithL2ShrinkageDoesNotChangeLrSchedule(self):
+ """Verifies that l2 shrinkage in FTRL does not change lr schedule."""
+ for dtype in [dtypes.half, dtypes.float32]:
+ with self.test_session() as sess:
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([1.0, 2.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+ grads1 = constant_op.constant([0.1, 0.2], dtype=dtype)
+
+ opt0 = ftrl.FtrlOptimizer(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=2.0,
+ l2_shrinkage_regularization_strength=0.1)
+ opt1 = ftrl.FtrlOptimizer(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=2.0)
+ update0 = opt0.apply_gradients([(grads0, var0)])
+ update1 = opt1.apply_gradients([(grads1, var1)])
+ variables.global_variables_initializer().run()
+
+ v0_val, v1_val = sess.run([var0, var1])
+ self.assertAllCloseAccordingToType([1.0, 2.0], v0_val)
+ self.assertAllCloseAccordingToType([1.0, 2.0], v1_val)
+
+ # Run 10 steps FTRL
+ for _ in range(10):
+ update0.run()
+ update1.run()
+
+ v0_val, v1_val = sess.run([var0, var1])
+ # var0 is experiencing L2 shrinkage so it should be smaller than var1
+ # in magnitude.
+ self.assertTrue((v0_val**2 < v1_val**2).all())
+ accum0 = list(sess.run(opt0._slots)["accum"].values())[0]
+ accum1 = list(sess.run(opt1._slots)["accum"].values())[0]
+ # L2 shrinkage should not change how we update grad accumulator.
+ self.assertAllCloseAccordingToType(accum0, accum1)
def applyOptimizer(self, opt, dtype, steps=5, is_sparse=False):
if is_sparse:
var0 = variables.Variable([[0.0], [0.0]], dtype=dtype)
var1 = variables.Variable([[0.0], [0.0]], dtype=dtype)
grads0 = ops.IndexedSlices(
- constant_op.constant(
- [0.1], shape=[1, 1], dtype=dtype),
- constant_op.constant([0]),
- constant_op.constant([2, 1]))
+ constant_op.constant([0.1], shape=[1, 1], dtype=dtype),
+ constant_op.constant([0]), constant_op.constant([2, 1]))
grads1 = ops.IndexedSlices(
- constant_op.constant(
- [0.02], shape=[1, 1], dtype=dtype),
- constant_op.constant([1]),
- constant_op.constant([2, 1]))
+ constant_op.constant([0.02], shape=[1, 1], dtype=dtype),
+ constant_op.constant([1]), constant_op.constant([2, 1]))
else:
var0 = variables.Variable([0.0, 0.0], dtype=dtype)
var1 = variables.Variable([0.0, 0.0], dtype=dtype)
@@ -277,8 +348,7 @@ class FtrlOptimizerTest(test.TestCase):
with self.test_session():
val2, val3 = self.applyOptimizer(
- adagrad.AdagradOptimizer(
- 3.0, initial_accumulator_value=0.1), dtype)
+ adagrad.AdagradOptimizer(3.0, initial_accumulator_value=0.1), dtype)
self.assertAllCloseAccordingToType(val0, val2)
self.assertAllCloseAccordingToType(val1, val3)
@@ -299,8 +369,7 @@ class FtrlOptimizerTest(test.TestCase):
with self.test_session():
val2, val3 = self.applyOptimizer(
- adagrad.AdagradOptimizer(
- 3.0, initial_accumulator_value=0.1),
+ adagrad.AdagradOptimizer(3.0, initial_accumulator_value=0.1),
dtype,
is_sparse=True)