aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-24 09:44:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-24 09:47:31 -0700
commitcfedd67f5881ae3697638e9b74eccb7da9818a0e (patch)
tree598e7e458e1c9e6ff60fac5dcdb9a2f0ec62192d
parentb7f957ceedb6f47e4d68c506389bff210c35ef6a (diff)
Add an attr to apply_adagrad op that allows it to skip updating the accumulators.
PiperOrigin-RevId: 194100678
-rw-r--r--tensorflow/core/kernels/training_ops.cc23
-rw-r--r--tensorflow/core/kernels/training_ops.h2
-rw-r--r--tensorflow/core/kernels/training_ops_gpu.cu.cc6
-rw-r--r--tensorflow/core/ops/training_ops.cc4
4 files changed, 26 insertions, 9 deletions
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc
index 5b13b10937..271329599f 100644
--- a/tensorflow/core/kernels/training_ops.cc
+++ b/tensorflow/core/kernels/training_ops.cc
@@ -153,8 +153,10 @@ struct ApplyAdagrad<CPUDevice, T> {
void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
typename TTypes<T>::Flat accum,
typename TTypes<T>::ConstScalar lr,
- typename TTypes<T>::ConstFlat grad) {
- accum.device(d) += grad.square();
+ typename TTypes<T>::ConstFlat grad, bool update_slots) {
+ if (update_slots) {
+ accum.device(d) += grad.square();
+ }
var.device(d) -= grad * lr() * accum.rsqrt();
}
};
@@ -1074,6 +1076,7 @@ class ApplyAdagradOp : public OpKernel {
public:
explicit ApplyAdagradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("update_slots", &update_slots_));
}
void Compute(OpKernelContext* ctx) override {
@@ -1111,13 +1114,15 @@ class ApplyAdagradOp : public OpKernel {
const Device& device = ctx->template eigen_device<Device>();
functor::ApplyAdagrad<Device, T>()(device, var.flat<T>(), accum.flat<T>(),
- lr.scalar<T>(), grad.flat<T>());
+ lr.scalar<T>(), grad.flat<T>(),
+ update_slots_);
MaybeForwardRefInputToRefOutput(ctx, 0, 0);
}
private:
bool use_exclusive_lock_;
+ bool update_slots_;
};
#define REGISTER_KERNELS(D, T) \
@@ -1145,7 +1150,7 @@ namespace functor {
void ApplyAdagrad<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::Flat var, \
typename TTypes<T>::Flat accum, typename TTypes<T>::ConstScalar lr, \
- typename TTypes<T>::ConstFlat grad); \
+ typename TTypes<T>::ConstFlat grad, bool update_slots); \
extern template struct ApplyAdagrad<GPUDevice, T>;
DECLARE_GPU_SPEC(Eigen::half);
DECLARE_GPU_SPEC(float);
@@ -1266,6 +1271,7 @@ class SparseApplyAdagradOp : public OpKernel {
public:
explicit SparseApplyAdagradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("update_slots", &update_slots_));
}
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
@@ -1339,7 +1345,9 @@ class SparseApplyAdagradOp : public OpKernel {
auto a = accum_flat.template chip<0>(index);
auto g = grad_flat.template chip<0>(i);
auto v = var_flat.template chip<0>(index);
- a += g.square();
+ if (update_slots_) {
+ a += g.square();
+ }
v -= g.constant(lr_scalar) * g * a.rsqrt();
}
} else {
@@ -1358,7 +1366,9 @@ class SparseApplyAdagradOp : public OpKernel {
" in indices is out of range")));
T& a = accum_flat(index);
const T& g = grad_flat(i);
- a += g * g;
+ if (update_slots_) {
+ a += g * g;
+ }
var_flat(index) -= lr_scalar * g / Eigen::numext::sqrt(a);
}
}
@@ -1369,6 +1379,7 @@ class SparseApplyAdagradOp : public OpKernel {
private:
bool use_exclusive_lock_;
+ bool update_slots_;
};
#define REGISTER_KERNELS(T, Tindices) \
diff --git a/tensorflow/core/kernels/training_ops.h b/tensorflow/core/kernels/training_ops.h
index f536a61eb0..495a94f1a1 100644
--- a/tensorflow/core/kernels/training_ops.h
+++ b/tensorflow/core/kernels/training_ops.h
@@ -68,7 +68,7 @@ struct ApplyAdagrad {
void operator()(const Device& d, typename TTypes<T>::Flat var,
typename TTypes<T>::Flat accum,
typename TTypes<T>::ConstScalar lr,
- typename TTypes<T>::ConstFlat grad);
+ typename TTypes<T>::ConstFlat grad, bool update_slots);
};
template <typename Device, typename T>
diff --git a/tensorflow/core/kernels/training_ops_gpu.cu.cc b/tensorflow/core/kernels/training_ops_gpu.cu.cc
index 2aa17f2a0f..4bd32592db 100644
--- a/tensorflow/core/kernels/training_ops_gpu.cu.cc
+++ b/tensorflow/core/kernels/training_ops_gpu.cu.cc
@@ -42,8 +42,10 @@ struct ApplyAdagrad<GPUDevice, T> {
void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
typename TTypes<T>::Flat accum,
typename TTypes<T>::ConstScalar lr,
- typename TTypes<T>::ConstFlat grad) {
- accum.device(d) += grad.square();
+ typename TTypes<T>::ConstFlat grad, bool update_slots) {
+ if (update_slots) {
+ accum.device(d) += grad.square();
+ }
Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
bcast[0] = grad.dimension(0);
Eigen::Sizes<1> single;
diff --git a/tensorflow/core/ops/training_ops.cc b/tensorflow/core/ops/training_ops.cc
index dc7b588898..94ff092a85 100644
--- a/tensorflow/core/ops/training_ops.cc
+++ b/tensorflow/core/ops/training_ops.cc
@@ -253,6 +253,7 @@ REGISTER_OP("ApplyAdagrad")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
+ .Attr("update_slots: bool = true")
.SetShapeFn([](InferenceContext* c) {
return ApplyAdagradShapeFn(c, false /* sparse */);
});
@@ -264,6 +265,7 @@ REGISTER_OP("ResourceApplyAdagrad")
.Input("grad: T")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
+ .Attr("update_slots: bool = true")
.SetShapeFn([](InferenceContext* c) {
return ApplyAdagradShapeFn(c, false /* sparse */);
});
@@ -320,6 +322,7 @@ REGISTER_OP("SparseApplyAdagrad")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
+ .Attr("update_slots: bool = true")
.SetShapeFn([](InferenceContext* c) {
return ApplyAdagradShapeFn(c, true /* sparse */);
});
@@ -333,6 +336,7 @@ REGISTER_OP("ResourceSparseApplyAdagrad")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
+ .Attr("update_slots: bool = true")
.SetShapeFn([](InferenceContext* c) {
return ApplyAdagradShapeFn(c, true /* sparse */);
});