aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/training_ops.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-08-18 19:40:58 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-18 20:47:26 -0700
commit5da979895d63ed6c0933e96c9dacdc30bf05a09c (patch)
tree2e7e45c23ede25a407a1d192a9ebcc7c13b54c09 /tensorflow/core/kernels/training_ops.cc
parent0eae3401b731295342af3387bb8586a6b1f70274 (diff)
Adagrad Dual Averaging optimizer for sparse linear models, that takes care of lazy updates correctly.
Change: 130714247
Diffstat (limited to 'tensorflow/core/kernels/training_ops.cc')
-rw-r--r--tensorflow/core/kernels/training_ops.cc363
1 files changed, 339 insertions, 24 deletions
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc
index ca7231e6cc..bac8d8c80d 100644
--- a/tensorflow/core/kernels/training_ops.cc
+++ b/tensorflow/core/kernels/training_ops.cc
@@ -1,4 +1,4 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -18,15 +18,13 @@ limitations under the License.
#include "tensorflow/core/kernels/training_ops.h"
#include <algorithm>
#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/kernels/bounds_check.h"
-
-#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/kernels/bounds_check.h"
namespace tensorflow {
-typedef Eigen::ThreadPoolDevice CPUDevice;
-typedef Eigen::GpuDevice GPUDevice;
+using CPUDevice = Eigen::ThreadPoolDevice;
+using GPUDevice = Eigen::GpuDevice;
namespace {
template <class T>
@@ -95,6 +93,43 @@ struct ApplyProximalGradientDescent<CPUDevice, T> {
};
template <typename T>
+struct ApplyAdagradDA<CPUDevice, T> {
+ void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
+ typename TTypes<T>::Flat gradient_accum,
+ typename TTypes<T>::Flat gradient_squared_accum,
+ typename TTypes<T>::ConstScalar lr, int64 global_step,
+ typename TTypes<T>::ConstScalar l1,
+ typename TTypes<T>::ConstScalar l2,
+ typename TTypes<T>::ConstFlat grad) {
+ // Accumulate gradient, and gradient_squared
+ gradient_accum.device(d) += grad;
+ gradient_squared_accum.device(d) += grad.square();
+
+ // AdagradDA update:
+ // Let g to be gradient accumulator, gg to be gradient squared accumulator,
+ // T be the global step, lr is the learning rate, and k the initial
+ // gradient squared accumulator value.
+ // w = \dfrac{sign(-g)*lr*|g - l1*T|_{+}}{l2*T*lr + \sqrt{k+gg})}
+ if (l1() > 0) {
+ var.device(d) =
+ lr() * var.constant(-1.0) * gradient_accum.sign() *
+ (gradient_accum.abs() -
+ var.constant(static_cast<float>(global_step)) * var.constant(l1()))
+ .cwiseMax(T(0.0)) /
+ (var.constant(l2()) *
+ var.constant(static_cast<float>(global_step) * lr()) +
+ gradient_squared_accum.sqrt());
+ } else {
+ var.device(d) =
+ lr() * gradient_accum * var.constant(-1.0) /
+ (var.constant(l2()) *
+ var.constant(static_cast<float>(global_step) * lr()) +
+ gradient_squared_accum.sqrt());
+ }
+ }
+};
+
+template <typename T>
struct ApplyAdagrad<CPUDevice, T> {
void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
typename TTypes<T>::Flat accum,
@@ -413,8 +448,8 @@ class ApplyAdadeltaOp : public OpKernel {
}
};
-typedef Eigen::ThreadPoolDevice CPUDevice;
-typedef Eigen::GpuDevice GPUDevice;
+using CPUDevice = Eigen::ThreadPoolDevice;
+using GPUDevice = Eigen::GpuDevice;
#define REGISTER_KERNELS(D, T) \
REGISTER_KERNEL_BUILDER( \
@@ -698,6 +733,9 @@ class SparseApplyProximalGradientDescentOp : public OpKernel {
ctx, grad.dim_size(0) == N,
errors::InvalidArgument(
"grad must be the same size as indices in the first dimension."));
+ OP_REQUIRES(ctx, inner_dim > 0,
+ errors::InvalidArgument(
+ "Inner dimension should be greater than zero."));
if (N > 0) {
if (inner_dim > 1) {
@@ -735,7 +773,6 @@ class SparseApplyProximalGradientDescentOp : public OpKernel {
}
}
} else {
- CHECK_EQ(1, inner_dim);
auto indices_vec = indices.vec<Tindex>();
auto var_flat = var.flat<T>();
auto grad_flat = grad.flat<T>();
@@ -834,8 +871,8 @@ class ApplyAdagradOp : public OpKernel {
bool use_exclusive_lock_;
};
-typedef Eigen::ThreadPoolDevice CPUDevice;
-typedef Eigen::GpuDevice GPUDevice;
+using CPUDevice = Eigen::ThreadPoolDevice;
+using GPUDevice = Eigen::GpuDevice;
#define REGISTER_KERNELS(D, T) \
REGISTER_KERNEL_BUILDER( \
@@ -928,8 +965,8 @@ class ApplyProximalAdagradOp : public OpKernel {
bool use_exclusive_lock_;
};
-typedef Eigen::ThreadPoolDevice CPUDevice;
-typedef Eigen::GpuDevice GPUDevice;
+using CPUDevice = Eigen::ThreadPoolDevice;
+using GPUDevice = Eigen::GpuDevice;
#define REGISTER_KERNELS(D, T) \
REGISTER_KERNEL_BUILDER( \
@@ -1010,6 +1047,10 @@ class SparseApplyAdagradOp : public OpKernel {
errors::InvalidArgument(
"grad must be the same size as indices in the first dimension."));
+ OP_REQUIRES(ctx, inner_dim > 0,
+ errors::InvalidArgument(
+ "Inner dimension should be greater than zero."));
+
if (N > 0) {
if (inner_dim > 1) {
const Tindex first_dim_size = var.dim_size(0);
@@ -1034,7 +1075,6 @@ class SparseApplyAdagradOp : public OpKernel {
v -= g.constant(lr_scalar) * g * a.rsqrt();
}
} else {
- CHECK_EQ(1, inner_dim);
auto indices_vec = indices.vec<Tindex>();
auto var_flat = var.flat<T>();
auto accum_flat = accum.flat<T>();
@@ -1142,6 +1182,10 @@ class SparseApplyProximalAdagradOp : public OpKernel {
errors::InvalidArgument(
"grad must be the same size as indices in the first dimension."));
+ OP_REQUIRES(ctx, inner_dim > 0,
+ errors::InvalidArgument(
+ "Inner dimension should be greater than zero."));
+
if (N > 0) {
if (inner_dim > 1) {
const Tindex first_dim_size = var.dim_size(0);
@@ -1180,7 +1224,6 @@ class SparseApplyProximalAdagradOp : public OpKernel {
}
}
} else {
- CHECK_EQ(1, inner_dim);
auto indices_vec = indices.vec<Tindex>();
auto var_flat = var.flat<T>();
auto accum_flat = accum.flat<T>();
@@ -1237,6 +1280,275 @@ REGISTER_KERNELS(double, int64);
#undef REGISTER_KERNELS
template <typename Device, typename T>
+class ApplyAdagradDAOp : public OpKernel {
+ public:
+ explicit ApplyAdagradDAOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
+ Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
+ Tensor gradient_accum = ctx->mutable_input(1, use_exclusive_lock_);
+ Tensor gradient_squared_accum = ctx->mutable_input(2, use_exclusive_lock_);
+ OP_REQUIRES(
+ ctx, var.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", def().input(0)));
+ OP_REQUIRES(
+ ctx, gradient_accum.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", def().input(1)));
+ OP_REQUIRES(
+ ctx, gradient_squared_accum.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", def().input(2)));
+ OP_REQUIRES(
+ ctx, var.shape().IsSameSize(gradient_accum.shape()),
+ errors::InvalidArgument("var and accum do not have the same shape",
+ var.shape().DebugString(), " ",
+ gradient_accum.shape().DebugString()));
+ OP_REQUIRES(
+ ctx, var.shape().IsSameSize(gradient_squared_accum.shape()),
+ errors::InvalidArgument("var and accum do not have the same shape",
+ var.shape().DebugString(), " ",
+ gradient_squared_accum.shape().DebugString()));
+
+ const Tensor& grad = ctx->input(3);
+ OP_REQUIRES(
+ ctx, var.shape().IsSameSize(grad.shape()),
+ errors::InvalidArgument("var and grad do not have the same shape",
+ var.shape().DebugString(), " ",
+ grad.shape().DebugString()));
+
+ const Tensor& lr = ctx->input(4);
+ OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()),
+ errors::InvalidArgument("lr is not a scalar: ",
+ lr.shape().DebugString()));
+ const Tensor& l1 = ctx->input(5);
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsScalar(l1.shape()),
+ errors::InvalidArgument("l1 regularization strength is not a scalar: ",
+ l1.shape().DebugString()));
+ const Tensor& l2 = ctx->input(6);
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsScalar(l2.shape()),
+ errors::InvalidArgument("l2 regularization strength is not a scalar: ",
+ l2.shape().DebugString()));
+ const Tensor& global_step = ctx->input(7);
+ OP_REQUIRES(ctx, IsLegacyScalar(global_step.shape()),
+ errors::InvalidArgument("global_step is not a scalar: ",
+ global_step.shape().DebugString()));
+
+ const Device& device = ctx->template eigen_device<Device>();
+ functor::ApplyAdagradDA<Device, T>()(
+ device, var.flat<T>(), gradient_accum.flat<T>(),
+ gradient_squared_accum.flat<T>(), lr.scalar<T>(),
+ global_step.scalar<int64>()(), l1.scalar<T>(), l2.scalar<T>(),
+ grad.flat<T>());
+
+ ctx->forward_ref_input_to_ref_output(0, 0);
+ }
+
+ private:
+ bool use_exclusive_lock_;
+};
+
+using CPUDevice = Eigen::ThreadPoolDevice;
+using GPUDevice = Eigen::GpuDevice;
+
+#define REGISTER_KERNELS(D, T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ApplyAdagradDA").Device(DEVICE_##D).TypeConstraint<T>("T"), \
+ ApplyAdagradDAOp<D##Device, T>);
+
+REGISTER_KERNELS(CPU, float);
+REGISTER_KERNELS(CPU, double);
+#undef REGISTER_KERNELS
+
+// Note, this op works on cpu only.
+template <typename T, typename Tindex>
+class SparseApplyAdagradDAOp : public OpKernel {
+ public:
+ explicit SparseApplyAdagradDAOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
+ }
+
+ void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
+ auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
+ Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
+ Tensor gradient_accum = ctx->mutable_input(1, use_exclusive_lock_);
+ Tensor gradient_squared_accum = ctx->mutable_input(2, use_exclusive_lock_);
+ OP_REQUIRES(
+ ctx, var.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", def().input(0)));
+ OP_REQUIRES(
+ ctx, gradient_accum.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", def().input(1)));
+ OP_REQUIRES(
+ ctx, gradient_squared_accum.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", def().input(2)));
+ OP_REQUIRES(
+ ctx, var.shape().IsSameSize(gradient_accum.shape()),
+ errors::InvalidArgument("var and accum do not have the same shape",
+ var.shape().DebugString(), " ",
+ gradient_accum.shape().DebugString()));
+ OP_REQUIRES(
+ ctx, var.shape().IsSameSize(gradient_squared_accum.shape()),
+ errors::InvalidArgument("var and accum do not have the same shape",
+ var.shape().DebugString(), " ",
+ gradient_squared_accum.shape().DebugString()));
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
+ errors::InvalidArgument("var must be at least 1 dimensional"));
+
+ const Tensor& grad = ctx->input(3);
+ const Tensor& indices = ctx->input(4);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
+ errors::InvalidArgument("indices must be one-dimensional"));
+
+ const Tensor& lr = ctx->input(5);
+ OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()),
+ errors::InvalidArgument("lr is not a scalar: ",
+ lr.shape().DebugString()));
+
+ const Tensor& l1 = ctx->input(6);
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsScalar(l1.shape()),
+ errors::InvalidArgument("l1 regularization strength is not a scalar: ",
+ l1.shape().DebugString()));
+
+ const Tensor& l2 = ctx->input(7);
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsScalar(l2.shape()),
+ errors::InvalidArgument("l2 regularization strength is not a scalar: ",
+ l2.shape().DebugString()));
+
+ const Tensor& global_step = ctx->input(8);
+ OP_REQUIRES(ctx, IsLegacyScalar(global_step.shape()),
+ errors::InvalidArgument("global_step is not a scalar: ",
+ global_step.shape().DebugString()));
+
+ int64 inner_dim = 1;
+ for (int d = 1; d < var.dims(); d++) {
+ OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d),
+ errors::InvalidArgument(strings::StrCat(
+ "var and grad must match in dimension ", d)));
+ inner_dim *= grad.dim_size(d);
+ }
+ const Tindex N = indices.dim_size(0);
+ OP_REQUIRES(
+ ctx, grad.dim_size(0) == N,
+ errors::InvalidArgument(
+ "grad must be the same size as indices in the first dimension."));
+
+ OP_REQUIRES(ctx, inner_dim > 0,
+ errors::InvalidArgument(
+ "Inner dimension should be greater than zero."));
+
+ // AdagradDA update:
+ // Let g to be gradient accumulator, gg to be gradient squared accumulator,
+ // T be the global step, lr is the learning rate, and k the initial
+ // gradient squared accumulator value.
+ // w = \dfrac{sign(-g)*lr*|g - l1*T|_{+}}{l2*T*lr + \sqrt{k+gg})}
+ if (N > 0) {
+ if (inner_dim > 1) {
+ const Tindex first_dim_size = var.dim_size(0);
+ auto indices_vec = indices.vec<Tindex>();
+ auto var_flat = var.flat_outer_dims<T>();
+ auto gradient_accum_flat = gradient_accum.flat_outer_dims<T>();
+ auto gradient_squared_accum_flat =
+ gradient_squared_accum.flat_outer_dims<T>();
+ auto grad_flat = grad.flat_outer_dims<T>();
+ T lr_scalar = lr.scalar<T>()();
+ T global_step_scalar = global_step.scalar<int64>()();
+ T l1_scalar = l1.scalar<T>()();
+ T l2_scalar = l2.scalar<T>()();
+ const double gs_lr = global_step_scalar * lr_scalar;
+
+ for (Tindex i = 0; i < N; i++) {
+ const Tindex index = internal::SubtleMustCopy(indices_vec(i));
+ OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size),
+ errors::InvalidArgument(
+ strings::StrCat("Index ", index, " at offset ", i,
+ " in indices is out of range")));
+ auto ga = gradient_accum_flat.template chip<0>(index);
+ auto da = gradient_squared_accum_flat.template chip<0>(index);
+ auto g = grad_flat.template chip<0>(i);
+ auto v = var_flat.template chip<0>(index);
+ ga += g;
+ da += g.square();
+ if (l1_scalar > 0) {
+ v = (ga.abs() / ga.constant(global_step_scalar)) -
+ ga.constant(l1_scalar);
+ v = ga.constant(-1.0) * ga.sign() *
+ v.cwiseMax(static_cast<T>(0.0)) /
+ (v.constant(l2_scalar) + da.sqrt() / v.constant(gs_lr));
+ } else {
+ v = ga.constant(-1.0) * (ga / ga.constant(global_step_scalar)) /
+ (v.constant(l2_scalar) + da.sqrt() / v.constant(gs_lr));
+ }
+ }
+ } else {
+ auto indices_vec = indices.vec<Tindex>();
+ auto var_flat = var.flat<T>();
+ auto gradient_accum_flat = gradient_accum.flat<T>();
+ auto gradient_squared_accum_flat = gradient_squared_accum.flat<T>();
+ auto grad_flat = grad.flat<T>();
+ const double lr_scalar = lr.scalar<T>()();
+ const int64 global_step_scalar = global_step.scalar<int64>()();
+ const double l1_scalar = l1.scalar<T>()();
+ const double l2_scalar = l2.scalar<T>()();
+ const Tindex first_dim_size = var_flat.size();
+ const double gs_l1 = global_step_scalar * l1_scalar;
+ const double gs_l2_lr = global_step_scalar * l2_scalar * lr_scalar;
+
+ for (Tindex i = 0; i < N; i++) {
+ const Tindex index = internal::SubtleMustCopy(indices_vec(i));
+ OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size),
+ errors::InvalidArgument(
+ strings::StrCat("Index ", index, " at offset ", i,
+ " in indices is out of range")));
+ T& ga = gradient_accum_flat(index);
+ T& da = gradient_squared_accum_flat(index);
+ const double g = grad_flat(i);
+ ga += g;
+ da += g * g;
+ if (l1_scalar > 0) {
+ var_flat(index) = sgn(-ga) * lr_scalar *
+ std::max((std::abs(ga) - gs_l1), 0.0) /
+ (gs_l2_lr + std::sqrt(da));
+ } else {
+ var_flat(index) = (-ga * lr_scalar) / (gs_l2_lr + std::sqrt(da));
+ }
+ }
+ }
+ }
+
+ ctx->forward_ref_input_to_ref_output(0, 0);
+ }
+
+ private:
+ bool use_exclusive_lock_;
+};
+
+#define REGISTER_KERNELS(T, Tindices) \
+ REGISTER_KERNEL_BUILDER(Name("SparseApplyAdagradDA") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<Tindices>("Tindices"), \
+ SparseApplyAdagradDAOp<T, Tindices>);
+
+REGISTER_KERNELS(float, int32);
+REGISTER_KERNELS(float, int64);
+REGISTER_KERNELS(double, int32);
+REGISTER_KERNELS(double, int64);
+#undef REGISTER_KERNELS
+
+template <typename Device, typename T>
class ApplyFtrlOp : public OpKernel {
public:
explicit ApplyFtrlOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
@@ -1312,8 +1624,8 @@ class ApplyFtrlOp : public OpKernel {
bool use_exclusive_lock_;
};
-typedef Eigen::ThreadPoolDevice CPUDevice;
-typedef Eigen::GpuDevice GPUDevice;
+using CPUDevice = Eigen::ThreadPoolDevice;
+using GPUDevice = Eigen::GpuDevice;
#define REGISTER_KERNELS(D, T) \
REGISTER_KERNEL_BUILDER( \
@@ -1403,6 +1715,10 @@ class SparseApplyFtrlOp : public OpKernel {
errors::InvalidArgument(
"grad must be the same size as indices in the first dimension."));
+ OP_REQUIRES(ctx, inner_dim > 0,
+ errors::InvalidArgument(
+ "Inner dimension should be greater than zero."));
+
if (N > 0) {
if (inner_dim > 1) {
const Tindex first_dim_size = var.dim_size(0);
@@ -1453,7 +1769,6 @@ class SparseApplyFtrlOp : public OpKernel {
accum += grad.square();
}
} else {
- CHECK_EQ(1, inner_dim);
auto indices_vec = indices.vec<Tindex>();
auto var_flat = var.flat<T>();
auto accum_flat = accum.flat<T>();
@@ -1567,8 +1882,8 @@ class ApplyMomentumOp : public OpKernel {
bool use_nesterov_;
};
-typedef Eigen::ThreadPoolDevice CPUDevice;
-typedef Eigen::GpuDevice GPUDevice;
+using CPUDevice = Eigen::ThreadPoolDevice;
+using GPUDevice = Eigen::GpuDevice;
#define REGISTER_KERNELS(D, T) \
REGISTER_KERNEL_BUILDER( \
@@ -1793,8 +2108,8 @@ class ApplyAdamOp : public OpKernel {
bool use_exclusive_lock_;
};
-typedef Eigen::ThreadPoolDevice CPUDevice;
-typedef Eigen::GpuDevice GPUDevice;
+using CPUDevice = Eigen::ThreadPoolDevice;
+using GPUDevice = Eigen::GpuDevice;
#define REGISTER_KERNELS(D, T) \
REGISTER_KERNEL_BUILDER( \
@@ -1910,8 +2225,8 @@ class ApplyRMSPropOp : public OpKernel {
bool use_exclusive_lock_;
};
-typedef Eigen::ThreadPoolDevice CPUDevice;
-typedef Eigen::GpuDevice GPUDevice;
+using CPUDevice = Eigen::ThreadPoolDevice;
+using GPUDevice = Eigen::GpuDevice;
#define REGISTER_KERNELS(D, T) \
REGISTER_KERNEL_BUILDER( \