aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/training_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/training_ops.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/training_ops.cc582
1 files changed, 511 insertions, 71 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
index 2e5d61e111..98df730249 100644
--- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
@@ -16,8 +16,10 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/no_op.h"
@@ -47,7 +49,7 @@ class ResourceApplyGradientDescent : public XlaOpKernel {
var_shape.DebugString(), " vs ",
delta_shape.DebugString()));
- handle = xla::Sub(handle, xla::Mul(ctx->Input(1), ctx->Input(2)));
+ handle = handle - ctx->Input(1) * ctx->Input(2);
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
}
};
@@ -55,6 +57,64 @@ REGISTER_XLA_OP(
Name("ResourceApplyGradientDescent").TypeConstraint("T", kFloatTypes),
ResourceApplyGradientDescent);
+xla::XlaOp ProximalGradientDescentUpdate(xla::XlaOp var, xla::XlaOp lr,
+ xla::XlaOp l1, xla::XlaOp l2,
+ xla::XlaOp grad) {
+ xla::XlaOp one = xla::ScalarLike(lr, 1.0);
+ xla::XlaOp zero = xla::ScalarLike(lr, 0.0);
+ xla::XlaOp prox_var = var - grad * lr;
+ xla::XlaOp l1_gt_zero = xla::Sign(prox_var) *
+ xla::Max(xla::Abs(prox_var) - lr * l1, zero) /
+ (one + lr * l2);
+ xla::XlaOp l1_le_zero = prox_var / (one + lr * l2);
+ return xla::Select(xla::Gt(l1, zero), l1_gt_zero, l1_le_zero);
+}
+
+class ResourceApplyProximalGradientDescent : public XlaOpKernel {
+ public:
+ explicit ResourceApplyProximalGradientDescent(OpKernelConstruction* ctx)
+ : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::XlaOp var;
+ TensorShape var_shape;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
+
+ TensorShape alpha_shape = ctx->InputShape(1);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape),
+ errors::InvalidArgument("alpha is not a scalar: ",
+ alpha_shape.DebugString()));
+ TensorShape l1_shape = ctx->InputShape(2);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape),
+ errors::InvalidArgument("l1 is not a scalar: ",
+ l1_shape.DebugString()));
+ TensorShape l2_shape = ctx->InputShape(3);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape),
+ errors::InvalidArgument("l2 is not a scalar: ",
+ l2_shape.DebugString()));
+ TensorShape delta_shape = ctx->InputShape(4);
+ OP_REQUIRES(
+ ctx, var_shape.IsSameSize(delta_shape),
+ errors::InvalidArgument("var and delta do not have the same shape: ",
+ var_shape.DebugString(), " vs ",
+ delta_shape.DebugString()));
+ xla::XlaOp alpha = ctx->Input(1);
+ xla::XlaOp l1 = ctx->Input(2);
+ xla::XlaOp l2 = ctx->Input(3);
+ xla::XlaOp delta = ctx->Input(4);
+ var = ProximalGradientDescentUpdate(var, alpha, l1, l2, delta);
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
+ }
+
+ private:
+ DataType dtype_;
+};
+REGISTER_XLA_OP(Name("ResourceApplyProximalGradientDescent")
+ .TypeConstraint("T", kFloatTypes),
+ ResourceApplyProximalGradientDescent);
+
class ResourceApplyMomentum : public XlaOpKernel {
public:
explicit ResourceApplyMomentum(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
@@ -94,14 +154,13 @@ class ResourceApplyMomentum : public XlaOpKernel {
xla::XlaOp grad = ctx->Input(3);
xla::XlaOp momentum = ctx->Input(4);
- accum = xla::Add(xla::Mul(accum, momentum), grad);
+ accum = accum * momentum + grad;
if (use_nesterov_) {
// See https://github.com/tensorflow/tensorflow/pull/2798 for an
// explanation of the reparameterization used here.
- var = xla::Sub(var, xla::Add(xla::Mul(grad, lr),
- xla::Mul(xla::Mul(accum, momentum), lr)));
+ var = var - (grad * lr + accum * momentum * lr);
} else {
- var = xla::Sub(var, xla::Mul(accum, lr));
+ var = var - accum * lr;
}
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
@@ -118,8 +177,6 @@ class ResourceApplyAdagrad : public XlaOpKernel {
explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::XlaBuilder* b = ctx->builder();
-
DataType type = ctx->input_type(2);
TensorShape var_shape, accum_shape;
@@ -146,12 +203,8 @@ class ResourceApplyAdagrad : public XlaOpKernel {
xla::XlaOp lr = ctx->Input(2);
xla::XlaOp grad = ctx->Input(3);
- accum =
- xla::Add(accum, xla::Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0)));
- var = xla::Sub(
- var,
- xla::Mul(xla::Mul(grad, lr),
- xla::Pow(accum, XlaHelpers::FloatLiteral(b, type, -0.5))));
+ accum = accum + xla::Square(grad);
+ var = var - grad * lr * xla::Rsqrt(accum);
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
}
@@ -159,6 +212,139 @@ class ResourceApplyAdagrad : public XlaOpKernel {
REGISTER_XLA_OP(Name("ResourceApplyAdagrad").TypeConstraint("T", kFloatTypes),
ResourceApplyAdagrad);
+class ResourceApplyProximalAdagrad : public XlaOpKernel {
+ public:
+ explicit ResourceApplyProximalAdagrad(OpKernelConstruction* ctx)
+ : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ TensorShape var_shape, accum_shape;
+ xla::XlaOp var, accum;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
+ OP_REQUIRES_OK(ctx,
+ ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum));
+
+ OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
+ errors::InvalidArgument(
+ "var and accum do not have the same shape",
+ var_shape.DebugString(), " ", accum_shape.DebugString()));
+
+ TensorShape lr_shape = ctx->InputShape(2);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
+ errors::InvalidArgument("lr is not a scalar: ",
+ lr_shape.DebugString()));
+ TensorShape l1_shape = ctx->InputShape(3);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1_shape),
+ errors::InvalidArgument("l1 is not a scalar: ",
+ l1_shape.DebugString()));
+ TensorShape l2_shape = ctx->InputShape(4);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2_shape),
+ errors::InvalidArgument("l2 is not a scalar: ",
+ l2_shape.DebugString()));
+ TensorShape grad_shape = ctx->InputShape(5);
+ OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
+ errors::InvalidArgument(
+ "var and grad do not have the same shape: ",
+ var_shape.DebugString(), " vs ", grad_shape.DebugString()));
+
+ xla::XlaOp lr = ctx->Input(2);
+ xla::XlaOp l1 = ctx->Input(3);
+ xla::XlaOp l2 = ctx->Input(4);
+ xla::XlaOp grad = ctx->Input(5);
+ accum = accum + xla::Square(grad);
+ // Adagrad learning rate.
+ xla::XlaOp adagrad_lr = lr * xla::Rsqrt(accum);
+ var = ProximalGradientDescentUpdate(var, adagrad_lr, l1, l2, grad);
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum));
+ }
+
+ private:
+ DataType dtype_;
+};
+REGISTER_XLA_OP(
+ Name("ResourceApplyProximalAdagrad").TypeConstraint("T", kFloatTypes),
+ ResourceApplyProximalAdagrad);
+
+class ResourceApplyAdagradDA : public XlaOpKernel {
+ public:
+ explicit ResourceApplyAdagradDA(OpKernelConstruction* ctx)
+ : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ TensorShape var_shape, accum_shape, squared_accum_shape;
+ xla::XlaOp var, accum, squared_accum;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
+ OP_REQUIRES_OK(ctx,
+ ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &squared_accum_shape,
+ &squared_accum));
+ OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
+ errors::InvalidArgument(
+ "var and accum do not have the same shape",
+ var_shape.DebugString(), " ", accum_shape.DebugString()));
+ OP_REQUIRES(
+ ctx, var_shape.IsSameSize(squared_accum_shape),
+ errors::InvalidArgument(
+ "var and squared accum do not have the same shape",
+ var_shape.DebugString(), " ", squared_accum_shape.DebugString()));
+
+ TensorShape grad_shape = ctx->InputShape(3);
+ TensorShape lr_shape = ctx->InputShape(4);
+ TensorShape l1_shape = ctx->InputShape(5);
+ TensorShape l2_shape = ctx->InputShape(6);
+ TensorShape global_step_shape = ctx->InputShape(7);
+
+ OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
+ errors::InvalidArgument(
+ "var and grad do not have the same shape",
+ var_shape.DebugString(), " ", grad_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
+ errors::InvalidArgument("lr is not a scalar: ",
+ lr_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1_shape),
+ errors::InvalidArgument("l1 is not a scalar: ",
+ l1_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2_shape),
+ errors::InvalidArgument("l2 is not a scalar: ",
+ l2_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(global_step_shape),
+ errors::InvalidArgument("global step is not a scalar: ",
+ global_step_shape.DebugString()));
+
+ xla::XlaOp grad = ctx->Input(3);
+ xla::XlaOp lr = ctx->Input(4);
+ xla::XlaOp l1 = ctx->Input(5);
+ xla::XlaOp l2 = ctx->Input(6);
+ xla::XlaBuilder* const b = ctx->builder();
+ xla::XlaOp global_step =
+ XlaHelpers::ConvertElementType(b, ctx->Input(7), dtype_);
+
+ accum = accum + grad;
+ squared_accum = squared_accum + xla::Square(grad);
+ xla::XlaOp zero = xla::ScalarLike(lr, 0.0);
+ xla::XlaOp denominator = global_step * lr * l2 + xla::Sqrt(squared_accum);
+ xla::XlaOp l1_le_zero = -lr * accum / denominator;
+ xla::XlaOp l1_gt_zero = -lr * xla::Sign(accum) *
+ xla::Max(xla::Abs(accum) - global_step * l1, zero) /
+ denominator;
+
+ var = xla::Select(xla::Gt(l1, zero), l1_gt_zero, l1_le_zero);
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, squared_accum));
+ }
+
+ private:
+ DataType dtype_;
+};
+REGISTER_XLA_OP(Name("ResourceApplyAdagradDA").TypeConstraint("T", kFloatTypes),
+ ResourceApplyAdagradDA);
+
class ResourceApplyAdam : public XlaOpKernel {
public:
explicit ResourceApplyAdam(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
@@ -226,18 +412,12 @@ class ResourceApplyAdam : public XlaOpKernel {
// variable <- variable - alpha * m_t / (sqrt(v_t) + epsilon)
xla::XlaBuilder* b = ctx->builder();
- xla::XlaOp half = XlaHelpers::FloatLiteral(b, dtype_, 0.5);
xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0);
- xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype_, 2.0);
- xla::XlaOp alpha =
- xla::Div(xla::Mul(lr, xla::Pow(xla::Sub(one, beta2_power), half)),
- xla::Sub(one, beta1_power));
- m = xla::Add(m, xla::Mul(xla::Sub(grad, m), xla::Sub(one, beta1)));
- v = xla::Add(
- v, xla::Mul(xla::Sub(xla::Pow(grad, two), v), xla::Sub(one, beta2)));
- var = xla::Sub(var, xla::Div(xla::Mul(m, alpha),
- xla::Add(xla::Pow(v, half), epsilon)));
+ xla::XlaOp alpha = lr * xla::Sqrt(one - beta2_power) / (one - beta1_power);
+ m = m + (grad - m) * (one - beta1);
+ v = v + (xla::Square(grad) - v) * (one - beta2);
+ var = var - m * alpha / (xla::Sqrt(v) + epsilon);
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m));
@@ -250,38 +430,112 @@ class ResourceApplyAdam : public XlaOpKernel {
REGISTER_XLA_OP(Name("ResourceApplyAdam").TypeConstraint("T", kFloatTypes),
ResourceApplyAdam);
-class ResourceApplyRMSProp : public XlaOpKernel {
+class ResourceApplyAdaMax : public XlaOpKernel {
public:
- explicit ResourceApplyRMSProp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ explicit ResourceApplyAdaMax(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
void Compile(XlaOpKernelContext* ctx) override {
- xla::XlaBuilder* b = ctx->builder();
+ TensorShape var_shape, m_shape, v_shape;
+ xla::XlaOp var, m, v;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &v_shape, &v));
- DataType type = ctx->input_type(3);
+ TensorShape beta1_power_shape = ctx->InputShape(3);
+ TensorShape lr_shape = ctx->InputShape(4);
+ TensorShape beta1_shape = ctx->InputShape(5);
+ TensorShape beta2_shape = ctx->InputShape(6);
+ TensorShape epsilon_shape = ctx->InputShape(7);
+ TensorShape grad_shape = ctx->InputShape(8);
- TensorShape var_shape, ms_shape, mom_shape;
- xla::XlaOp var, ms, mom;
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &ms_shape, &ms));
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, type, &mom_shape, &mom));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power_shape),
+ errors::InvalidArgument("beta1_power is not a scalar: ",
+ beta1_power_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
+ errors::InvalidArgument("lr is not a scalar : ",
+ lr_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_shape),
+ errors::InvalidArgument("beta1 is not a scalar: ",
+ beta1_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_shape),
+ errors::InvalidArgument("beta2 is not a scalar: ",
+ beta2_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape),
+ errors::InvalidArgument("epsilon is not a scalar: ",
+ epsilon_shape.DebugString()));
+ OP_REQUIRES(ctx, var_shape.IsSameSize(m_shape),
+ errors::InvalidArgument("var and m do not have the same shape",
+ var_shape.DebugString(), " ",
+ m_shape.DebugString()));
+ OP_REQUIRES(ctx, var_shape.IsSameSize(v_shape),
+ errors::InvalidArgument("var and v do not have the same shape",
+ var_shape.DebugString(), " ",
+ v_shape.DebugString()));
+ OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
+ errors::InvalidArgument(
+ "var and grad do not have the same shape",
+ var_shape.DebugString(), " ", grad_shape.DebugString()));
- TensorShape lr_shape = ctx->InputShape(3);
+ xla::XlaOp beta1_power = ctx->Input(3);
+ xla::XlaOp lr = ctx->Input(4);
+ xla::XlaOp beta1 = ctx->Input(5);
+ xla::XlaOp beta2 = ctx->Input(6);
+ xla::XlaOp epsilon = ctx->Input(7);
+ xla::XlaOp grad = ctx->Input(8);
+
+ xla::XlaOp one = xla::ScalarLike(lr, 1.0);
+ m = beta1 * m + (one - beta1) * grad;
+ v = xla::Max(beta2 * v, xla::Abs(grad));
+ var = var - lr / (one - beta1_power) * (m / (v + epsilon));
+
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, v));
+ }
+
+ private:
+ DataType dtype_;
+};
+REGISTER_XLA_OP(Name("ResourceApplyAdaMax").TypeConstraint("T", kFloatTypes),
+ ResourceApplyAdaMax);
+
+class ResourceApplyRMSProp : public XlaOpKernel {
+ public:
+ explicit ResourceApplyRMSProp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ TensorShape var_shape, ms_shape, mom_shape, mg_shape;
+ xla::XlaOp var, ms, mom, mg;
+ OP_REQUIRES_OK(ctx,
+ ctx->ReadVariableInput("var", dtype_, &var_shape, &var));
+ if (centered_) {
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput("mg", dtype_, &mg_shape, &mg));
+ }
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput("ms", dtype_, &ms_shape, &ms));
+ OP_REQUIRES_OK(ctx,
+ ctx->ReadVariableInput("mom", dtype_, &mom_shape, &mom));
+
+ TensorShape lr_shape = ctx->InputShape("lr");
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
errors::InvalidArgument("lr is not a scalar: ",
lr_shape.DebugString()));
- TensorShape rho_shape = ctx->InputShape(4);
+ TensorShape rho_shape = ctx->InputShape("rho");
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho_shape),
errors::InvalidArgument("rho is not a scalar: ",
rho_shape.DebugString()));
- TensorShape momentum_shape = ctx->InputShape(5);
+ TensorShape momentum_shape = ctx->InputShape("momentum");
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum_shape),
errors::InvalidArgument("momentum is not a scalar: ",
momentum_shape.DebugString()));
- TensorShape epsilon_shape = ctx->InputShape(6);
+ TensorShape epsilon_shape = ctx->InputShape("epsilon");
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape),
errors::InvalidArgument("epsilon is not a scalar: ",
epsilon_shape.DebugString()));
- TensorShape grad_shape = ctx->InputShape(7);
+ TensorShape grad_shape = ctx->InputShape("grad");
// var should be the same shape as mom and ms.
OP_REQUIRES(ctx, var_shape.IsSameSize(ms_shape),
@@ -297,11 +551,11 @@ class ResourceApplyRMSProp : public XlaOpKernel {
"var and grad do not have the same shape",
var_shape.DebugString(), " ", grad_shape.DebugString()));
- xla::XlaOp lr = ctx->Input(3);
- xla::XlaOp rho = ctx->Input(4);
- xla::XlaOp momentum = ctx->Input(5);
- xla::XlaOp epsilon = ctx->Input(6);
- xla::XlaOp grad = ctx->Input(7);
+ xla::XlaOp lr = ctx->Input("lr");
+ xla::XlaOp rho = ctx->Input("rho");
+ xla::XlaOp momentum = ctx->Input("momentum");
+ xla::XlaOp epsilon = ctx->Input("epsilon");
+ xla::XlaOp grad = ctx->Input("grad");
// ms <- rho * ms_{t-1} + (1-rho) * grad * grad
// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
@@ -320,26 +574,46 @@ class ResourceApplyRMSProp : public XlaOpKernel {
// ms <- grad**2 (1 - rho) + ms * rho
//
// Which is the equation listed above.
- xla::XlaOp new_ms = xla::Add(
- ms, xla::Mul(
- xla::Sub(xla::Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0)),
- ms),
- xla::Sub(XlaHelpers::FloatLiteral(b, type, 1.0), rho)));
- xla::XlaOp new_mom =
- xla::Add(xla::Mul(mom, momentum),
- xla::Mul(xla::Mul(grad, lr),
- xla::Pow(xla::Add(new_ms, epsilon),
- XlaHelpers::FloatLiteral(b, type, -0.5))));
- xla::XlaOp new_var = xla::Sub(var, new_mom);
-
- OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, new_var));
- OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, new_ms));
- OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, type, new_mom));
+ xla::XlaOp one = xla::ScalarLike(ms, 1.0);
+ xla::XlaOp new_ms = xla::Square(grad) * (one - rho) + ms * rho;
+ xla::XlaOp denominator;
+ if (centered_) {
+ mg = grad * (one - rho) + mg * rho;
+ denominator = new_ms - xla::Square(mg) + epsilon;
+ } else {
+ denominator = new_ms + epsilon;
+ }
+ xla::XlaOp new_mom = mom * momentum + grad * lr * xla::Rsqrt(denominator);
+ xla::XlaOp new_var = var - new_mom;
+
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable("var", dtype_, new_var));
+ if (centered_) {
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable("mg", dtype_, mg));
+ }
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable("ms", dtype_, new_ms));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable("mom", dtype_, new_mom));
}
+
+ protected:
+ bool centered_ = false;
+
+ private:
+ DataType dtype_;
};
REGISTER_XLA_OP(Name("ResourceApplyRMSProp").TypeConstraint("T", kFloatTypes),
ResourceApplyRMSProp);
+class ResourceApplyCenteredRMSProp : public ResourceApplyRMSProp {
+ public:
+ explicit ResourceApplyCenteredRMSProp(OpKernelConstruction* ctx)
+ : ResourceApplyRMSProp(ctx) {
+ centered_ = true;
+ }
+};
+REGISTER_XLA_OP(
+ Name("ResourceApplyCenteredRMSProp").TypeConstraint("T", kFloatTypes),
+ ResourceApplyCenteredRMSProp);
+
void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
bool has_l2_shrinkage) {
xla::XlaBuilder* b = ctx->builder();
@@ -425,23 +699,18 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype, 2.0);
xla::XlaOp grad_to_use;
if (has_l2_shrinkage) {
- grad_to_use = xla::Add(grad, xla::Mul(two, xla::Mul(l2_shrinkage, var)));
+ grad_to_use = grad + two * l2_shrinkage * var;
} else {
grad_to_use = grad;
}
- xla::XlaOp new_accum = xla::Add(accum, xla::Pow(grad_to_use, two));
- xla::XlaOp new_accum_lr_pow = xla::Pow(new_accum, xla::Neg(lr_power));
- xla::XlaOp accum_lr_pow = xla::Pow(accum, xla::Neg(lr_power));
- linear = xla::Add(
- linear,
- xla::Sub(grad_to_use,
- xla::Mul(xla::Div(xla::Sub(new_accum_lr_pow, accum_lr_pow), lr),
- var)));
- xla::XlaOp linear_clipped = xla::Clamp(xla::Neg(l1), linear, l1);
- xla::XlaOp quadratic =
- xla::Add(xla::Div(new_accum_lr_pow, lr), xla::Mul(two, l2));
- var = xla::Div(xla::Sub(linear_clipped, linear), quadratic);
+ xla::XlaOp new_accum = accum + xla::Square(grad_to_use);
+ xla::XlaOp new_accum_lr_pow = xla::Pow(new_accum, -lr_power);
+ xla::XlaOp accum_lr_pow = xla::Pow(accum, -lr_power);
+ linear = linear + grad_to_use - (new_accum_lr_pow - accum_lr_pow) / lr * var;
+ xla::XlaOp linear_clipped = xla::Clamp(-l1, linear, l1);
+ xla::XlaOp quadratic = new_accum_lr_pow / lr + two * l2;
+ var = (linear_clipped - linear) / quadratic;
accum = new_accum;
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype, var));
@@ -481,5 +750,176 @@ class ResourceApplyFtrlV2 : public XlaOpKernel {
REGISTER_XLA_OP(Name("ResourceApplyFtrlV2").TypeConstraint("T", kFloatTypes),
ResourceApplyFtrlV2);
+class ResourceApplyAdadelta : public XlaOpKernel {
+ public:
+ explicit ResourceApplyAdadelta(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ TensorShape var_shape, accum_shape, accum_update_shape;
+ xla::XlaOp var, accum, accum_update;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
+ OP_REQUIRES_OK(ctx,
+ ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &accum_update_shape,
+ &accum_update));
+
+ TensorShape lr_shape = ctx->InputShape(3);
+ TensorShape rho_shape = ctx->InputShape(4);
+ TensorShape epsilon_shape = ctx->InputShape(5);
+ TensorShape grad_shape = ctx->InputShape(6);
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
+ errors::InvalidArgument("lr is not a scalar: ",
+ lr_shape.DebugString()));
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho_shape),
+ errors::InvalidArgument("rho is not a scalar: ",
+ rho_shape.DebugString()));
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape),
+ errors::InvalidArgument("epsilon is not a scalar: ",
+ epsilon_shape.DebugString()));
+
+ OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
+ errors::InvalidArgument(
+ "var and accum do not have the same shape",
+ var_shape.DebugString(), " ", accum_shape.DebugString()));
+
+ OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
+ errors::InvalidArgument(
+ "var and grad do not have the same shape",
+ var_shape.DebugString(), " ", grad_shape.DebugString()));
+
+ xla::XlaOp lr = ctx->Input(3);
+ xla::XlaOp rho = ctx->Input(4);
+ xla::XlaOp epsilon = ctx->Input(5);
+ xla::XlaOp grad = ctx->Input(6);
+
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp neg_half = XlaHelpers::FloatLiteral(b, dtype_, -0.5);
+ xla::XlaOp half = XlaHelpers::FloatLiteral(b, dtype_, 0.5);
+ xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0);
+ xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype_, 2.0);
+
+ accum = rho * accum + (one - rho) * xla::Pow(grad, two);
+ xla::XlaOp update = xla::Pow(accum_update + epsilon, half) *
+ xla::Pow(accum + epsilon, neg_half) * grad;
+ accum_update = rho * accum_update + (one - rho) * xla::Pow(update, two);
+ var = var - update * lr;
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, accum_update));
+ }
+
+ private:
+ DataType dtype_;
+};
+REGISTER_XLA_OP(Name("ResourceApplyAdadelta").TypeConstraint("T", kFloatTypes),
+ ResourceApplyAdadelta);
+
+class ResourceApplySignBase : public XlaOpKernel {
+ public:
+ explicit ResourceApplySignBase(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ TensorShape var_shape, m_shape;
+ xla::XlaOp var, m;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m));
+ OP_REQUIRES(ctx, var_shape.IsSameSize(m_shape),
+ errors::InvalidArgument("var and m do not have the same shape",
+ var_shape.DebugString(), " ",
+ m_shape.DebugString()));
+ TensorShape grad_shape = ctx->InputShape(6);
+ OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
+ errors::InvalidArgument(
+ "var and grad do not have the same shape",
+ var_shape.DebugString(), " ", grad_shape.DebugString()));
+ CheckScalarParams(ctx);
+
+ xla::XlaOp lr = ctx->Input(2);
+ xla::XlaOp alpha = ctx->Input(3);
+ xla::XlaOp sign_decay = ctx->Input(4);
+ xla::XlaOp beta = ctx->Input(5);
+ xla::XlaOp grad = ctx->Input(6);
+
+ m = m * beta + grad * (xla::ScalarLike(beta, 1.0) - beta);
+ xla::XlaOp decay = xla::Sign(grad) * xla::Sign(m) * sign_decay;
+
+ xla::XlaOp grad_scale = ComputeGradientScale(alpha, decay);
+ var = var - lr * grad_scale * grad;
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m));
+ }
+
+ virtual void CheckScalarParams(XlaOpKernelContext* ctx) {
+ TensorShape lr_shape = ctx->InputShape(2);
+ TensorShape sign_decay_shape = ctx->InputShape(4);
+ TensorShape beta_shape = ctx->InputShape(5);
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
+ errors::InvalidArgument("lr is not a scalar: ",
+ lr_shape.DebugString()));
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(sign_decay_shape),
+ errors::InvalidArgument("sign_decay is not a scalar: ",
+ sign_decay_shape.DebugString()));
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta_shape),
+ errors::InvalidArgument("beta is not a scalar: ",
+ beta_shape.DebugString()));
+ }
+
+ virtual xla::XlaOp ComputeGradientScale(xla::XlaOp alpha,
+ xla::XlaOp decay) = 0;
+
+ private:
+ DataType dtype_;
+};
+
+class ResourceApplyAddSign : public ResourceApplySignBase {
+ public:
+ explicit ResourceApplyAddSign(OpKernelConstruction* ctx)
+ : ResourceApplySignBase(ctx) {}
+
+ void CheckScalarParams(XlaOpKernelContext* ctx) override {
+ ResourceApplySignBase::CheckScalarParams(ctx);
+ TensorShape alpha_shape = ctx->InputShape(3);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape),
+ errors::InvalidArgument("alpha is not a scalar: ",
+ alpha_shape.DebugString()));
+ }
+
+ xla::XlaOp ComputeGradientScale(xla::XlaOp alpha, xla::XlaOp decay) override {
+ return alpha + decay;
+ }
+};
+REGISTER_XLA_OP(Name("ResourceApplyAddSign").TypeConstraint("T", kFloatTypes),
+ ResourceApplyAddSign);
+
+class ResourceApplyPowerSign : public ResourceApplySignBase {
+ public:
+ explicit ResourceApplyPowerSign(OpKernelConstruction* ctx)
+ : ResourceApplySignBase(ctx) {}
+
+ void CheckScalarParams(XlaOpKernelContext* ctx) override {
+ ResourceApplySignBase::CheckScalarParams(ctx);
+ TensorShape logbase_shape = ctx->InputShape(3);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(logbase_shape),
+ errors::InvalidArgument("logbase is not a scalar: ",
+ logbase_shape.DebugString()));
+ }
+
+ xla::XlaOp ComputeGradientScale(xla::XlaOp alpha, xla::XlaOp decay) override {
+ return xla::Exp(alpha * decay);
+ }
+};
+REGISTER_XLA_OP(Name("ResourceApplyPowerSign").TypeConstraint("T", kFloatTypes),
+ ResourceApplyPowerSign);
+
} // namespace
} // namespace tensorflow