aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/tests/BUILD27
-rw-r--r--tensorflow/compiler/tests/adagrad_da_test.py165
-rw-r--r--tensorflow/compiler/tests/adamax_test.py139
-rw-r--r--tensorflow/compiler/tf2xla/kernels/training_ops.cc148
4 files changed, 479 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 69ff0d99cb..8b25147899 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -98,6 +98,19 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "adagrad_da_test",
+ size = "small",
+ srcs = ["adagrad_da_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:training",
+ ],
+)
+
+tf_xla_py_test(
name = "adam_test",
size = "small",
srcs = ["adam_test.py"],
@@ -112,6 +125,20 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "adamax_test",
+ size = "small",
+ srcs = ["adamax_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/contrib/opt:opt_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:training",
+ ],
+)
+
+tf_xla_py_test(
name = "addsign_test",
size = "small",
srcs = ["addsign_test.py"],
diff --git a/tensorflow/compiler/tests/adagrad_da_test.py b/tensorflow/compiler/tests/adagrad_da_test.py
new file mode 100644
index 0000000000..dc1625793a
--- /dev/null
+++ b/tensorflow/compiler/tests/adagrad_da_test.py
@@ -0,0 +1,165 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for AdagradDA optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import adagrad_da
+
+
+class AdagradDAOptimizerTest(xla_test.XLATestCase):
+
+ def testAdagradDAWithoutRegularizationBasic1(self):
+ for dtype in self.float_types:
+ with self.test_session(), self.test_scope():
+ global_step = resource_variable_ops.ResourceVariable(
+ 0, dtype=dtypes.int64)
+ var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
+ opt = adagrad_da.AdagradDAOptimizer(
+ 3.0,
+ global_step,
+ initial_gradient_squared_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0)
+ update = opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]), global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ self.assertAllClose([0.0, 0.0], var0.eval())
+ self.assertAllClose([0.0, 0.0], var1.eval())
+
+ # Run a step of AdagradDA
+ update.run()
+
+ # Let g to be gradient accumulator, gg to be gradient squared
+ # accumulator, T be the global step, lr is the learning rate, and k the
+ # initial gradient squared accumulator value.
+ # w = \dfrac{sign(-g)*lr*|g - l1*T|_{+}}{l2*T*lr + \sqrt{k+gg})}
+ # For -0.1*3.0*(0.1 - 0)/(0 + sqrt(0.1 + 0.1*0.1)) = -0.904534
+ # similarly for others.
+ self.assertAllCloseAccordingToType(
+ np.array([-0.904534, -1.603567]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([-0.094821, -0.189358]), var1.eval())
+
+ def testAdagradDAwithoutRegularizationBasic2(self):
+ for dtype in self.float_types:
+ with self.test_session(), self.test_scope():
+ global_step = resource_variable_ops.ResourceVariable(
+ 0, dtype=dtypes.int64)
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
+
+ opt = adagrad_da.AdagradDAOptimizer(
+ 3.0,
+ global_step,
+ initial_gradient_squared_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0)
+ update = opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]), global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
+ self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval())
+
+ # Run a step of AdagradDA
+ update.run()
+
+ self.assertAllCloseAccordingToType(
+ np.array([-0.904534, -1.603567]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([-0.094821, -0.189358]), var1.eval())
+
+ def testAdagradDAWithL1(self):
+ for dtype in self.float_types:
+ with self.test_session(), self.test_scope():
+ global_step = resource_variable_ops.ResourceVariable(
+ 0, dtype=dtypes.int64)
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
+
+ opt = adagrad_da.AdagradDAOptimizer(
+ 3.0,
+ global_step,
+ initial_gradient_squared_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=0.0)
+ update = opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]), global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
+ self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval())
+
+ # Run a step of AdagradDA
+ update.run()
+
+ self.assertAllCloseAccordingToType(
+ np.array([-0.895489, -1.59555]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([-0.085339, -0.17989]), var1.eval())
+
+ def testAdagradDAWithL1_L2(self):
+ for dtype in self.float_types:
+ with self.test_session(), self.test_scope():
+ global_step = resource_variable_ops.ResourceVariable(
+ 0, dtype=dtypes.int64)
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
+
+ opt = adagrad_da.AdagradDAOptimizer(
+ 3.0,
+ global_step,
+ initial_gradient_squared_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=2.0)
+ update = opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]), global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
+ self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval())
+
+ # Run a step of AdagradDA
+ update.run()
+
+ self.assertAllCloseAccordingToType(
+ np.array([-0.046907, -0.093659]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([-0.004275, -0.009023]), var1.eval())
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tests/adamax_test.py b/tensorflow/compiler/tests/adamax_test.py
new file mode 100644
index 0000000000..c4fdbc5974
--- /dev/null
+++ b/tensorflow/compiler/tests/adamax_test.py
@@ -0,0 +1,139 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for AdaMax optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.contrib.opt.python.training import adamax
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+def adamax_update_numpy(param,
+ g_t,
+ t,
+ m,
+ v,
+ alpha=0.001,
+ beta1=0.9,
+ beta2=0.999,
+ epsilon=1e-8):
+ m_t = beta1 * m + (1 - beta1) * g_t
+ v_t = np.maximum(beta2 * v, np.abs(g_t))
+ param_t = param - (alpha / (1 - beta1**t)) * (m_t / (v_t + epsilon))
+ return param_t, m_t, v_t
+
+
+class AdaMaxOptimizerTest(xla_test.XLATestCase):
+
+ def testBasic(self):
+ for i, dtype in enumerate(self.float_types):
+ with self.test_session(), self.test_scope():
+ variable_scope.get_variable_scope().set_use_resource(True)
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype)
+
+ var0 = resource_variable_ops.ResourceVariable(
+ var0_np, name="var0_%d" % i)
+ var1 = resource_variable_ops.ResourceVariable(
+ var1_np, name="var1_%d" % i)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+
+ opt = adamax.AdaMaxOptimizer()
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ opt_variables = opt.variables()
+ beta1_power = opt._get_beta_accumulators()
+ self.assertTrue(beta1_power is not None)
+ self.assertIn(beta1_power, opt_variables)
+
+ with ops.Graph().as_default():
+ # Shouldn't return non-slot variables from other graphs.
+ self.assertEqual(0, len(opt.variables()))
+
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ beta1_power = opt._get_beta_accumulators()
+
+ # Run 3 steps of AdaMax
+ for t in range(1, 4):
+ update.run()
+
+ self.assertAllCloseAccordingToType(0.9**(t + 1), beta1_power.eval())
+
+ var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, var0.eval(), rtol=1e-2)
+ self.assertAllCloseAccordingToType(var1_np, var1.eval(), rtol=1e-2)
+ self.assertEqual("var0_%d/AdaMax:0" % (i,),
+ opt.get_slot(var=var0, name="m").name)
+
+ def testTensorLearningRate(self):
+ for dtype in self.float_types:
+ with self.test_session(), self.test_scope():
+ variable_scope.get_variable_scope().set_use_resource(True)
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype)
+
+ var0 = resource_variable_ops.ResourceVariable(var0_np)
+ var1 = resource_variable_ops.ResourceVariable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+ opt = adamax.AdaMaxOptimizer(constant_op.constant(0.001))
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ beta1_power = opt._get_beta_accumulators()
+
+ # Run 3 steps of AdaMax
+ for t in range(1, 4):
+ self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
+ update.run()
+
+ var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
index a1877ebf7a..03902f012c 100644
--- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
@@ -268,6 +268,83 @@ REGISTER_XLA_OP(
Name("ResourceApplyProximalAdagrad").TypeConstraint("T", kFloatTypes),
ResourceApplyProximalAdagrad);
+class ResourceApplyAdagradDA : public XlaOpKernel {
+ public:
+ explicit ResourceApplyAdagradDA(OpKernelConstruction* ctx)
+ : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ TensorShape var_shape, accum_shape, squared_accum_shape;
+ xla::XlaOp var, accum, squared_accum;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
+ OP_REQUIRES_OK(ctx,
+ ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &squared_accum_shape,
+ &squared_accum));
+ OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
+ errors::InvalidArgument(
+ "var and accum do not have the same shape",
+ var_shape.DebugString(), " ", accum_shape.DebugString()));
+ OP_REQUIRES(
+ ctx, var_shape.IsSameSize(squared_accum_shape),
+ errors::InvalidArgument(
+ "var and squared accum do not have the same shape",
+ var_shape.DebugString(), " ", squared_accum_shape.DebugString()));
+
+ TensorShape grad_shape = ctx->InputShape(3);
+ TensorShape lr_shape = ctx->InputShape(4);
+ TensorShape l1_shape = ctx->InputShape(5);
+ TensorShape l2_shape = ctx->InputShape(6);
+ TensorShape global_step_shape = ctx->InputShape(7);
+
+ OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
+ errors::InvalidArgument(
+ "var and grad do not have the same shape",
+ var_shape.DebugString(), " ", grad_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
+ errors::InvalidArgument("lr is not a scalar: ",
+ lr_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1_shape),
+ errors::InvalidArgument("l1 is not a scalar: ",
+ l1_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2_shape),
+ errors::InvalidArgument("l2 is not a scalar: ",
+ l2_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(global_step_shape),
+ errors::InvalidArgument("global step is not a scalar: ",
+ global_step_shape.DebugString()));
+
+ xla::XlaOp grad = ctx->Input(3);
+ xla::XlaOp lr = ctx->Input(4);
+ xla::XlaOp l1 = ctx->Input(5);
+ xla::XlaOp l2 = ctx->Input(6);
+ xla::XlaBuilder* const b = ctx->builder();
+ xla::XlaOp global_step =
+ XlaHelpers::ConvertElementType(b, ctx->Input(7), dtype_);
+
+ accum = accum + grad;
+ squared_accum = squared_accum + xla::Square(grad);
+ xla::XlaOp zero = xla::ScalarLike(lr, 0.0);
+ xla::XlaOp denominator = global_step * lr * l2 + xla::Sqrt(squared_accum);
+ xla::XlaOp l1_le_zero = -lr * accum / denominator;
+ xla::XlaOp l1_gt_zero = -lr * xla::Sign(accum) *
+ xla::Max(xla::Abs(accum) - global_step * l1, zero) /
+ denominator;
+
+ var = xla::Select(xla::Gt(l1, zero), l1_gt_zero, l1_le_zero);
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, squared_accum));
+ }
+
+ private:
+ DataType dtype_;
+};
+REGISTER_XLA_OP(Name("ResourceApplyAdagradDA").TypeConstraint("T", kFloatTypes),
+ ResourceApplyAdagradDA);
+
class ResourceApplyAdam : public XlaOpKernel {
public:
explicit ResourceApplyAdam(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
@@ -353,6 +430,77 @@ class ResourceApplyAdam : public XlaOpKernel {
REGISTER_XLA_OP(Name("ResourceApplyAdam").TypeConstraint("T", kFloatTypes),
ResourceApplyAdam);
+class ResourceApplyAdaMax : public XlaOpKernel {
+ public:
+ explicit ResourceApplyAdaMax(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ TensorShape var_shape, m_shape, v_shape;
+ xla::XlaOp var, m, v;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &v_shape, &v));
+
+ TensorShape beta1_power_shape = ctx->InputShape(3);
+ TensorShape lr_shape = ctx->InputShape(4);
+ TensorShape beta1_shape = ctx->InputShape(5);
+ TensorShape beta2_shape = ctx->InputShape(6);
+ TensorShape epsilon_shape = ctx->InputShape(7);
+ TensorShape grad_shape = ctx->InputShape(8);
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power_shape),
+ errors::InvalidArgument("beta1_power is not a scalar: ",
+ beta1_power_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
+ errors::InvalidArgument("lr is not a scalar : ",
+ lr_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_shape),
+ errors::InvalidArgument("beta1 is not a scalar: ",
+ beta1_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_shape),
+ errors::InvalidArgument("beta2 is not a scalar: ",
+ beta2_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape),
+ errors::InvalidArgument("epsilon is not a scalar: ",
+ epsilon_shape.DebugString()));
+ OP_REQUIRES(ctx, var_shape.IsSameSize(m_shape),
+ errors::InvalidArgument("var and m do not have the same shape",
+ var_shape.DebugString(), " ",
+ m_shape.DebugString()));
+ OP_REQUIRES(ctx, var_shape.IsSameSize(v_shape),
+ errors::InvalidArgument("var and v do not have the same shape",
+ var_shape.DebugString(), " ",
+ v_shape.DebugString()));
+ OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
+ errors::InvalidArgument(
+ "var and grad do not have the same shape",
+ var_shape.DebugString(), " ", grad_shape.DebugString()));
+
+ xla::XlaOp beta1_power = ctx->Input(3);
+ xla::XlaOp lr = ctx->Input(4);
+ xla::XlaOp beta1 = ctx->Input(5);
+ xla::XlaOp beta2 = ctx->Input(6);
+ xla::XlaOp epsilon = ctx->Input(7);
+ xla::XlaOp grad = ctx->Input(8);
+
+ xla::XlaOp one = xla::ScalarLike(lr, 1.0);
+ m = beta1 * m + (one - beta1) * grad;
+ v = xla::Max(beta2 * v, xla::Abs(grad));
+ var = var - lr / (one - beta1_power) * (m / (v + epsilon));
+
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, v));
+ }
+
+ private:
+ DataType dtype_;
+};
+REGISTER_XLA_OP(Name("ResourceApplyAdaMax").TypeConstraint("T", kFloatTypes),
+ ResourceApplyAdaMax);
+
class ResourceApplyRMSProp : public XlaOpKernel {
public:
explicit ResourceApplyRMSProp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}