From e69fb48d2b6a5c523c0fe74c14a8121fa3685af6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 2 Jul 2018 10:40:47 -0700 Subject: [TF:XLA] Add implementation of ResourceApplyAdadelta. PiperOrigin-RevId: 202975643 --- tensorflow/compiler/tests/BUILD | 13 ++ tensorflow/compiler/tests/adadelta_test.py | 134 +++++++++++++++++++++ tensorflow/compiler/tf2xla/kernels/training_ops.cc | 69 +++++++++++ 3 files changed, 216 insertions(+) create mode 100644 tensorflow/compiler/tests/adadelta_test.py diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 95fda489a1..080b1c9c35 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -70,6 +70,19 @@ py_test( ], ) +tf_xla_py_test( + name = "adadelta_test", + size = "medium", + srcs = ["adadelta_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + ], +) + tf_xla_py_test( name = "adagrad_test", size = "small", diff --git a/tensorflow/compiler/tests/adadelta_test.py b/tensorflow/compiler/tests/adadelta_test.py new file mode 100644 index 0000000000..3e3c09c66e --- /dev/null +++ b/tensorflow/compiler/tests/adadelta_test.py @@ -0,0 +1,134 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Adadelta Optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import adadelta + + +class AdadeltaOptimizerTest(xla_test.XLATestCase): + + def testBasic(self): + num_updates = 4 # number of ADADELTA steps to perform + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + for grad in [0.2, 0.1, 0.01]: + for lr in [1.0, 0.5, 0.1]: + var0_init = [1.0, 2.0] + var1_init = [3.0, 4.0] + var0 = resource_variable_ops.ResourceVariable( + var0_init, dtype=dtype) + var1 = resource_variable_ops.ResourceVariable( + var1_init, dtype=dtype) + + grads = constant_op.constant([grad, grad], dtype=dtype) + + accum = 0.0 + accum_update = 0.0 + + # ADADELTA gradient optimizer + rho = 0.95 + epsilon = 1e-8 + adadelta_opt = adadelta.AdadeltaOptimizer( + learning_rate=lr, rho=rho, epsilon=epsilon) + adadelta_update = adadelta_opt.apply_gradients( + zip([grads, grads], [var0, var1])) + self.evaluate(variables.global_variables_initializer()) + opt_vars = adadelta_opt.variables() + self.assertStartsWith(opt_vars[0].name, var0._shared_name) + self.assertStartsWith(opt_vars[1].name, var0._shared_name) + self.assertStartsWith(opt_vars[2].name, var1._shared_name) + self.assertStartsWith(opt_vars[3].name, var1._shared_name) + self.assertEqual(4, len(opt_vars)) + # Assign slots + slot = [None] * 2 + slot_update = [None] * 2 + self.assertEqual(["accum", "accum_update"], + adadelta_opt.get_slot_names()) + slot[0] = adadelta_opt.get_slot(var0, "accum") + self.assertEquals(slot[0].get_shape(), var0.get_shape()) + self.assertFalse(slot[0] in variables.trainable_variables()) + + slot_update[0] = adadelta_opt.get_slot(var0, "accum_update") + self.assertEquals(slot_update[0].get_shape(), var0.get_shape()) + self.assertFalse(slot_update[0] in variables.trainable_variables()) + + slot[1] = adadelta_opt.get_slot(var1, "accum") + self.assertEquals(slot[1].get_shape(), var1.get_shape()) + self.assertFalse(slot[1] in variables.trainable_variables()) + + slot_update[1] = adadelta_opt.get_slot(var1, "accum_update") + self.assertEquals(slot_update[1].get_shape(), var1.get_shape()) + self.assertFalse(slot_update[1] in variables.trainable_variables()) + + # Fetch params to validate initial values + self.assertAllClose(var0_init, self.evaluate(var0)) + self.assertAllClose(var1_init, self.evaluate(var1)) + + update = [None] * num_updates + tot_update = 0 + for step in range(num_updates): + # Run adadelta update for comparison + self.evaluate(adadelta_update) + + # Perform initial update without previous accum values + accum = accum * rho + (grad**2) * (1 - rho) + update[step] = ( + np.sqrt(accum_update + epsilon) * + (1. / np.sqrt(accum + epsilon)) * grad) + accum_update = ( + accum_update * rho + (update[step]**2) * (1.0 - rho)) + tot_update += update[step] * lr + + # Check that the accumulators have been updated + for slot_idx in range(2): + self.assertAllCloseAccordingToType( + np.array([accum, accum], dtype=dtype), + self.evaluate(slot[slot_idx]), + rtol=1e-5) + + self.assertAllCloseAccordingToType( + np.array([accum_update, accum_update], dtype=dtype), + self.evaluate(slot_update[slot_idx]), + rtol=1e-5) + + # Check that the parameters have been updated + self.assertAllCloseAccordingToType( + np.array( + [var0_init[0] - tot_update, var0_init[1] - tot_update], + dtype=dtype), + self.evaluate(var0), + rtol=1e-5) + + self.assertAllCloseAccordingToType( + np.array( + [var1_init[0] - tot_update, var1_init[1] - tot_update], + dtype=dtype), + self.evaluate(var1), + rtol=1e-5) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index f3e112c7b3..68b1fce477 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -457,5 +457,74 @@ class ResourceApplyFtrlV2 : public XlaOpKernel { REGISTER_XLA_OP(Name("ResourceApplyFtrlV2").TypeConstraint("T", kFloatTypes), ResourceApplyFtrlV2); +class ResourceApplyAdadelta : public XlaOpKernel { + public: + explicit ResourceApplyAdadelta(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape var_shape, accum_shape, accum_update_shape; + xla::XlaOp var, accum, accum_update; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); + OP_REQUIRES_OK(ctx, + ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &accum_update_shape, + &accum_update)); + + TensorShape lr_shape = ctx->InputShape(3); + TensorShape rho_shape = ctx->InputShape(4); + TensorShape epsilon_shape = ctx->InputShape(5); + TensorShape grad_shape = ctx->InputShape(6); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar: ", + lr_shape.DebugString())); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho_shape), + errors::InvalidArgument("rho is not a scalar: ", + rho_shape.DebugString())); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape), + errors::InvalidArgument("epsilon is not a scalar: ", + epsilon_shape.DebugString())); + + OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), + errors::InvalidArgument( + "var and accum do not have the same shape", + var_shape.DebugString(), " ", accum_shape.DebugString())); + + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument( + "var and grad do not have the same shape", + var_shape.DebugString(), " ", grad_shape.DebugString())); + + xla::XlaOp lr = ctx->Input(3); + xla::XlaOp rho = ctx->Input(4); + xla::XlaOp epsilon = ctx->Input(5); + xla::XlaOp grad = ctx->Input(6); + + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp neg_half = XlaHelpers::FloatLiteral(b, dtype_, -0.5); + xla::XlaOp half = XlaHelpers::FloatLiteral(b, dtype_, 0.5); + xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0); + xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype_, 2.0); + + accum = rho * accum + (one - rho) * xla::Pow(grad, two); + xla::XlaOp update = xla::Pow(accum_update + epsilon, half) * + xla::Pow(accum + epsilon, neg_half) * grad; + accum_update = rho * accum_update + (one - rho) * xla::Pow(update, two); + var = var - update * lr; + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, accum_update)); + } + + private: + DataType dtype_; +}; +REGISTER_XLA_OP(Name("ResourceApplyAdadelta").TypeConstraint("T", kFloatTypes), + ResourceApplyAdadelta); + } // namespace } // namespace tensorflow -- cgit v1.2.3