aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/conditional_accumulator.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/conditional_accumulator.h')
-rw-r--r--tensorflow/core/kernels/conditional_accumulator.h6
1 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/conditional_accumulator.h b/tensorflow/core/kernels/conditional_accumulator.h
index a7836896c7..390db8fe5a 100644
--- a/tensorflow/core/kernels/conditional_accumulator.h
+++ b/tensorflow/core/kernels/conditional_accumulator.h
@@ -51,9 +51,11 @@ class ConditionalAccumulator
// dtype: The datatype of the gradients to be accumulated.
// shape: The shape of the accumulated gradients.
// name: A name to use for the ConditionalAccumulator.
+ // reduction_type: The reduction type, i.e., MEAN or SUM
ConditionalAccumulator(const DataType& dtype, const PartialTensorShape& shape,
- const string& name)
- : TypedConditionalAccumulatorBase<const Tensor>(dtype, shape, name) {}
+ const string& name, const string& reduction_type)
+ : TypedConditionalAccumulatorBase<const Tensor>(dtype, shape, name,
+ reduction_type) {}
~ConditionalAccumulator() override{};
protected: