aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Zhenyu Tan <tanzheny@google.com>2018-09-06 10:01:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-06 10:06:12 -0700
commitd17016a8dfd9b9bd92a55fc1fddee4fd1c29bdbe (patch)
treea96cc2bb410e2dbd4b42c04270750a1daf59d31d
parentbfff3425e0938c6bcc635edce2673252c4762a99 (diff)
Extend ConditionalAccumulator with SUM functionality.
Previously take_grad represents the average gradients being aggregated. However this does not cover other use cases such as summing quantiles, or summing probability distributions from parallel workers. This change extends the functionality. PiperOrigin-RevId: 211824519
-rw-r--r--tensorflow/core/kernels/conditional_accumulator.h6
-rw-r--r--tensorflow/core/kernels/conditional_accumulator_base.cc13
-rw-r--r--tensorflow/core/kernels/conditional_accumulator_base.h3
-rw-r--r--tensorflow/core/kernels/conditional_accumulator_base_op.h3
-rw-r--r--tensorflow/core/kernels/conditional_accumulator_op.cc3
-rw-r--r--tensorflow/core/kernels/sparse_conditional_accumulator.h4
-rw-r--r--tensorflow/core/kernels/sparse_conditional_accumulator_op.cc4
-rw-r--r--tensorflow/core/kernels/typed_conditional_accumulator_base.h5
-rw-r--r--tensorflow/core/ops/data_flow_ops.cc2
-rw-r--r--tensorflow/python/kernel_tests/conditional_accumulator_test.py88
-rw-r--r--tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py83
-rw-r--r--tensorflow/python/ops/data_flow_ops.py20
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.-conditional-accumulator.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.-sparse-conditional-accumulator.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-sparse-conditional-accumulator.pbtxt2
16 files changed, 207 insertions, 35 deletions
diff --git a/tensorflow/core/kernels/conditional_accumulator.h b/tensorflow/core/kernels/conditional_accumulator.h
index a7836896c7..390db8fe5a 100644
--- a/tensorflow/core/kernels/conditional_accumulator.h
+++ b/tensorflow/core/kernels/conditional_accumulator.h
@@ -51,9 +51,11 @@ class ConditionalAccumulator
// dtype: The datatype of the gradients to be accumulated.
// shape: The shape of the accumulated gradients.
// name: A name to use for the ConditionalAccumulator.
+ // reduction_type: The reduction type, i.e., MEAN or SUM
ConditionalAccumulator(const DataType& dtype, const PartialTensorShape& shape,
- const string& name)
- : TypedConditionalAccumulatorBase<const Tensor>(dtype, shape, name) {}
+ const string& name, const string& reduction_type)
+ : TypedConditionalAccumulatorBase<const Tensor>(dtype, shape, name,
+ reduction_type) {}
~ConditionalAccumulator() override{};
protected:
diff --git a/tensorflow/core/kernels/conditional_accumulator_base.cc b/tensorflow/core/kernels/conditional_accumulator_base.cc
index 90593c56b8..292cf0cd64 100644
--- a/tensorflow/core/kernels/conditional_accumulator_base.cc
+++ b/tensorflow/core/kernels/conditional_accumulator_base.cc
@@ -14,12 +14,17 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/conditional_accumulator_base.h"
+#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
ConditionalAccumulatorBase::ConditionalAccumulatorBase(
- const DataType& dtype, const PartialTensorShape& shape, const string& name)
- : dtype_(dtype), shape_(shape), name_(name) {
+ const DataType& dtype, const PartialTensorShape& shape, const string& name,
+ const string& reduction_type)
+ : dtype_(dtype),
+ shape_(shape),
+ name_(name),
+ reduction_type_(reduction_type) {
counter_ = 0;
current_global_step_ = 0;
}
@@ -190,7 +195,9 @@ bool ConditionalAccumulatorBase::TakeGradLockedHelper(OpKernelContext* ctx,
current_global_step_++;
// Average the accumulated gradient
- DivideAccumGradByCounter(ctx);
+ if (reduction_type_ == "MEAN") {
+ DivideAccumGradByCounter(ctx);
+ }
// Set output for accumulated gradient tensor
bool successful_set_output = SetOutput(ctx);
diff --git a/tensorflow/core/kernels/conditional_accumulator_base.h b/tensorflow/core/kernels/conditional_accumulator_base.h
index b7b7482a00..4a5ec6f0fb 100644
--- a/tensorflow/core/kernels/conditional_accumulator_base.h
+++ b/tensorflow/core/kernels/conditional_accumulator_base.h
@@ -52,7 +52,7 @@ class ConditionalAccumulatorBase : public ResourceBase {
// name: A name to use for the ConditionalAccumulator.
ConditionalAccumulatorBase(const DataType& dtype,
const PartialTensorShape& shape,
- const string& name);
+ const string& name, const string& reduction_type);
typedef AsyncOpKernel::DoneCallback DoneCallback;
@@ -125,6 +125,7 @@ class ConditionalAccumulatorBase : public ResourceBase {
const DataType dtype_;
const PartialTensorShape shape_;
const string name_;
+ const string reduction_type_;
mutex mu_;
int counter_ GUARDED_BY(mu_);
int64 current_global_step_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/kernels/conditional_accumulator_base_op.h b/tensorflow/core/kernels/conditional_accumulator_base_op.h
index 012a0dcc12..ca24d690f8 100644
--- a/tensorflow/core/kernels/conditional_accumulator_base_op.h
+++ b/tensorflow/core/kernels/conditional_accumulator_base_op.h
@@ -51,6 +51,8 @@ class ConditionalAccumulatorBaseOp : public OpKernel {
&accumulator_handle_, nullptr));
OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
+ OP_REQUIRES_OK(context,
+ context->GetAttr("reduction_type", &reduction_type_));
}
void Compute(OpKernelContext* ctx) override {
@@ -81,6 +83,7 @@ class ConditionalAccumulatorBaseOp : public OpKernel {
DataType dtype_;
PartialTensorShape shape_;
ContainerInfo cinfo_;
+ string reduction_type_;
private:
Status SetAccumulatorHandle(OpKernelContext* ctx)
diff --git a/tensorflow/core/kernels/conditional_accumulator_op.cc b/tensorflow/core/kernels/conditional_accumulator_op.cc
index e13bf8a4c6..52ac51a9b6 100644
--- a/tensorflow/core/kernels/conditional_accumulator_op.cc
+++ b/tensorflow/core/kernels/conditional_accumulator_op.cc
@@ -34,7 +34,8 @@ class ConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp {
Creator GetCreator() const override {
return [this](ConditionalAccumulatorBase** ret) {
ConditionalAccumulator<Device, T>* accumulator =
- new ConditionalAccumulator<Device, T>(dtype_, shape_, cinfo_.name());
+ new ConditionalAccumulator<Device, T>(dtype_, shape_, cinfo_.name(),
+ reduction_type_);
*ret = accumulator;
return Status::OK();
};
diff --git a/tensorflow/core/kernels/sparse_conditional_accumulator.h b/tensorflow/core/kernels/sparse_conditional_accumulator.h
index 11149c4d16..a4453bd7ab 100644
--- a/tensorflow/core/kernels/sparse_conditional_accumulator.h
+++ b/tensorflow/core/kernels/sparse_conditional_accumulator.h
@@ -50,10 +50,10 @@ class SparseConditionalAccumulator
public:
SparseConditionalAccumulator(const DataType& dtype,
const PartialTensorShape& shape,
- const string& name)
+ const string& name, const string& reduction_type)
: TypedConditionalAccumulatorBase<
std::tuple<const Tensor*, const Tensor*, const Tensor*>>(
- dtype, shape, name) {
+ dtype, shape, name, reduction_type) {
accum_idx_vec_ = nullptr;
count_element_ = nullptr;
accum_val_ = nullptr;
diff --git a/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc b/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc
index 80bc1f1934..1e542a26a7 100644
--- a/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc
+++ b/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc
@@ -34,8 +34,8 @@ class SparseConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp {
Creator GetCreator() const override {
return [this](ConditionalAccumulatorBase** ret) {
SparseConditionalAccumulator<Device, T>* accumulator =
- new SparseConditionalAccumulator<Device, T>(dtype_, shape_,
- cinfo_.name());
+ new SparseConditionalAccumulator<Device, T>(
+ dtype_, shape_, cinfo_.name(), reduction_type_);
*ret = accumulator;
return Status::OK();
};
diff --git a/tensorflow/core/kernels/typed_conditional_accumulator_base.h b/tensorflow/core/kernels/typed_conditional_accumulator_base.h
index 9dedb618f9..ca341e511e 100644
--- a/tensorflow/core/kernels/typed_conditional_accumulator_base.h
+++ b/tensorflow/core/kernels/typed_conditional_accumulator_base.h
@@ -35,8 +35,9 @@ class TypedConditionalAccumulatorBase : public ConditionalAccumulatorBase {
public:
TypedConditionalAccumulatorBase(const DataType& dtype,
const PartialTensorShape& shape,
- const string& name)
- : ConditionalAccumulatorBase(dtype, shape, name) {}
+ const string& name,
+ const string& reduction_type)
+ : ConditionalAccumulatorBase(dtype, shape, name, reduction_type) {}
/**
* Attempts to add a gradient to the accumulator. An ApplyGrad attempt is
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index eed0bce174..ffab8ad661 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -419,6 +419,7 @@ REGISTER_OP("ConditionalAccumulator")
.Attr("shape: shape")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
+ .Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ")
.SetIsStateful()
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->Vector(2));
@@ -456,6 +457,7 @@ REGISTER_OP("SparseConditionalAccumulator")
.Attr("shape: shape")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
+ .Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ")
.SetIsStateful()
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->Vector(2));
diff --git a/tensorflow/python/kernel_tests/conditional_accumulator_test.py b/tensorflow/python/kernel_tests/conditional_accumulator_test.py
index 7570523495..86802664d1 100644
--- a/tensorflow/python/kernel_tests/conditional_accumulator_test.py
+++ b/tensorflow/python/kernel_tests/conditional_accumulator_test.py
@@ -42,14 +42,22 @@ class ConditionalAccumulatorTest(test.TestCase):
with ops.Graph().as_default():
q = data_flow_ops.ConditionalAccumulator(dtypes_lib.float32, name="Q")
self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor))
- self.assertProtoEquals("""
+ self.assertProtoEquals(
+ """
name:'Q' op:'ConditionalAccumulator'
attr { key: 'dtype' value { type: DT_FLOAT } }
attr { key: 'shape' value { shape { unknown_rank: true} } }
attr { key: 'container' value { s: '' } }
attr { key: 'shared_name' value { s: '' } }
+ attr { key: 'reduction_type' value {s: 'MEAN'} }
""", q.accumulator_ref.op.node_def)
+ def testConstructorWithInvalidArg(self):
+ with ops.Graph().as_default():
+ with self.assertRaises(ValueError):
+ data_flow_ops.ConditionalAccumulator(
+ dtypes_lib.float32, name="Q", reduction_type="Invalid")
+
def testConstructorWithShape(self):
with ops.Graph().as_default():
q = data_flow_ops.ConditionalAccumulator(
@@ -57,7 +65,8 @@ class ConditionalAccumulatorTest(test.TestCase):
name="Q",
shape=tensor_shape.TensorShape([1, 5, 2, 8]))
self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor))
- self.assertProtoEquals("""
+ self.assertProtoEquals(
+ """
name:'Q' op:'ConditionalAccumulator'
attr { key: 'dtype' value { type: DT_FLOAT } }
attr { key: 'shape' value { shape { dim {size: 1 }
@@ -67,6 +76,7 @@ class ConditionalAccumulatorTest(test.TestCase):
} } }
attr { key: 'container' value { s: '' } }
attr { key: 'shared_name' value { s: '' } }
+ attr { key: 'reduction_type' value {s: 'MEAN'} }
""", q.accumulator_ref.op.node_def)
def testAccumulatorSizeEmpty(self):
@@ -237,12 +247,11 @@ class ConditionalAccumulatorTest(test.TestCase):
extract_t.op.run()
self.assertEqual(q.num_accumulated().eval(), 0)
- def testAccumulatorTakeGrad(self):
+ def testAccumulatorTakeGradMean(self):
with self.test_session():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
elems = [10.0, 20.0]
- elems_ave = sum(elems) / len(elems)
accum_ops = [q.apply_grad((x,), local_step=0) for x in elems]
takeg_t = q.take_grad(1)
@@ -251,7 +260,7 @@ class ConditionalAccumulatorTest(test.TestCase):
accum_op.run()
val = takeg_t.eval()
- self.assertEqual(elems_ave, val)
+ self.assertEqual(15.0, val)
accum_ops = [q.apply_grad((x,), local_step=1) for x in elems]
takeg_t = q.take_grad(constant_op.constant(1))
@@ -260,7 +269,42 @@ class ConditionalAccumulatorTest(test.TestCase):
accum_op.run()
val = takeg_t.eval()
- self.assertEqual(elems_ave, val)
+ self.assertEqual(15.0, val)
+
+ def testAccumulatorTakeGradSum(self):
+ with self.test_session():
+ q = data_flow_ops.ConditionalAccumulator(
+ dtypes_lib.float32,
+ name="Q",
+ shape=tensor_shape.TensorShape([1]),
+ reduction_type="SUM")
+ elems = [10.0, 20.0]
+
+ accum_ops = [q.apply_grad((x,), local_step=0) for x in elems]
+ takeg_t = q.take_grad(1)
+
+ for accum_op in accum_ops:
+ accum_op.run()
+
+ val = takeg_t.eval()
+ self.assertEqual(30.0, val)
+
+ accum_ops = [q.apply_grad((x,), local_step=1) for x in elems]
+ takeg_t = q.take_grad(constant_op.constant(1))
+
+ for accum_op in accum_ops:
+ accum_op.run()
+
+ val = takeg_t.eval()
+ self.assertEqual(30.0, val)
+
+ def testAccumulatorTakeGradInvalidReductionType(self):
+ with self.assertRaises(ValueError):
+ data_flow_ops.ConditionalAccumulator(
+ dtypes_lib.float32,
+ name="Q",
+ shape=tensor_shape.TensorShape([1]),
+ reduction_type="Invalid")
def testAccumulatorInvalidTakeGrad(self):
with self.test_session():
@@ -277,7 +321,7 @@ class ConditionalAccumulatorTest(test.TestCase):
with self.assertRaises(errors_impl.InvalidArgumentError):
takeg_t.eval()
- def testAccumulatorRepeatedTakeGrad(self):
+ def testAccumulatorRepeatedTakeGradMean(self):
with self.test_session():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
@@ -304,6 +348,36 @@ class ConditionalAccumulatorTest(test.TestCase):
val = takeg_t.eval()
self.assertEqual(elems_ave + 0.0, val)
+ def testAccumulatorRepeatedTakeGradSum(self):
+ with self.test_session():
+ q = data_flow_ops.ConditionalAccumulator(
+ dtypes_lib.float32,
+ name="Q",
+ shape=tensor_shape.TensorShape([1]),
+ reduction_type="SUM")
+
+ elems = [10.0, 20.0]
+ elems_sum = 30.0
+ accum_ops = [q.apply_grad((x,), local_step=0) for x in elems]
+ takeg_t = q.take_grad(1)
+
+ for accum_op in accum_ops:
+ accum_op.run()
+
+ val = takeg_t.eval()
+ self.assertEqual(elems_sum, val)
+
+ elems = [20.0, 30.0]
+ elems_sum = 50.0
+ accum_ops = [q.apply_grad((x,), local_step=1) for x in elems]
+ takeg_t = q.take_grad(1)
+
+ for accum_op in accum_ops:
+ accum_op.run()
+
+ val = takeg_t.eval()
+ self.assertEqual(elems_sum, val)
+
def testAccumulatorIncrementGlobalStep(self):
with self.test_session():
q = data_flow_ops.ConditionalAccumulator(
diff --git a/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py b/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py
index d749843410..3bb5e899fe 100644
--- a/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py
+++ b/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py
@@ -61,14 +61,22 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q")
self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor))
- self.assertProtoEquals("""
+ self.assertProtoEquals(
+ """
name:'Q' op:'SparseConditionalAccumulator'
attr { key: 'dtype' value { type: DT_FLOAT } }
attr { key: 'shape' value { shape { unknown_rank: true} } }
attr { key: 'container' value { s: '' } }
attr { key: 'shared_name' value { s: '' } }
+ attr { key: 'reduction_type' value {s: 'MEAN'} }
""", q.accumulator_ref.op.node_def)
+ def testConstructorWithInvalidArg(self):
+ with ops.Graph().as_default():
+ with self.assertRaises(ValueError):
+ data_flow_ops.SparseConditionalAccumulator(
+ dtypes_lib.float32, name="Q", reduction_type="Invalid")
+
def testConstructorWithShape(self):
with ops.Graph().as_default():
q = data_flow_ops.SparseConditionalAccumulator(
@@ -76,7 +84,8 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
name="Q",
shape=tensor_shape.TensorShape([1, 5, 2, 8]))
self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor))
- self.assertProtoEquals("""
+ self.assertProtoEquals(
+ """
name:'Q' op:'SparseConditionalAccumulator'
attr { key: 'dtype' value { type: DT_FLOAT } }
attr { key: 'shape' value { shape { dim {size: 1 }
@@ -86,6 +95,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
} } }
attr { key: 'container' value { s: '' } }
attr { key: 'shared_name' value { s: '' } }
+ attr { key: 'reduction_type' value {s: 'MEAN'} }
""", q.accumulator_ref.op.node_def)
def testAccumulatorSizeEmpty(self):
@@ -164,7 +174,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
result = sess.run(accums[i].take_indexed_slices_grad(1))
self._assertEqual_indexedslices(expected_tensors[i], result)
- def testAccumulatorTakeGrad(self):
+ def testAccumulatorTakeGradMean(self):
with self.test_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=())
@@ -180,9 +190,34 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
takeg_t = q.take_indexed_slices_grad(1)
val = sess.run(takeg_t)
- self.assertAllEqual(val.indices, [0, 1, 2])
- self.assertAllEqual(val.values, [[0.5, 0.5], [0, 2], [3, 0]])
- self.assertAllEqual(val.dense_shape, [-1, 2])
+ self.assertAllEqual([0, 1, 2], val.indices)
+ self.assertAllEqual([[0.5, 0.5], [0, 2], [3, 0]], val.values)
+ self.assertAllEqual([-1, 2], val.dense_shape)
+
+ def testAccumulatorTakeGradSum(self):
+ with self.test_session() as sess:
+ q = data_flow_ops.SparseConditionalAccumulator(
+ dtypes_lib.float32, name="Q", shape=(), reduction_type="SUM")
+
+ grad_indexed_slices = ops.IndexedSlices(
+ indices=[0, 1], values=np.array([[1, 0], [0, 2]]).astype(np.float32))
+ accum_op = q.apply_indexed_slices_grad(grad_indexed_slices)
+ accum_op.run()
+ accum_op = q.apply_grad([0, 2],
+ np.array([[0, 1], [3, 0]]).astype(np.float32),
+ [3, 2])
+ accum_op.run()
+
+ takeg_t = q.take_indexed_slices_grad(1)
+ val = sess.run(takeg_t)
+ self.assertAllEqual([0, 1, 2], val.indices)
+ self.assertAllEqual([[1, 1], [0, 2], [3, 0]], val.values)
+ self.assertAllEqual([-1, 2], val.dense_shape)
+
+ def testAccumulatorTakeGradInvalidReductionType(self):
+ with self.assertRaises(ValueError):
+ data_flow_ops.SparseConditionalAccumulator(
+ dtypes_lib.float32, name="Q", shape=(), reduction_type="Invalid")
def testAccumulatorRepeatedTakeGrad(self):
with self.test_session() as sess:
@@ -222,7 +257,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
self.assertAllEqual(val.values, [[5, 5], [0, 20], [30, 0]])
self.assertAllEqual(val.dense_shape, [-1, 2])
- def testParallelApplyGrad(self):
+ def testParallelApplyGradMean(self):
with self.test_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2]))
@@ -253,6 +288,40 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
np.array([[expected_val, 0], [0, expected_val]]).astype(np.float32),
val, sess)
+ def testParallelApplyGradSum(self):
+ with self.test_session() as sess:
+ q = data_flow_ops.SparseConditionalAccumulator(
+ dtypes_lib.float32,
+ name="Q",
+ shape=tensor_shape.TensorShape([2, 2]),
+ reduction_type="SUM")
+ elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
+ accum_ops = []
+ for x in elems:
+ x = _indexedslice(np.array([[x, 0], [0, x]]).astype(np.float32))
+ accum_ops.append(q.apply_indexed_slices_grad(x, local_step=0))
+ takeg_t = q.take_indexed_slices_grad(1)
+
+ def apply_indexed_slices_grad(accum_op):
+ sess.run(accum_op)
+
+ threads = [
+ self.checkedThread(target=apply_indexed_slices_grad, args=(o,))
+ for o in accum_ops
+ ]
+
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+
+ val = sess.run(takeg_t)
+
+ expected_val = 550.0
+ self._assertEqual_nparray(
+ np.array([[expected_val, 0], [0, expected_val]]).astype(np.float32),
+ val, sess)
+
def testParallelTakeGrad(self):
with self.test_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index 7af2ca56be..69c0fcbbee 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -1229,7 +1229,8 @@ class ConditionalAccumulator(ConditionalAccumulatorBase):
dtype,
shape=None,
shared_name=None,
- name="conditional_accumulator"):
+ name="conditional_accumulator",
+ reduction_type="MEAN"):
"""Creates a new ConditionalAccumulator.
Args:
@@ -1238,9 +1239,14 @@ class ConditionalAccumulator(ConditionalAccumulatorBase):
shared_name: Optional. If non-empty, this accumulator will be shared under
the given name across multiple sessions.
name: Optional name for the accumulator.
+ reduction_type: Reduction type to use when taking the gradient.
"""
accumulator_ref = gen_data_flow_ops.conditional_accumulator(
- dtype=dtype, shape=shape, shared_name=shared_name, name=name)
+ dtype=dtype,
+ shape=shape,
+ shared_name=shared_name,
+ name=name,
+ reduction_type=reduction_type)
super(ConditionalAccumulator, self).__init__(dtype, shape, accumulator_ref)
def apply_grad(self, grad, local_step=0, name=None):
@@ -1312,15 +1318,21 @@ class SparseConditionalAccumulator(ConditionalAccumulatorBase):
shared_name: Optional. If non-empty, this accumulator will be shared under
the given name across multiple sessions.
name: Optional name for the accumulator.
+ reduction_type: Reduction type to use when taking the gradient.
"""
def __init__(self,
dtype,
shape=None,
shared_name=None,
- name="sparse_conditional_accumulator"):
+ name="sparse_conditional_accumulator",
+ reduction_type="MEAN"):
accumulator_ref = gen_data_flow_ops.sparse_conditional_accumulator(
- dtype=dtype, shape=shape, shared_name=shared_name, name=name)
+ dtype=dtype,
+ shape=shape,
+ shared_name=shared_name,
+ name=name,
+ reduction_type=reduction_type)
super(SparseConditionalAccumulator, self).__init__(dtype, shape,
accumulator_ref)
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-conditional-accumulator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-conditional-accumulator.pbtxt
index d23b3bd0ca..15e0ab76b6 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-conditional-accumulator.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-conditional-accumulator.pbtxt
@@ -17,7 +17,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\'], "
+ argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\', \'MEAN\'], "
}
member_method {
name: "apply_grad"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-sparse-conditional-accumulator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-sparse-conditional-accumulator.pbtxt
index 2260279ad2..39ff336c4f 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-sparse-conditional-accumulator.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-sparse-conditional-accumulator.pbtxt
@@ -17,7 +17,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\'], "
+ argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\', \'MEAN\'], "
}
member_method {
name: "apply_grad"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator.pbtxt
index d23b3bd0ca..15e0ab76b6 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator.pbtxt
@@ -17,7 +17,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\'], "
+ argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\', \'MEAN\'], "
}
member_method {
name: "apply_grad"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-sparse-conditional-accumulator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-sparse-conditional-accumulator.pbtxt
index 2260279ad2..39ff336c4f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-sparse-conditional-accumulator.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-sparse-conditional-accumulator.pbtxt
@@ -17,7 +17,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\'], "
+ argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\', \'MEAN\'], "
}
member_method {
name: "apply_grad"