aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-02 10:40:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-02 10:43:45 -0700
commite69fb48d2b6a5c523c0fe74c14a8121fa3685af6 (patch)
treeff68c76eebce35762e9e0a62e889e4edd1c8226a
parent94212e2b02e7e585e815bb659662253dceee9d55 (diff)
[TF:XLA] Add implementation of ResourceApplyAdadelta.
PiperOrigin-RevId: 202975643
-rw-r--r--tensorflow/compiler/tests/BUILD13
-rw-r--r--tensorflow/compiler/tests/adadelta_test.py134
-rw-r--r--tensorflow/compiler/tf2xla/kernels/training_ops.cc69
3 files changed, 216 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 95fda489a1..080b1c9c35 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -71,6 +71,19 @@ py_test(
)
tf_xla_py_test(
+ name = "adadelta_test",
+ size = "medium",
+ srcs = ["adadelta_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:training",
+ ],
+)
+
+tf_xla_py_test(
name = "adagrad_test",
size = "small",
srcs = ["adagrad_test.py"],
diff --git a/tensorflow/compiler/tests/adadelta_test.py b/tensorflow/compiler/tests/adadelta_test.py
new file mode 100644
index 0000000000..3e3c09c66e
--- /dev/null
+++ b/tensorflow/compiler/tests/adadelta_test.py
@@ -0,0 +1,134 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Adadelta Optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import adadelta
+
+
+class AdadeltaOptimizerTest(xla_test.XLATestCase):
+
+ def testBasic(self):
+ num_updates = 4 # number of ADADELTA steps to perform
+ for dtype in self.float_types:
+ with self.test_session(), self.test_scope():
+ for grad in [0.2, 0.1, 0.01]:
+ for lr in [1.0, 0.5, 0.1]:
+ var0_init = [1.0, 2.0]
+ var1_init = [3.0, 4.0]
+ var0 = resource_variable_ops.ResourceVariable(
+ var0_init, dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable(
+ var1_init, dtype=dtype)
+
+ grads = constant_op.constant([grad, grad], dtype=dtype)
+
+ accum = 0.0
+ accum_update = 0.0
+
+ # ADADELTA gradient optimizer
+ rho = 0.95
+ epsilon = 1e-8
+ adadelta_opt = adadelta.AdadeltaOptimizer(
+ learning_rate=lr, rho=rho, epsilon=epsilon)
+ adadelta_update = adadelta_opt.apply_gradients(
+ zip([grads, grads], [var0, var1]))
+ self.evaluate(variables.global_variables_initializer())
+ opt_vars = adadelta_opt.variables()
+ self.assertStartsWith(opt_vars[0].name, var0._shared_name)
+ self.assertStartsWith(opt_vars[1].name, var0._shared_name)
+ self.assertStartsWith(opt_vars[2].name, var1._shared_name)
+ self.assertStartsWith(opt_vars[3].name, var1._shared_name)
+ self.assertEqual(4, len(opt_vars))
+ # Assign slots
+ slot = [None] * 2
+ slot_update = [None] * 2
+ self.assertEqual(["accum", "accum_update"],
+ adadelta_opt.get_slot_names())
+ slot[0] = adadelta_opt.get_slot(var0, "accum")
+ self.assertEquals(slot[0].get_shape(), var0.get_shape())
+ self.assertFalse(slot[0] in variables.trainable_variables())
+
+ slot_update[0] = adadelta_opt.get_slot(var0, "accum_update")
+ self.assertEquals(slot_update[0].get_shape(), var0.get_shape())
+ self.assertFalse(slot_update[0] in variables.trainable_variables())
+
+ slot[1] = adadelta_opt.get_slot(var1, "accum")
+ self.assertEquals(slot[1].get_shape(), var1.get_shape())
+ self.assertFalse(slot[1] in variables.trainable_variables())
+
+ slot_update[1] = adadelta_opt.get_slot(var1, "accum_update")
+ self.assertEquals(slot_update[1].get_shape(), var1.get_shape())
+ self.assertFalse(slot_update[1] in variables.trainable_variables())
+
+ # Fetch params to validate initial values
+ self.assertAllClose(var0_init, self.evaluate(var0))
+ self.assertAllClose(var1_init, self.evaluate(var1))
+
+ update = [None] * num_updates
+ tot_update = 0
+ for step in range(num_updates):
+ # Run adadelta update for comparison
+ self.evaluate(adadelta_update)
+
+ # Perform initial update without previous accum values
+ accum = accum * rho + (grad**2) * (1 - rho)
+ update[step] = (
+ np.sqrt(accum_update + epsilon) *
+ (1. / np.sqrt(accum + epsilon)) * grad)
+ accum_update = (
+ accum_update * rho + (update[step]**2) * (1.0 - rho))
+ tot_update += update[step] * lr
+
+ # Check that the accumulators have been updated
+ for slot_idx in range(2):
+ self.assertAllCloseAccordingToType(
+ np.array([accum, accum], dtype=dtype),
+ self.evaluate(slot[slot_idx]),
+ rtol=1e-5)
+
+ self.assertAllCloseAccordingToType(
+ np.array([accum_update, accum_update], dtype=dtype),
+ self.evaluate(slot_update[slot_idx]),
+ rtol=1e-5)
+
+ # Check that the parameters have been updated
+ self.assertAllCloseAccordingToType(
+ np.array(
+ [var0_init[0] - tot_update, var0_init[1] - tot_update],
+ dtype=dtype),
+ self.evaluate(var0),
+ rtol=1e-5)
+
+ self.assertAllCloseAccordingToType(
+ np.array(
+ [var1_init[0] - tot_update, var1_init[1] - tot_update],
+ dtype=dtype),
+ self.evaluate(var1),
+ rtol=1e-5)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
index f3e112c7b3..68b1fce477 100644
--- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
@@ -457,5 +457,74 @@ class ResourceApplyFtrlV2 : public XlaOpKernel {
REGISTER_XLA_OP(Name("ResourceApplyFtrlV2").TypeConstraint("T", kFloatTypes),
ResourceApplyFtrlV2);
+class ResourceApplyAdadelta : public XlaOpKernel {
+ public:
+ explicit ResourceApplyAdadelta(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ TensorShape var_shape, accum_shape, accum_update_shape;
+ xla::XlaOp var, accum, accum_update;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
+ OP_REQUIRES_OK(ctx,
+ ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &accum_update_shape,
+ &accum_update));
+
+ TensorShape lr_shape = ctx->InputShape(3);
+ TensorShape rho_shape = ctx->InputShape(4);
+ TensorShape epsilon_shape = ctx->InputShape(5);
+ TensorShape grad_shape = ctx->InputShape(6);
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
+ errors::InvalidArgument("lr is not a scalar: ",
+ lr_shape.DebugString()));
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho_shape),
+ errors::InvalidArgument("rho is not a scalar: ",
+ rho_shape.DebugString()));
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape),
+ errors::InvalidArgument("epsilon is not a scalar: ",
+ epsilon_shape.DebugString()));
+
+ OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
+ errors::InvalidArgument(
+ "var and accum do not have the same shape",
+ var_shape.DebugString(), " ", accum_shape.DebugString()));
+
+ OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
+ errors::InvalidArgument(
+ "var and grad do not have the same shape",
+ var_shape.DebugString(), " ", grad_shape.DebugString()));
+
+ xla::XlaOp lr = ctx->Input(3);
+ xla::XlaOp rho = ctx->Input(4);
+ xla::XlaOp epsilon = ctx->Input(5);
+ xla::XlaOp grad = ctx->Input(6);
+
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp neg_half = XlaHelpers::FloatLiteral(b, dtype_, -0.5);
+ xla::XlaOp half = XlaHelpers::FloatLiteral(b, dtype_, 0.5);
+ xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0);
+ xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype_, 2.0);
+
+ accum = rho * accum + (one - rho) * xla::Pow(grad, two);
+ xla::XlaOp update = xla::Pow(accum_update + epsilon, half) *
+ xla::Pow(accum + epsilon, neg_half) * grad;
+ accum_update = rho * accum_update + (one - rho) * xla::Pow(update, two);
+ var = var - update * lr;
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, accum_update));
+ }
+
+ private:
+ DataType dtype_;
+};
+REGISTER_XLA_OP(Name("ResourceApplyAdadelta").TypeConstraint("T", kFloatTypes),
+ ResourceApplyAdadelta);
+
} // namespace
} // namespace tensorflow