aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-06 18:02:00 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-07 21:21:42 -0700
commitd7416d53ef09052a3dc10ec54ee5bd19e883cf86 (patch)
tree7bf2b5cbee0ce8ad768da5116fdce07979b26982
parent1caaea99e0156523e7d65f7e54cc2ac7117a8f90 (diff)
[TF:XLA] Add implementation of ResourceApplyPowerSign and ResourceApplyAddSign.
PiperOrigin-RevId: 203547001
-rw-r--r--tensorflow/compiler/tests/BUILD28
-rw-r--r--tensorflow/compiler/tests/addsign_test.py145
-rw-r--r--tensorflow/compiler/tests/powersign_test.py142
-rw-r--r--tensorflow/compiler/tf2xla/kernels/training_ops.cc104
-rw-r--r--tensorflow/contrib/opt/python/training/addsign_test.py6
-rw-r--r--tensorflow/contrib/opt/python/training/powersign_test.py2
6 files changed, 423 insertions, 4 deletions
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 3be391ba90..69ff0d99cb 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -112,6 +112,34 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "addsign_test",
+ size = "small",
+ srcs = ["addsign_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/contrib/opt:opt_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:training",
+ ],
+)
+
+tf_xla_py_test(
+ name = "powersign_test",
+ size = "small",
+ srcs = ["powersign_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/contrib/opt:opt_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:training",
+ ],
+)
+
+tf_xla_py_test(
name = "argminmax_test",
size = "small",
srcs = ["argminmax_test.py"],
diff --git a/tensorflow/compiler/tests/addsign_test.py b/tensorflow/compiler/tests/addsign_test.py
new file mode 100644
index 0000000000..69cf7a0bf7
--- /dev/null
+++ b/tensorflow/compiler/tests/addsign_test.py
@@ -0,0 +1,145 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for AddSign."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.contrib.opt.python.training import addsign
+from tensorflow.contrib.opt.python.training import sign_decay
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+def py_linear_decay_fn(decay_steps):
+ def linear_decay(step):
+ step = min(step, decay_steps)
+ return float(decay_steps - step) / decay_steps
+ return linear_decay
+
+
+def addsign_update_numpy(params,
+ g_t,
+ m,
+ lr,
+ alpha=1.0,
+ beta=0.9,
+ py_sign_decay_fn=None,
+ t=None):
+ m_t = beta * m + (1 - beta) * g_t
+ if py_sign_decay_fn is None:
+ sign_decayed = 1.0
+ else:
+ sign_decayed = py_sign_decay_fn(t-1)
+ multiplier = alpha + sign_decayed * np.sign(g_t) * np.sign(m_t)
+ params_t = params - lr * multiplier * g_t
+ return params_t, m_t
+
+
+class AddSignTest(xla_test.XLATestCase):
+
+ def _testDense(self,
+ learning_rate=0.1,
+ sign_decay_fn=None,
+ py_sign_decay_fn=None,
+ alpha=1.0,
+ beta=0.9):
+ for dtype in self.float_types:
+ # TODO(b/111123982): remove once the bug is fixed.
+ if dtype == dtypes.float16:
+ continue
+ with self.test_session(), self.test_scope():
+ # Initialize variables for numpy implementation.
+ m0, m1 = 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype)
+
+ var0 = resource_variable_ops.ResourceVariable(var0_np)
+ var1 = resource_variable_ops.ResourceVariable(var1_np)
+ global_step = resource_variable_ops.ResourceVariable(0, trainable=False)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+
+ opt = addsign.AddSignOptimizer(
+ learning_rate=learning_rate,
+ alpha=alpha,
+ beta=beta,
+ sign_decay_fn=sign_decay_fn,
+ )
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]),
+ global_step=global_step)
+ neg_update = opt.apply_gradients(zip([-grads0, -grads1], [var0, var1]),
+ global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run 7 steps of AddSign
+ # first 4 steps with positive gradient
+ # last 3 steps with negative gradient (sign(gm) should be -1)
+ for t in range(1, 8):
+ if t < 5:
+ update.run()
+ else:
+ neg_update.run()
+
+ var0_np, m0 = addsign_update_numpy(
+ var0_np,
+ grads0_np if t < 5 else -grads0_np,
+ m0,
+ learning_rate,
+ alpha=alpha,
+ beta=beta,
+ py_sign_decay_fn=py_sign_decay_fn,
+ t=t,
+ )
+ var1_np, m1 = addsign_update_numpy(
+ var1_np,
+ grads1_np if t < 5 else -grads1_np,
+ m1,
+ learning_rate,
+ alpha=alpha,
+ beta=beta,
+ py_sign_decay_fn=py_sign_decay_fn,
+ t=t,
+ )
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+ def testDense(self):
+ decay_steps = 10
+ sign_decay_fn = sign_decay.get_linear_decay_fn(decay_steps)
+ py_sign_decay_fn = py_linear_decay_fn(decay_steps)
+ self._testDense()
+ self._testDense(learning_rate=0.01, alpha=0.1, beta=0.8)
+ self._testDense(
+ sign_decay_fn=sign_decay_fn, py_sign_decay_fn=py_sign_decay_fn)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/compiler/tests/powersign_test.py b/tensorflow/compiler/tests/powersign_test.py
new file mode 100644
index 0000000000..5fa7706d72
--- /dev/null
+++ b/tensorflow/compiler/tests/powersign_test.py
@@ -0,0 +1,142 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for PowerSign."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.contrib.opt.python.training import powersign
+from tensorflow.contrib.opt.python.training import sign_decay
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+def py_linear_decay_fn(decay_steps):
+ def linear_decay(step):
+ step = min(step, decay_steps)
+ return float(decay_steps - step) / decay_steps
+ return linear_decay
+
+
+def powersign_update_numpy(params,
+ g_t,
+ m,
+ lr,
+ base=math.e,
+ beta=0.9,
+ py_sign_decay_fn=None,
+ t=None):
+ m_t = beta * m + (1 - beta) * g_t
+ if py_sign_decay_fn is None:
+ sign_decayed = 1.0
+ else:
+ sign_decayed = py_sign_decay_fn(t-1)
+ multiplier = base ** (sign_decayed * np.sign(g_t) * np.sign(m_t))
+ params_t = params - lr * multiplier * g_t
+ return params_t, m_t
+
+
+class PowerSignTest(xla_test.XLATestCase):
+
+ def _testDense(self,
+ learning_rate=0.1,
+ sign_decay_fn=None,
+ py_sign_decay_fn=None,
+ base=math.e,
+ beta=0.9):
+ for dtype in self.float_types:
+ with self.test_session(), self.test_scope():
+ # Initialize variables for numpy implementation.
+ m0, m1 = 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype)
+
+ var0 = resource_variable_ops.ResourceVariable(var0_np)
+ var1 = resource_variable_ops.ResourceVariable(var1_np)
+ global_step = resource_variable_ops.ResourceVariable(0, trainable=False)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+
+ opt = powersign.PowerSignOptimizer(
+ learning_rate=learning_rate,
+ base=base,
+ beta=beta,
+ sign_decay_fn=sign_decay_fn,
+ )
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]),
+ global_step=global_step)
+ neg_update = opt.apply_gradients(zip([-grads0, -grads1], [var0, var1]),
+ global_step=global_step)
+
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run 7 steps of powersign
+ # first 4 steps with positive gradient
+ # last 3 steps with negative gradient (sign(gm) should be -1)
+ for t in range(1, 8):
+ if t < 5:
+ update.run()
+ else:
+ neg_update.run()
+
+ var0_np, m0 = powersign_update_numpy(
+ var0_np,
+ grads0_np if t < 5 else -grads0_np,
+ m0,
+ learning_rate,
+ base=base,
+ beta=beta,
+ py_sign_decay_fn=py_sign_decay_fn,
+ t=t,
+ )
+ var1_np, m1 = powersign_update_numpy(
+ var1_np,
+ grads1_np if t < 5 else -grads1_np,
+ m1,
+ learning_rate,
+ base=base,
+ beta=beta,
+ py_sign_decay_fn=py_sign_decay_fn,
+ t=t,
+ )
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+ def testDense(self):
+ decay_steps = 10
+ sign_decay_fn = sign_decay.get_linear_decay_fn(decay_steps)
+ py_sign_decay_fn = py_linear_decay_fn(decay_steps)
+ self._testDense()
+ self._testDense(learning_rate=0.1, base=10.0, beta=0.8)
+ self._testDense(
+ sign_decay_fn=sign_decay_fn, py_sign_decay_fn=py_sign_decay_fn)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
index b62a6e778d..f9f38897bd 100644
--- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
@@ -640,5 +640,109 @@ class ResourceApplyAdadelta : public XlaOpKernel {
REGISTER_XLA_OP(Name("ResourceApplyAdadelta").TypeConstraint("T", kFloatTypes),
ResourceApplyAdadelta);
+class ResourceApplySignBase : public XlaOpKernel {
+ public:
+ explicit ResourceApplySignBase(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ TensorShape var_shape, m_shape;
+ xla::XlaOp var, m;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m));
+ OP_REQUIRES(ctx, var_shape.IsSameSize(m_shape),
+ errors::InvalidArgument("var and m do not have the same shape",
+ var_shape.DebugString(), " ",
+ m_shape.DebugString()));
+ TensorShape grad_shape = ctx->InputShape(6);
+ OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
+ errors::InvalidArgument(
+ "var and grad do not have the same shape",
+ var_shape.DebugString(), " ", grad_shape.DebugString()));
+ CheckScalarParams(ctx);
+
+ xla::XlaOp lr = ctx->Input(2);
+ xla::XlaOp alpha = ctx->Input(3);
+ xla::XlaOp sign_decay = ctx->Input(4);
+ xla::XlaOp beta = ctx->Input(5);
+ xla::XlaOp grad = ctx->Input(6);
+
+ m = m * beta + grad * (xla::ScalarLike(beta, 1.0) - beta);
+ xla::XlaOp decay = xla::Sign(grad) * xla::Sign(m) * sign_decay;
+
+ xla::XlaOp grad_scale = ComputeGradientScale(alpha, decay);
+ var = var - lr * grad_scale * grad;
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m));
+ }
+
+ virtual void CheckScalarParams(XlaOpKernelContext* ctx) {
+ TensorShape lr_shape = ctx->InputShape(2);
+ TensorShape sign_decay_shape = ctx->InputShape(4);
+ TensorShape beta_shape = ctx->InputShape(5);
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
+ errors::InvalidArgument("lr is not a scalar: ",
+ lr_shape.DebugString()));
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(sign_decay_shape),
+ errors::InvalidArgument("sign_decay is not a scalar: ",
+ sign_decay_shape.DebugString()));
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta_shape),
+ errors::InvalidArgument("beta is not a scalar: ",
+ beta_shape.DebugString()));
+ }
+
+ virtual xla::XlaOp ComputeGradientScale(xla::XlaOp alpha,
+ xla::XlaOp decay) = 0;
+
+ private:
+ DataType dtype_;
+};
+
+class ResourceApplyAddSign : public ResourceApplySignBase {
+ public:
+ explicit ResourceApplyAddSign(OpKernelConstruction* ctx)
+ : ResourceApplySignBase(ctx) {}
+
+ void CheckScalarParams(XlaOpKernelContext* ctx) override {
+ ResourceApplySignBase::CheckScalarParams(ctx);
+ TensorShape alpha_shape = ctx->InputShape(3);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape),
+ errors::InvalidArgument("alpha is not a scalar: ",
+ alpha_shape.DebugString()));
+ }
+
+ xla::XlaOp ComputeGradientScale(xla::XlaOp alpha, xla::XlaOp decay) override {
+ return alpha + decay;
+ }
+};
+// TODO(b/111123982): Use kFloatTypes once the bug is fixed.
+REGISTER_XLA_OP(Name("ResourceApplyAddSign")
+ .TypeConstraint("T", {DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}),
+ ResourceApplyAddSign);
+
+class ResourceApplyPowerSign : public ResourceApplySignBase {
+ public:
+ explicit ResourceApplyPowerSign(OpKernelConstruction* ctx)
+ : ResourceApplySignBase(ctx) {}
+
+ void CheckScalarParams(XlaOpKernelContext* ctx) override {
+ ResourceApplySignBase::CheckScalarParams(ctx);
+ TensorShape logbase_shape = ctx->InputShape(3);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(logbase_shape),
+ errors::InvalidArgument("logbase is not a scalar: ",
+ logbase_shape.DebugString()));
+ }
+
+ xla::XlaOp ComputeGradientScale(xla::XlaOp alpha, xla::XlaOp decay) override {
+ return xla::Exp(alpha * decay);
+ }
+};
+REGISTER_XLA_OP(Name("ResourceApplyPowerSign").TypeConstraint("T", kFloatTypes),
+ ResourceApplyPowerSign);
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/contrib/opt/python/training/addsign_test.py b/tensorflow/contrib/opt/python/training/addsign_test.py
index 08d45ed73f..628a735e72 100644
--- a/tensorflow/contrib/opt/python/training/addsign_test.py
+++ b/tensorflow/contrib/opt/python/training/addsign_test.py
@@ -214,7 +214,7 @@ class AddSignTest(test.TestCase):
# Run 7 steps of AddSign
# first 4 steps with positive gradient
# last 3 steps with negative gradient (sign(gm) should be -1)
- for t in range(1, 4):
+ for t in range(1, 8):
if t < 5:
update.run()
else:
@@ -222,7 +222,7 @@ class AddSignTest(test.TestCase):
var0_np, m0 = addsign_update_numpy(
var0_np,
- grads0_np,
+ grads0_np if t < 5 else -grads0_np,
m0,
learning_rate,
alpha=alpha,
@@ -232,7 +232,7 @@ class AddSignTest(test.TestCase):
)
var1_np, m1 = addsign_update_numpy(
var1_np,
- grads1_np,
+ grads1_np if t < 5 else -grads1_np,
m1,
learning_rate,
alpha=alpha,
diff --git a/tensorflow/contrib/opt/python/training/powersign_test.py b/tensorflow/contrib/opt/python/training/powersign_test.py
index 5214082dd6..0bcf5d230a 100644
--- a/tensorflow/contrib/opt/python/training/powersign_test.py
+++ b/tensorflow/contrib/opt/python/training/powersign_test.py
@@ -216,7 +216,7 @@ class PowerSignTest(test.TestCase):
self.assertAllClose([1.0, 2.0], var0.eval())
self.assertAllClose([3.0, 4.0], var1.eval())
- # Run 3 steps of powersign
+ # Run 7 steps of powersign
# first 4 steps with positive gradient
# last 3 steps with negative gradient (sign(gm) should be -1)
for t in range(1, 8):