diff options
Diffstat (limited to 'tensorflow/core/kernels/conditional_accumulator_base.cc')
-rw-r--r-- | tensorflow/core/kernels/conditional_accumulator_base.cc | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/conditional_accumulator_base.cc b/tensorflow/core/kernels/conditional_accumulator_base.cc index 90593c56b8..292cf0cd64 100644 --- a/tensorflow/core/kernels/conditional_accumulator_base.cc +++ b/tensorflow/core/kernels/conditional_accumulator_base.cc @@ -14,12 +14,17 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/conditional_accumulator_base.h" +#include "tensorflow/core/lib/core/errors.h" namespace tensorflow { ConditionalAccumulatorBase::ConditionalAccumulatorBase( - const DataType& dtype, const PartialTensorShape& shape, const string& name) - : dtype_(dtype), shape_(shape), name_(name) { + const DataType& dtype, const PartialTensorShape& shape, const string& name, + const string& reduction_type) + : dtype_(dtype), + shape_(shape), + name_(name), + reduction_type_(reduction_type) { counter_ = 0; current_global_step_ = 0; } @@ -190,7 +195,9 @@ bool ConditionalAccumulatorBase::TakeGradLockedHelper(OpKernelContext* ctx, current_global_step_++; // Average the accumulated gradient - DivideAccumGradByCounter(ctx); + if (reduction_type_ == "MEAN") { + DivideAccumGradByCounter(ctx); + } // Set output for accumulated gradient tensor bool successful_set_output = SetOutput(ctx); |