diff options
Diffstat (limited to 'tensorflow/core/kernels/conditional_accumulator.h')
-rw-r--r-- | tensorflow/core/kernels/conditional_accumulator.h | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/conditional_accumulator.h b/tensorflow/core/kernels/conditional_accumulator.h index a7836896c7..390db8fe5a 100644 --- a/tensorflow/core/kernels/conditional_accumulator.h +++ b/tensorflow/core/kernels/conditional_accumulator.h @@ -51,9 +51,11 @@ class ConditionalAccumulator // dtype: The datatype of the gradients to be accumulated. // shape: The shape of the accumulated gradients. // name: A name to use for the ConditionalAccumulator. + // reduction_type: The reduction type, i.e., MEAN or SUM ConditionalAccumulator(const DataType& dtype, const PartialTensorShape& shape, - const string& name) - : TypedConditionalAccumulatorBase<const Tensor>(dtype, shape, name) {} + const string& name, const string& reduction_type) + : TypedConditionalAccumulatorBase<const Tensor>(dtype, shape, name, + reduction_type) {} ~ConditionalAccumulator() override{}; protected: |