aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/conditional_accumulator_base.h
diff options
context:
space:
mode:
authorGravatar Zhenyu Tan <tanzheny@google.com>2018-09-06 10:01:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-06 10:06:12 -0700
commitd17016a8dfd9b9bd92a55fc1fddee4fd1c29bdbe (patch)
treea96cc2bb410e2dbd4b42c04270750a1daf59d31d /tensorflow/core/kernels/conditional_accumulator_base.h
parentbfff3425e0938c6bcc635edce2673252c4762a99 (diff)
Extend ConditionalAccumulator with SUM functionality.
Previously take_grad represents the average gradients being aggregated. However this does not cover other use cases such as summing quantiles, or summing probability distributions from parallel workers. This change extends the functionality. PiperOrigin-RevId: 211824519
Diffstat (limited to 'tensorflow/core/kernels/conditional_accumulator_base.h')
-rw-r--r--tensorflow/core/kernels/conditional_accumulator_base.h3
1 files changed, 2 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/conditional_accumulator_base.h b/tensorflow/core/kernels/conditional_accumulator_base.h
index b7b7482a00..4a5ec6f0fb 100644
--- a/tensorflow/core/kernels/conditional_accumulator_base.h
+++ b/tensorflow/core/kernels/conditional_accumulator_base.h
@@ -52,7 +52,7 @@ class ConditionalAccumulatorBase : public ResourceBase {
// name: A name to use for the ConditionalAccumulator.
ConditionalAccumulatorBase(const DataType& dtype,
const PartialTensorShape& shape,
- const string& name);
+ const string& name, const string& reduction_type);
typedef AsyncOpKernel::DoneCallback DoneCallback;
@@ -125,6 +125,7 @@ class ConditionalAccumulatorBase : public ResourceBase {
const DataType dtype_;
const PartialTensorShape shape_;
const string name_;
+ const string reduction_type_;
mutex mu_;
int counter_ GUARDED_BY(mu_);
int64 current_global_step_ GUARDED_BY(mu_);