aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-09 14:47:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-09 14:52:16 -0700
commit3057b7bf5132eb5a7bae5414925a40c4d2429716 (patch)
treefb172848ed3b3df5b8fd0d07d42497419081fda1
parent435599f5d896d7e1f721ffe6fd092d39efe2b027 (diff)
[TF-XLA] Implement FtrlOptimizer
Change the TF documentation for the operation assigned to `linear` variable in ResourceApplyFtrl training_ops. PiperOrigin-RevId: 158565492
-rw-r--r--tensorflow/compiler/tests/BUILD14
-rw-r--r--tensorflow/compiler/tests/ftrl_test.py253
-rw-r--r--tensorflow/compiler/tf2xla/kernels/training_ops.cc107
-rw-r--r--tensorflow/core/ops/training_ops.cc2
4 files changed, 375 insertions, 1 deletions
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index d18e51e32c..ef19e23858 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -184,6 +184,20 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "ftrl_test",
+ size = "small",
+ srcs = ["ftrl_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:training",
+ ],
+)
+
+tf_xla_py_test(
name = "function_test",
size = "small",
srcs = ["function_test.py"],
diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py
new file mode 100644
index 0000000000..6b328fb618
--- /dev/null
+++ b/tensorflow/compiler/tests/ftrl_test.py
@@ -0,0 +1,253 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Ftrl optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import adagrad
+from tensorflow.python.training import ftrl
+from tensorflow.python.training import gradient_descent
+
+
+class FtrlOptimizerTest(XLATestCase):
+
+ def initVariableAndGradient(self, dtype):
+ var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+ grads1 = constant_op.constant([0.02, 0.04], dtype=dtype)
+
+ return var0, var1, grads0, grads1
+
+ def equivAdagradTest_FtrlPart(self, steps, dtype):
+ var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype)
+ opt = ftrl.FtrlOptimizer(
+ 3.0,
+ learning_rate_power=-0.5, # using Adagrad learning rate
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0)
+ ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([0.0, 0.0], var0.eval())
+ self.assertAllClose([0.0, 0.0], var1.eval())
+
+ # Run Ftrl for a few steps
+ for _ in range(steps):
+ ftrl_update.run()
+
+ return var0.eval(), var1.eval()
+
+ def equivAdagradTest_AdagradPart(self, steps, dtype):
+ var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype)
+ opt = adagrad.AdagradOptimizer(3.0, initial_accumulator_value=0.1)
+ adagrad_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([0.0, 0.0], var0.eval())
+ self.assertAllClose([0.0, 0.0], var1.eval())
+
+ # Run Adagrad for a few steps
+ for _ in range(steps):
+ adagrad_update.run()
+
+ return var0.eval(), var1.eval()
+
+ def equivGradientDescentTest_FtrlPart(self, steps, dtype):
+ var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype)
+ opt = ftrl.FtrlOptimizer(
+ 3.0,
+ learning_rate_power=-0.0, # using Fixed learning rate
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0)
+ ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([0.0, 0.0], var0.eval())
+ self.assertAllClose([0.0, 0.0], var1.eval())
+
+ # Run Ftrl for a few steps
+ for _ in range(steps):
+ ftrl_update.run()
+
+ return var0.eval(), var1.eval()
+
+ def equivGradientDescentTest_GradientDescentPart(self, steps, dtype):
+ var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype)
+ opt = gradient_descent.GradientDescentOptimizer(3.0, name="sgd")
+ sgd_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([0.0, 0.0], var0.eval())
+ self.assertAllClose([0.0, 0.0], var1.eval())
+
+ # Run GradientDescent for a few steps
+ for _ in range(steps):
+ sgd_update.run()
+
+ return var0.eval(), var1.eval()
+
+ def testFtrlwithoutRegularization(self):
+ for dtype in self.float_types:
+ with self.test_session(), self.test_scope():
+ var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
+ opt = ftrl.FtrlOptimizer(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0)
+ ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([0.0, 0.0], var0.eval())
+ self.assertAllClose([0.0, 0.0], var1.eval())
+
+ # Run 3 steps FTRL
+ for _ in range(3):
+ ftrl_update.run()
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ np.array([-2.60260963, -4.29698515]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([-0.28432083, -0.56694895]), var1.eval())
+
+ def testFtrlwithoutRegularization2(self):
+ for dtype in self.float_types:
+ with self.test_session(), self.test_scope():
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
+ opt = ftrl.FtrlOptimizer(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0)
+ ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([4.0, 3.0], var1.eval())
+
+ # Run 3 steps FTRL
+ for _ in range(3):
+ ftrl_update.run()
+
+ # Validate updated params
+ self.assertAllClose(
+ np.array([-2.55607247, -3.98729396]), var0.eval(), 1e-5, 1e-5)
+ self.assertAllClose(
+ np.array([-0.28232238, -0.56096673]), var1.eval(), 1e-5, 1e-5)
+
+ def testFtrlWithL1(self):
+ for dtype in self.float_types:
+ with self.test_session(), self.test_scope():
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
+ opt = ftrl.FtrlOptimizer(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=0.0)
+ ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([4.0, 3.0], var1.eval())
+
+ # Run 10 steps FTRL
+ for _ in range(10):
+ ftrl_update.run()
+
+ # Validate updated params
+ self.assertAllClose(np.array([-7.66718769, -10.91273689]), var0.eval())
+ self.assertAllClose(np.array([-0.93460727, -1.86147261]), var1.eval())
+
+ def testFtrlWithL1_L2(self):
+ for dtype in self.float_types:
+ with self.test_session(), self.test_scope():
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
+ opt = ftrl.FtrlOptimizer(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=2.0)
+ ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([4.0, 3.0], var1.eval())
+
+ # Run 10 steps FTRL
+ for _ in range(10):
+ ftrl_update.run()
+
+ # Validate updated params
+ self.assertAllClose(np.array([-0.24059935, -0.46829352]), var0.eval())
+ self.assertAllClose(np.array([-0.02406147, -0.04830509]), var1.eval())
+
+ # When variables are intialized with Zero, FTRL-Proximal has two properties:
+ # 1. Without L1&L2 but with fixed learning rate, FTRL-Proximal is identical
+ # with GradientDescent.
+ # 2. Without L1&L2 but with adaptive learning rate, FTRL-Proximal is idential
+ # with Adagrad.
+ # So, basing on these two properties, we test if our implementation of
+ # FTRL-Proximal performs same updates as Adagrad or GradientDescent.
+ def testEquivAdagradwithoutRegularization(self):
+ steps = 5
+ for dtype in self.float_types:
+ with self.test_session(), self.test_scope():
+ val0, val1 = self.equivAdagradTest_FtrlPart(steps, dtype)
+ with self.test_session(), self.test_scope():
+ val2, val3 = self.equivAdagradTest_AdagradPart(steps, dtype)
+
+ self.assertAllClose(val0, val2)
+ self.assertAllClose(val1, val3)
+
+ def testEquivGradientDescentwithoutRegularization(self):
+ steps = 5
+ for dtype in self.float_types:
+ with self.test_session(), self.test_scope():
+ val0, val1 = self.equivGradientDescentTest_FtrlPart(steps, dtype)
+ with self.test_session(), self.test_scope():
+ val2, val3 = self.equivGradientDescentTest_GradientDescentPart(
+ steps, dtype)
+
+ self.assertAllClose(val0, val2)
+ self.assertAllClose(val1, val3)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
index ddd81cb490..e9ac1ee91b 100644
--- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
@@ -364,5 +364,112 @@ class ResourceApplyRMSProp : public XlaOpKernel {
};
REGISTER_XLA_OP(Name("ResourceApplyRMSProp"), ResourceApplyRMSProp);
+class ResourceApplyFtrl : public XlaOpKernel {
+ public:
+ explicit ResourceApplyFtrl(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::ComputationBuilder* b = ctx->builder();
+
+ DataType var_type, accum_type, linear_type;
+ TensorShape var_shape, accum_shape, linear_shape;
+ OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
+ OP_REQUIRES_OK(ctx,
+ ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape));
+ OP_REQUIRES_OK(
+ ctx, ctx->GetVariableTypeAndShape(2, &linear_type, &linear_shape));
+
+ OP_REQUIRES(
+ ctx,
+ dtype_ == var_type && dtype_ == accum_type && dtype_ == linear_type,
+ errors::InvalidArgument(
+ "Types of variable arguments to ResourceApplyFtrl must match: ",
+ DataTypeString(dtype_), " vs. ", DataTypeString(var_type), " and ",
+ DataTypeString(accum_type), " and ", DataTypeString(linear_type)));
+
+ OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
+ errors::InvalidArgument(
+ "var and accum do not have the same shape",
+ var_shape.DebugString(), " ", accum_shape.DebugString()));
+
+ OP_REQUIRES(ctx, var_shape.IsSameSize(linear_shape),
+ errors::InvalidArgument(
+ "var and linear do not have the same shape",
+ var_shape.DebugString(), " ", linear_shape.DebugString()));
+
+ TensorShape grad_shape = ctx->InputShape(3);
+ TensorShape lr_shape = ctx->InputShape(4);
+ TensorShape l1_shape = ctx->InputShape(5);
+ TensorShape l2_shape = ctx->InputShape(6);
+ TensorShape lr_power_shape = ctx->InputShape(7);
+
+ OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
+ errors::InvalidArgument(
+ "var and grad do not have the same shape",
+ var_shape.DebugString(), " ", grad_shape.DebugString()));
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
+ errors::InvalidArgument("lr is not a scalar: ",
+ lr_shape.DebugString()));
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1_shape),
+ errors::InvalidArgument("l1 is not a scalar: ",
+ l1_shape.DebugString()));
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2_shape),
+ errors::InvalidArgument("l2 is not a scalar: ",
+ l2_shape.DebugString()));
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_power_shape),
+ errors::InvalidArgument("lr_power is not a scalar: ",
+ lr_power_shape.DebugString()));
+
+ xla::ComputationDataHandle var, accum, linear;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &linear));
+ xla::ComputationDataHandle grad = ctx->Input(3);
+ xla::ComputationDataHandle lr = ctx->Input(4);
+ xla::ComputationDataHandle l1 = ctx->Input(5);
+ xla::ComputationDataHandle l2 = ctx->Input(6);
+ xla::ComputationDataHandle lr_power = ctx->Input(7);
+
+ // new_accum = accum + grad * grad
+ // linear += grad - (new_accum^(-lr_power) - accum^(-lr_power)) / lr * var
+ // quadratic = (new_accum^(-lr_power) / lr) + 2 * l2
+ // var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0
+ // accum = new_accum
+
+ xla::ComputationDataHandle zero_broadcast = b->Broadcast(
+ XlaHelpers::FloatLiteral(b, dtype_, 0.0), var_shape.dim_sizes());
+ xla::ComputationDataHandle two = XlaHelpers::FloatLiteral(b, dtype_, 2.0);
+
+ xla::ComputationDataHandle new_accum = b->Add(accum, b->Pow(grad, two));
+ xla::ComputationDataHandle new_accum_lr_pow =
+ b->Pow(new_accum, b->Neg(lr_power));
+ xla::ComputationDataHandle accum_lr_pow = b->Pow(accum, b->Neg(lr_power));
+ linear = b->Add(
+ linear,
+ b->Sub(grad, b->Mul(b->Div(b->Sub(new_accum_lr_pow, accum_lr_pow), lr),
+ var)));
+ xla::ComputationDataHandle quadratic =
+ b->Add(b->Div(new_accum_lr_pow, lr), b->Mul(two, l2));
+ xla::ComputationDataHandle pre_shrink =
+ b->Div(b->Sub(b->Mul(l1, b->Sign(linear)), linear), quadratic);
+ var = b->Select(b->Gt(b->Abs(linear), l1), pre_shrink, zero_broadcast);
+ accum = new_accum;
+
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, linear));
+ }
+
+ private:
+ DataType dtype_;
+};
+REGISTER_XLA_OP(Name("ResourceApplyFtrl"), ResourceApplyFtrl);
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/ops/training_ops.cc b/tensorflow/core/ops/training_ops.cc
index 1d24ea36a3..5bb93daea2 100644
--- a/tensorflow/core/ops/training_ops.cc
+++ b/tensorflow/core/ops/training_ops.cc
@@ -925,7 +925,7 @@ REGISTER_OP("ResourceApplyFtrl")
Update '*var' according to the Ftrl-proximal scheme.
accum_new = accum + grad * grad
-linear += grad + (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var
+linear += grad - (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var
quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2
var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0
accum = accum_new