diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/training_ops.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/training_ops.cc | 582 |
1 files changed, 511 insertions, 71 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index 2e5d61e111..98df730249 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -16,8 +16,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/no_op.h" @@ -47,7 +49,7 @@ class ResourceApplyGradientDescent : public XlaOpKernel { var_shape.DebugString(), " vs ", delta_shape.DebugString())); - handle = xla::Sub(handle, xla::Mul(ctx->Input(1), ctx->Input(2))); + handle = handle - ctx->Input(1) * ctx->Input(2); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); } }; @@ -55,6 +57,64 @@ REGISTER_XLA_OP( Name("ResourceApplyGradientDescent").TypeConstraint("T", kFloatTypes), ResourceApplyGradientDescent); +xla::XlaOp ProximalGradientDescentUpdate(xla::XlaOp var, xla::XlaOp lr, + xla::XlaOp l1, xla::XlaOp l2, + xla::XlaOp grad) { + xla::XlaOp one = xla::ScalarLike(lr, 1.0); + xla::XlaOp zero = xla::ScalarLike(lr, 0.0); + xla::XlaOp prox_var = var - grad * lr; + xla::XlaOp l1_gt_zero = xla::Sign(prox_var) * + xla::Max(xla::Abs(prox_var) - lr * l1, zero) / + (one + lr * l2); + xla::XlaOp l1_le_zero = prox_var / (one + lr * l2); + return xla::Select(xla::Gt(l1, zero), l1_gt_zero, l1_le_zero); +} + +class ResourceApplyProximalGradientDescent : public XlaOpKernel { + public: + explicit ResourceApplyProximalGradientDescent(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp var; + TensorShape var_shape; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); + + TensorShape alpha_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape), + errors::InvalidArgument("alpha is not a scalar: ", + alpha_shape.DebugString())); + TensorShape l1_shape = ctx->InputShape(2); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape), + errors::InvalidArgument("l1 is not a scalar: ", + l1_shape.DebugString())); + TensorShape l2_shape = ctx->InputShape(3); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape), + errors::InvalidArgument("l2 is not a scalar: ", + l2_shape.DebugString())); + TensorShape delta_shape = ctx->InputShape(4); + OP_REQUIRES( + ctx, var_shape.IsSameSize(delta_shape), + errors::InvalidArgument("var and delta do not have the same shape: ", + var_shape.DebugString(), " vs ", + delta_shape.DebugString())); + xla::XlaOp alpha = ctx->Input(1); + xla::XlaOp l1 = ctx->Input(2); + xla::XlaOp l2 = ctx->Input(3); + xla::XlaOp delta = ctx->Input(4); + var = ProximalGradientDescentUpdate(var, alpha, l1, l2, delta); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); + } + + private: + DataType dtype_; +}; +REGISTER_XLA_OP(Name("ResourceApplyProximalGradientDescent") + .TypeConstraint("T", kFloatTypes), + ResourceApplyProximalGradientDescent); + class ResourceApplyMomentum : public XlaOpKernel { public: explicit ResourceApplyMomentum(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -94,14 +154,13 @@ class ResourceApplyMomentum : public XlaOpKernel { xla::XlaOp grad = ctx->Input(3); xla::XlaOp momentum = ctx->Input(4); - accum = xla::Add(xla::Mul(accum, momentum), grad); + accum = accum * momentum + grad; if (use_nesterov_) { // See https://github.com/tensorflow/tensorflow/pull/2798 for an // explanation of the reparameterization used here. - var = xla::Sub(var, xla::Add(xla::Mul(grad, lr), - xla::Mul(xla::Mul(accum, momentum), lr))); + var = var - (grad * lr + accum * momentum * lr); } else { - var = xla::Sub(var, xla::Mul(accum, lr)); + var = var - accum * lr; } OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum)); @@ -118,8 +177,6 @@ class ResourceApplyAdagrad : public XlaOpKernel { explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* b = ctx->builder(); - DataType type = ctx->input_type(2); TensorShape var_shape, accum_shape; @@ -146,12 +203,8 @@ class ResourceApplyAdagrad : public XlaOpKernel { xla::XlaOp lr = ctx->Input(2); xla::XlaOp grad = ctx->Input(3); - accum = - xla::Add(accum, xla::Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0))); - var = xla::Sub( - var, - xla::Mul(xla::Mul(grad, lr), - xla::Pow(accum, XlaHelpers::FloatLiteral(b, type, -0.5)))); + accum = accum + xla::Square(grad); + var = var - grad * lr * xla::Rsqrt(accum); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum)); } @@ -159,6 +212,139 @@ class ResourceApplyAdagrad : public XlaOpKernel { REGISTER_XLA_OP(Name("ResourceApplyAdagrad").TypeConstraint("T", kFloatTypes), ResourceApplyAdagrad); +class ResourceApplyProximalAdagrad : public XlaOpKernel { + public: + explicit ResourceApplyProximalAdagrad(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape var_shape, accum_shape; + xla::XlaOp var, accum; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); + OP_REQUIRES_OK(ctx, + ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum)); + + OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), + errors::InvalidArgument( + "var and accum do not have the same shape", + var_shape.DebugString(), " ", accum_shape.DebugString())); + + TensorShape lr_shape = ctx->InputShape(2); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar: ", + lr_shape.DebugString())); + TensorShape l1_shape = ctx->InputShape(3); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1_shape), + errors::InvalidArgument("l1 is not a scalar: ", + l1_shape.DebugString())); + TensorShape l2_shape = ctx->InputShape(4); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2_shape), + errors::InvalidArgument("l2 is not a scalar: ", + l2_shape.DebugString())); + TensorShape grad_shape = ctx->InputShape(5); + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument( + "var and grad do not have the same shape: ", + var_shape.DebugString(), " vs ", grad_shape.DebugString())); + + xla::XlaOp lr = ctx->Input(2); + xla::XlaOp l1 = ctx->Input(3); + xla::XlaOp l2 = ctx->Input(4); + xla::XlaOp grad = ctx->Input(5); + accum = accum + xla::Square(grad); + // Adagrad learning rate. + xla::XlaOp adagrad_lr = lr * xla::Rsqrt(accum); + var = ProximalGradientDescentUpdate(var, adagrad_lr, l1, l2, grad); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum)); + } + + private: + DataType dtype_; +}; +REGISTER_XLA_OP( + Name("ResourceApplyProximalAdagrad").TypeConstraint("T", kFloatTypes), + ResourceApplyProximalAdagrad); + +class ResourceApplyAdagradDA : public XlaOpKernel { + public: + explicit ResourceApplyAdagradDA(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape var_shape, accum_shape, squared_accum_shape; + xla::XlaOp var, accum, squared_accum; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); + OP_REQUIRES_OK(ctx, + ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &squared_accum_shape, + &squared_accum)); + OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), + errors::InvalidArgument( + "var and accum do not have the same shape", + var_shape.DebugString(), " ", accum_shape.DebugString())); + OP_REQUIRES( + ctx, var_shape.IsSameSize(squared_accum_shape), + errors::InvalidArgument( + "var and squared accum do not have the same shape", + var_shape.DebugString(), " ", squared_accum_shape.DebugString())); + + TensorShape grad_shape = ctx->InputShape(3); + TensorShape lr_shape = ctx->InputShape(4); + TensorShape l1_shape = ctx->InputShape(5); + TensorShape l2_shape = ctx->InputShape(6); + TensorShape global_step_shape = ctx->InputShape(7); + + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument( + "var and grad do not have the same shape", + var_shape.DebugString(), " ", grad_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar: ", + lr_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1_shape), + errors::InvalidArgument("l1 is not a scalar: ", + l1_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2_shape), + errors::InvalidArgument("l2 is not a scalar: ", + l2_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(global_step_shape), + errors::InvalidArgument("global step is not a scalar: ", + global_step_shape.DebugString())); + + xla::XlaOp grad = ctx->Input(3); + xla::XlaOp lr = ctx->Input(4); + xla::XlaOp l1 = ctx->Input(5); + xla::XlaOp l2 = ctx->Input(6); + xla::XlaBuilder* const b = ctx->builder(); + xla::XlaOp global_step = + XlaHelpers::ConvertElementType(b, ctx->Input(7), dtype_); + + accum = accum + grad; + squared_accum = squared_accum + xla::Square(grad); + xla::XlaOp zero = xla::ScalarLike(lr, 0.0); + xla::XlaOp denominator = global_step * lr * l2 + xla::Sqrt(squared_accum); + xla::XlaOp l1_le_zero = -lr * accum / denominator; + xla::XlaOp l1_gt_zero = -lr * xla::Sign(accum) * + xla::Max(xla::Abs(accum) - global_step * l1, zero) / + denominator; + + var = xla::Select(xla::Gt(l1, zero), l1_gt_zero, l1_le_zero); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, squared_accum)); + } + + private: + DataType dtype_; +}; +REGISTER_XLA_OP(Name("ResourceApplyAdagradDA").TypeConstraint("T", kFloatTypes), + ResourceApplyAdagradDA); + class ResourceApplyAdam : public XlaOpKernel { public: explicit ResourceApplyAdam(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -226,18 +412,12 @@ class ResourceApplyAdam : public XlaOpKernel { // variable <- variable - alpha * m_t / (sqrt(v_t) + epsilon) xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp half = XlaHelpers::FloatLiteral(b, dtype_, 0.5); xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0); - xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype_, 2.0); - xla::XlaOp alpha = - xla::Div(xla::Mul(lr, xla::Pow(xla::Sub(one, beta2_power), half)), - xla::Sub(one, beta1_power)); - m = xla::Add(m, xla::Mul(xla::Sub(grad, m), xla::Sub(one, beta1))); - v = xla::Add( - v, xla::Mul(xla::Sub(xla::Pow(grad, two), v), xla::Sub(one, beta2))); - var = xla::Sub(var, xla::Div(xla::Mul(m, alpha), - xla::Add(xla::Pow(v, half), epsilon))); + xla::XlaOp alpha = lr * xla::Sqrt(one - beta2_power) / (one - beta1_power); + m = m + (grad - m) * (one - beta1); + v = v + (xla::Square(grad) - v) * (one - beta2); + var = var - m * alpha / (xla::Sqrt(v) + epsilon); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m)); @@ -250,38 +430,112 @@ class ResourceApplyAdam : public XlaOpKernel { REGISTER_XLA_OP(Name("ResourceApplyAdam").TypeConstraint("T", kFloatTypes), ResourceApplyAdam); -class ResourceApplyRMSProp : public XlaOpKernel { +class ResourceApplyAdaMax : public XlaOpKernel { public: - explicit ResourceApplyRMSProp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + explicit ResourceApplyAdaMax(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* b = ctx->builder(); + TensorShape var_shape, m_shape, v_shape; + xla::XlaOp var, m, v; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &v_shape, &v)); - DataType type = ctx->input_type(3); + TensorShape beta1_power_shape = ctx->InputShape(3); + TensorShape lr_shape = ctx->InputShape(4); + TensorShape beta1_shape = ctx->InputShape(5); + TensorShape beta2_shape = ctx->InputShape(6); + TensorShape epsilon_shape = ctx->InputShape(7); + TensorShape grad_shape = ctx->InputShape(8); - TensorShape var_shape, ms_shape, mom_shape; - xla::XlaOp var, ms, mom; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &ms_shape, &ms)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, type, &mom_shape, &mom)); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power_shape), + errors::InvalidArgument("beta1_power is not a scalar: ", + beta1_power_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar : ", + lr_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_shape), + errors::InvalidArgument("beta1 is not a scalar: ", + beta1_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_shape), + errors::InvalidArgument("beta2 is not a scalar: ", + beta2_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape), + errors::InvalidArgument("epsilon is not a scalar: ", + epsilon_shape.DebugString())); + OP_REQUIRES(ctx, var_shape.IsSameSize(m_shape), + errors::InvalidArgument("var and m do not have the same shape", + var_shape.DebugString(), " ", + m_shape.DebugString())); + OP_REQUIRES(ctx, var_shape.IsSameSize(v_shape), + errors::InvalidArgument("var and v do not have the same shape", + var_shape.DebugString(), " ", + v_shape.DebugString())); + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument( + "var and grad do not have the same shape", + var_shape.DebugString(), " ", grad_shape.DebugString())); - TensorShape lr_shape = ctx->InputShape(3); + xla::XlaOp beta1_power = ctx->Input(3); + xla::XlaOp lr = ctx->Input(4); + xla::XlaOp beta1 = ctx->Input(5); + xla::XlaOp beta2 = ctx->Input(6); + xla::XlaOp epsilon = ctx->Input(7); + xla::XlaOp grad = ctx->Input(8); + + xla::XlaOp one = xla::ScalarLike(lr, 1.0); + m = beta1 * m + (one - beta1) * grad; + v = xla::Max(beta2 * v, xla::Abs(grad)); + var = var - lr / (one - beta1_power) * (m / (v + epsilon)); + + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, v)); + } + + private: + DataType dtype_; +}; +REGISTER_XLA_OP(Name("ResourceApplyAdaMax").TypeConstraint("T", kFloatTypes), + ResourceApplyAdaMax); + +class ResourceApplyRMSProp : public XlaOpKernel { + public: + explicit ResourceApplyRMSProp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape var_shape, ms_shape, mom_shape, mg_shape; + xla::XlaOp var, ms, mom, mg; + OP_REQUIRES_OK(ctx, + ctx->ReadVariableInput("var", dtype_, &var_shape, &var)); + if (centered_) { + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput("mg", dtype_, &mg_shape, &mg)); + } + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput("ms", dtype_, &ms_shape, &ms)); + OP_REQUIRES_OK(ctx, + ctx->ReadVariableInput("mom", dtype_, &mom_shape, &mom)); + + TensorShape lr_shape = ctx->InputShape("lr"); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), errors::InvalidArgument("lr is not a scalar: ", lr_shape.DebugString())); - TensorShape rho_shape = ctx->InputShape(4); + TensorShape rho_shape = ctx->InputShape("rho"); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho_shape), errors::InvalidArgument("rho is not a scalar: ", rho_shape.DebugString())); - TensorShape momentum_shape = ctx->InputShape(5); + TensorShape momentum_shape = ctx->InputShape("momentum"); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum_shape), errors::InvalidArgument("momentum is not a scalar: ", momentum_shape.DebugString())); - TensorShape epsilon_shape = ctx->InputShape(6); + TensorShape epsilon_shape = ctx->InputShape("epsilon"); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape), errors::InvalidArgument("epsilon is not a scalar: ", epsilon_shape.DebugString())); - TensorShape grad_shape = ctx->InputShape(7); + TensorShape grad_shape = ctx->InputShape("grad"); // var should be the same shape as mom and ms. OP_REQUIRES(ctx, var_shape.IsSameSize(ms_shape), @@ -297,11 +551,11 @@ class ResourceApplyRMSProp : public XlaOpKernel { "var and grad do not have the same shape", var_shape.DebugString(), " ", grad_shape.DebugString())); - xla::XlaOp lr = ctx->Input(3); - xla::XlaOp rho = ctx->Input(4); - xla::XlaOp momentum = ctx->Input(5); - xla::XlaOp epsilon = ctx->Input(6); - xla::XlaOp grad = ctx->Input(7); + xla::XlaOp lr = ctx->Input("lr"); + xla::XlaOp rho = ctx->Input("rho"); + xla::XlaOp momentum = ctx->Input("momentum"); + xla::XlaOp epsilon = ctx->Input("epsilon"); + xla::XlaOp grad = ctx->Input("grad"); // ms <- rho * ms_{t-1} + (1-rho) * grad * grad // mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) @@ -320,26 +574,46 @@ class ResourceApplyRMSProp : public XlaOpKernel { // ms <- grad**2 (1 - rho) + ms * rho // // Which is the equation listed above. - xla::XlaOp new_ms = xla::Add( - ms, xla::Mul( - xla::Sub(xla::Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0)), - ms), - xla::Sub(XlaHelpers::FloatLiteral(b, type, 1.0), rho))); - xla::XlaOp new_mom = - xla::Add(xla::Mul(mom, momentum), - xla::Mul(xla::Mul(grad, lr), - xla::Pow(xla::Add(new_ms, epsilon), - XlaHelpers::FloatLiteral(b, type, -0.5)))); - xla::XlaOp new_var = xla::Sub(var, new_mom); - - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, new_var)); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, new_ms)); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, type, new_mom)); + xla::XlaOp one = xla::ScalarLike(ms, 1.0); + xla::XlaOp new_ms = xla::Square(grad) * (one - rho) + ms * rho; + xla::XlaOp denominator; + if (centered_) { + mg = grad * (one - rho) + mg * rho; + denominator = new_ms - xla::Square(mg) + epsilon; + } else { + denominator = new_ms + epsilon; + } + xla::XlaOp new_mom = mom * momentum + grad * lr * xla::Rsqrt(denominator); + xla::XlaOp new_var = var - new_mom; + + OP_REQUIRES_OK(ctx, ctx->AssignVariable("var", dtype_, new_var)); + if (centered_) { + OP_REQUIRES_OK(ctx, ctx->AssignVariable("mg", dtype_, mg)); + } + OP_REQUIRES_OK(ctx, ctx->AssignVariable("ms", dtype_, new_ms)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable("mom", dtype_, new_mom)); } + + protected: + bool centered_ = false; + + private: + DataType dtype_; }; REGISTER_XLA_OP(Name("ResourceApplyRMSProp").TypeConstraint("T", kFloatTypes), ResourceApplyRMSProp); +class ResourceApplyCenteredRMSProp : public ResourceApplyRMSProp { + public: + explicit ResourceApplyCenteredRMSProp(OpKernelConstruction* ctx) + : ResourceApplyRMSProp(ctx) { + centered_ = true; + } +}; +REGISTER_XLA_OP( + Name("ResourceApplyCenteredRMSProp").TypeConstraint("T", kFloatTypes), + ResourceApplyCenteredRMSProp); + void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, bool has_l2_shrinkage) { xla::XlaBuilder* b = ctx->builder(); @@ -425,23 +699,18 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype, 2.0); xla::XlaOp grad_to_use; if (has_l2_shrinkage) { - grad_to_use = xla::Add(grad, xla::Mul(two, xla::Mul(l2_shrinkage, var))); + grad_to_use = grad + two * l2_shrinkage * var; } else { grad_to_use = grad; } - xla::XlaOp new_accum = xla::Add(accum, xla::Pow(grad_to_use, two)); - xla::XlaOp new_accum_lr_pow = xla::Pow(new_accum, xla::Neg(lr_power)); - xla::XlaOp accum_lr_pow = xla::Pow(accum, xla::Neg(lr_power)); - linear = xla::Add( - linear, - xla::Sub(grad_to_use, - xla::Mul(xla::Div(xla::Sub(new_accum_lr_pow, accum_lr_pow), lr), - var))); - xla::XlaOp linear_clipped = xla::Clamp(xla::Neg(l1), linear, l1); - xla::XlaOp quadratic = - xla::Add(xla::Div(new_accum_lr_pow, lr), xla::Mul(two, l2)); - var = xla::Div(xla::Sub(linear_clipped, linear), quadratic); + xla::XlaOp new_accum = accum + xla::Square(grad_to_use); + xla::XlaOp new_accum_lr_pow = xla::Pow(new_accum, -lr_power); + xla::XlaOp accum_lr_pow = xla::Pow(accum, -lr_power); + linear = linear + grad_to_use - (new_accum_lr_pow - accum_lr_pow) / lr * var; + xla::XlaOp linear_clipped = xla::Clamp(-l1, linear, l1); + xla::XlaOp quadratic = new_accum_lr_pow / lr + two * l2; + var = (linear_clipped - linear) / quadratic; accum = new_accum; OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype, var)); @@ -481,5 +750,176 @@ class ResourceApplyFtrlV2 : public XlaOpKernel { REGISTER_XLA_OP(Name("ResourceApplyFtrlV2").TypeConstraint("T", kFloatTypes), ResourceApplyFtrlV2); +class ResourceApplyAdadelta : public XlaOpKernel { + public: + explicit ResourceApplyAdadelta(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape var_shape, accum_shape, accum_update_shape; + xla::XlaOp var, accum, accum_update; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); + OP_REQUIRES_OK(ctx, + ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &accum_update_shape, + &accum_update)); + + TensorShape lr_shape = ctx->InputShape(3); + TensorShape rho_shape = ctx->InputShape(4); + TensorShape epsilon_shape = ctx->InputShape(5); + TensorShape grad_shape = ctx->InputShape(6); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar: ", + lr_shape.DebugString())); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho_shape), + errors::InvalidArgument("rho is not a scalar: ", + rho_shape.DebugString())); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape), + errors::InvalidArgument("epsilon is not a scalar: ", + epsilon_shape.DebugString())); + + OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), + errors::InvalidArgument( + "var and accum do not have the same shape", + var_shape.DebugString(), " ", accum_shape.DebugString())); + + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument( + "var and grad do not have the same shape", + var_shape.DebugString(), " ", grad_shape.DebugString())); + + xla::XlaOp lr = ctx->Input(3); + xla::XlaOp rho = ctx->Input(4); + xla::XlaOp epsilon = ctx->Input(5); + xla::XlaOp grad = ctx->Input(6); + + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp neg_half = XlaHelpers::FloatLiteral(b, dtype_, -0.5); + xla::XlaOp half = XlaHelpers::FloatLiteral(b, dtype_, 0.5); + xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0); + xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype_, 2.0); + + accum = rho * accum + (one - rho) * xla::Pow(grad, two); + xla::XlaOp update = xla::Pow(accum_update + epsilon, half) * + xla::Pow(accum + epsilon, neg_half) * grad; + accum_update = rho * accum_update + (one - rho) * xla::Pow(update, two); + var = var - update * lr; + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, accum_update)); + } + + private: + DataType dtype_; +}; +REGISTER_XLA_OP(Name("ResourceApplyAdadelta").TypeConstraint("T", kFloatTypes), + ResourceApplyAdadelta); + +class ResourceApplySignBase : public XlaOpKernel { + public: + explicit ResourceApplySignBase(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape var_shape, m_shape; + xla::XlaOp var, m; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m)); + OP_REQUIRES(ctx, var_shape.IsSameSize(m_shape), + errors::InvalidArgument("var and m do not have the same shape", + var_shape.DebugString(), " ", + m_shape.DebugString())); + TensorShape grad_shape = ctx->InputShape(6); + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument( + "var and grad do not have the same shape", + var_shape.DebugString(), " ", grad_shape.DebugString())); + CheckScalarParams(ctx); + + xla::XlaOp lr = ctx->Input(2); + xla::XlaOp alpha = ctx->Input(3); + xla::XlaOp sign_decay = ctx->Input(4); + xla::XlaOp beta = ctx->Input(5); + xla::XlaOp grad = ctx->Input(6); + + m = m * beta + grad * (xla::ScalarLike(beta, 1.0) - beta); + xla::XlaOp decay = xla::Sign(grad) * xla::Sign(m) * sign_decay; + + xla::XlaOp grad_scale = ComputeGradientScale(alpha, decay); + var = var - lr * grad_scale * grad; + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m)); + } + + virtual void CheckScalarParams(XlaOpKernelContext* ctx) { + TensorShape lr_shape = ctx->InputShape(2); + TensorShape sign_decay_shape = ctx->InputShape(4); + TensorShape beta_shape = ctx->InputShape(5); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar: ", + lr_shape.DebugString())); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(sign_decay_shape), + errors::InvalidArgument("sign_decay is not a scalar: ", + sign_decay_shape.DebugString())); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta_shape), + errors::InvalidArgument("beta is not a scalar: ", + beta_shape.DebugString())); + } + + virtual xla::XlaOp ComputeGradientScale(xla::XlaOp alpha, + xla::XlaOp decay) = 0; + + private: + DataType dtype_; +}; + +class ResourceApplyAddSign : public ResourceApplySignBase { + public: + explicit ResourceApplyAddSign(OpKernelConstruction* ctx) + : ResourceApplySignBase(ctx) {} + + void CheckScalarParams(XlaOpKernelContext* ctx) override { + ResourceApplySignBase::CheckScalarParams(ctx); + TensorShape alpha_shape = ctx->InputShape(3); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape), + errors::InvalidArgument("alpha is not a scalar: ", + alpha_shape.DebugString())); + } + + xla::XlaOp ComputeGradientScale(xla::XlaOp alpha, xla::XlaOp decay) override { + return alpha + decay; + } +}; +REGISTER_XLA_OP(Name("ResourceApplyAddSign").TypeConstraint("T", kFloatTypes), + ResourceApplyAddSign); + +class ResourceApplyPowerSign : public ResourceApplySignBase { + public: + explicit ResourceApplyPowerSign(OpKernelConstruction* ctx) + : ResourceApplySignBase(ctx) {} + + void CheckScalarParams(XlaOpKernelContext* ctx) override { + ResourceApplySignBase::CheckScalarParams(ctx); + TensorShape logbase_shape = ctx->InputShape(3); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(logbase_shape), + errors::InvalidArgument("logbase is not a scalar: ", + logbase_shape.DebugString())); + } + + xla::XlaOp ComputeGradientScale(xla::XlaOp alpha, xla::XlaOp decay) override { + return xla::Exp(alpha * decay); + } +}; +REGISTER_XLA_OP(Name("ResourceApplyPowerSign").TypeConstraint("T", kFloatTypes), + ResourceApplyPowerSign); + } // namespace } // namespace tensorflow |