diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-24 09:44:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-24 09:47:31 -0700 |
commit | cfedd67f5881ae3697638e9b74eccb7da9818a0e (patch) | |
tree | 598e7e458e1c9e6ff60fac5dcdb9a2f0ec62192d | |
parent | b7f957ceedb6f47e4d68c506389bff210c35ef6a (diff) |
Add an attr to apply_adagrad op that allows it to skip updating the accumulators.
PiperOrigin-RevId: 194100678
-rw-r--r-- | tensorflow/core/kernels/training_ops.cc | 23 | ||||
-rw-r--r-- | tensorflow/core/kernels/training_ops.h | 2 | ||||
-rw-r--r-- | tensorflow/core/kernels/training_ops_gpu.cu.cc | 6 | ||||
-rw-r--r-- | tensorflow/core/ops/training_ops.cc | 4 |
4 files changed, 26 insertions, 9 deletions
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index 5b13b10937..271329599f 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -153,8 +153,10 @@ struct ApplyAdagrad<CPUDevice, T> { void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, typename TTypes<T>::Flat accum, typename TTypes<T>::ConstScalar lr, - typename TTypes<T>::ConstFlat grad) { - accum.device(d) += grad.square(); + typename TTypes<T>::ConstFlat grad, bool update_slots) { + if (update_slots) { + accum.device(d) += grad.square(); + } var.device(d) -= grad * lr() * accum.rsqrt(); } }; @@ -1074,6 +1076,7 @@ class ApplyAdagradOp : public OpKernel { public: explicit ApplyAdagradOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("update_slots", &update_slots_)); } void Compute(OpKernelContext* ctx) override { @@ -1111,13 +1114,15 @@ class ApplyAdagradOp : public OpKernel { const Device& device = ctx->template eigen_device<Device>(); functor::ApplyAdagrad<Device, T>()(device, var.flat<T>(), accum.flat<T>(), - lr.scalar<T>(), grad.flat<T>()); + lr.scalar<T>(), grad.flat<T>(), + update_slots_); MaybeForwardRefInputToRefOutput(ctx, 0, 0); } private: bool use_exclusive_lock_; + bool update_slots_; }; #define REGISTER_KERNELS(D, T) \ @@ -1145,7 +1150,7 @@ namespace functor { void ApplyAdagrad<GPUDevice, T>::operator()( \ const GPUDevice& d, typename TTypes<T>::Flat var, \ typename TTypes<T>::Flat accum, typename TTypes<T>::ConstScalar lr, \ - typename TTypes<T>::ConstFlat grad); \ + typename TTypes<T>::ConstFlat grad, bool update_slots); \ extern template struct ApplyAdagrad<GPUDevice, T>; DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); @@ -1266,6 +1271,7 @@ class SparseApplyAdagradOp : public OpKernel { public: explicit SparseApplyAdagradOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("update_slots", &update_slots_)); } void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { @@ -1339,7 +1345,9 @@ class SparseApplyAdagradOp : public OpKernel { auto a = accum_flat.template chip<0>(index); auto g = grad_flat.template chip<0>(i); auto v = var_flat.template chip<0>(index); - a += g.square(); + if (update_slots_) { + a += g.square(); + } v -= g.constant(lr_scalar) * g * a.rsqrt(); } } else { @@ -1358,7 +1366,9 @@ class SparseApplyAdagradOp : public OpKernel { " in indices is out of range"))); T& a = accum_flat(index); const T& g = grad_flat(i); - a += g * g; + if (update_slots_) { + a += g * g; + } var_flat(index) -= lr_scalar * g / Eigen::numext::sqrt(a); } } @@ -1369,6 +1379,7 @@ class SparseApplyAdagradOp : public OpKernel { private: bool use_exclusive_lock_; + bool update_slots_; }; #define REGISTER_KERNELS(T, Tindices) \ diff --git a/tensorflow/core/kernels/training_ops.h b/tensorflow/core/kernels/training_ops.h index f536a61eb0..495a94f1a1 100644 --- a/tensorflow/core/kernels/training_ops.h +++ b/tensorflow/core/kernels/training_ops.h @@ -68,7 +68,7 @@ struct ApplyAdagrad { void operator()(const Device& d, typename TTypes<T>::Flat var, typename TTypes<T>::Flat accum, typename TTypes<T>::ConstScalar lr, - typename TTypes<T>::ConstFlat grad); + typename TTypes<T>::ConstFlat grad, bool update_slots); }; template <typename Device, typename T> diff --git a/tensorflow/core/kernels/training_ops_gpu.cu.cc b/tensorflow/core/kernels/training_ops_gpu.cu.cc index 2aa17f2a0f..4bd32592db 100644 --- a/tensorflow/core/kernels/training_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/training_ops_gpu.cu.cc @@ -42,8 +42,10 @@ struct ApplyAdagrad<GPUDevice, T> { void operator()(const GPUDevice& d, typename TTypes<T>::Flat var, typename TTypes<T>::Flat accum, typename TTypes<T>::ConstScalar lr, - typename TTypes<T>::ConstFlat grad) { - accum.device(d) += grad.square(); + typename TTypes<T>::ConstFlat grad, bool update_slots) { + if (update_slots) { + accum.device(d) += grad.square(); + } Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast; bcast[0] = grad.dimension(0); Eigen::Sizes<1> single; diff --git a/tensorflow/core/ops/training_ops.cc b/tensorflow/core/ops/training_ops.cc index dc7b588898..94ff092a85 100644 --- a/tensorflow/core/ops/training_ops.cc +++ b/tensorflow/core/ops/training_ops.cc @@ -253,6 +253,7 @@ REGISTER_OP("ApplyAdagrad") .Output("out: Ref(T)") .Attr("T: numbertype") .Attr("use_locking: bool = false") + .Attr("update_slots: bool = true") .SetShapeFn([](InferenceContext* c) { return ApplyAdagradShapeFn(c, false /* sparse */); }); @@ -264,6 +265,7 @@ REGISTER_OP("ResourceApplyAdagrad") .Input("grad: T") .Attr("T: numbertype") .Attr("use_locking: bool = false") + .Attr("update_slots: bool = true") .SetShapeFn([](InferenceContext* c) { return ApplyAdagradShapeFn(c, false /* sparse */); }); @@ -320,6 +322,7 @@ REGISTER_OP("SparseApplyAdagrad") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = false") + .Attr("update_slots: bool = true") .SetShapeFn([](InferenceContext* c) { return ApplyAdagradShapeFn(c, true /* sparse */); }); @@ -333,6 +336,7 @@ REGISTER_OP("ResourceSparseApplyAdagrad") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = false") + .Attr("update_slots: bool = true") .SetShapeFn([](InferenceContext* c) { return ApplyAdagradShapeFn(c, true /* sparse */); }); |