diff options
author | Nikolaas Steenbergen <nikolaas.steenbergen@googlemail.com> | 2017-02-04 00:15:21 +0100 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2017-02-03 15:15:21 -0800 |
commit | 94f2229c9e4b4324a324330c8f419276eda7e503 (patch) | |
tree | d458b69c476dfd42a214032f2d1e277fa4cd0d18 /tensorflow | |
parent | 084b37a00f3cf2cc89d433528ca63ec1d3b5b313 (diff) |
TF-549 Adds unsorted segment max Op (#6975)
* TF-549 Adds unsorted segment max Op
* Cosmetic change
* Add todo comment
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/core/kernels/segment_reduction_ops.cc | 113 | ||||
-rw-r--r-- | tensorflow/core/kernels/segment_reduction_ops.h | 30 | ||||
-rw-r--r-- | tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/ops/math_grad.cc | 1 | ||||
-rw-r--r-- | tensorflow/core/ops/math_ops.cc | 98 | ||||
-rw-r--r-- | tensorflow/core/ops/ops.pbtxt | 53 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/segment_reduction_ops_test.py | 32 | ||||
-rw-r--r-- | tensorflow/python/ops/math_grad.py | 30 | ||||
-rw-r--r-- | tensorflow/python/ops/math_ops.py | 1 |
9 files changed, 297 insertions, 65 deletions
diff --git a/tensorflow/core/kernels/segment_reduction_ops.cc b/tensorflow/core/kernels/segment_reduction_ops.cc index fee16cdb78..5bd4362801 100644 --- a/tensorflow/core/kernels/segment_reduction_ops.cc +++ b/tensorflow/core/kernels/segment_reduction_ops.cc @@ -220,13 +220,15 @@ REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128); namespace functor { // UnsortedSegmentSumFunctor implementation for CPUDevice. +// todo: Remove duplicate code in UnsortedSegmentSumFunctor and UnsortedSegmentMaxFunctor. template <typename T, typename Index> -struct UnsortedSegmentSumFunctor<CPUDevice, T, Index> { +struct UnsortedSegmentSumFunctor<CPUDevice, T, Index> + : UnsortedSegmentBaseFunctor<CPUDevice, T, Index> { void operator()(OpKernelContext* ctx, const CPUDevice& d, const Index output_rows, const TensorShape& segment_ids_shape, typename TTypes<Index>::ConstFlat segment_ids, const Index data_size, const T* data, - typename TTypes<T, 2>::Tensor output) { + typename TTypes<T, 2>::Tensor output) override { output.setZero(); if (data_size == 0) { return; @@ -243,16 +245,44 @@ struct UnsortedSegmentSumFunctor<CPUDevice, T, Index> { } } }; - +// UnsortedSegmentMaxFunctor implementation for CPUDevice. +template <typename T, typename Index> +struct UnsortedSegmentMaxFunctor<CPUDevice, T, Index> + : UnsortedSegmentBaseFunctor<CPUDevice, T, Index> { + void operator()(OpKernelContext* ctx, const CPUDevice& d, + const Index output_rows, const TensorShape& segment_ids_shape, + typename TTypes<Index>::ConstFlat segment_ids, + const Index data_size, const T* data, + typename TTypes<T, 2>::Tensor output) override { + output.setConstant(std::numeric_limits<T>::min()); + if (data_size == 0) { + return; + } + const int64 N = segment_ids.dimension(0); + auto data_flat = typename TTypes<T, 2>::ConstTensor(data, N, data_size / N); + for (int64 i = 0; i < N; ++i) { + Index j = internal::SubtleMustCopy(segment_ids(i)); + OP_REQUIRES(ctx, FastBoundsCheck(j, output_rows), + errors::InvalidArgument( + "segment_ids", SliceDebugString(segment_ids_shape, i), + " = ", j, " is out of range [0, ", output_rows, ")")); + output.template chip<0>(j) = + data_flat.template chip<0>(i).cwiseMax(output.template chip<0>(j)); + } + } +}; } // namespace functor -// Similar to SegmentReductionOp but can handle unsorted segment definitions and -// specifying size of output. +// Base class for SegmentReductionOps that can handle unsorted segment +// definitions +// and specifying the size of the output in addition to a reduction function template <typename Device, class T, class Index> -class UnsortedSegmentSumOp : public OpKernel { +class UnsortedSegmentBaseOp : public OpKernel { public: - explicit UnsortedSegmentSumOp(OpKernelConstruction* context) - : OpKernel(context) {} + explicit UnsortedSegmentBaseOp( + OpKernelConstruction* context, + functor::UnsortedSegmentBaseFunctor<Device, T, Index>& functor) + : OpKernel(context), reduction_functor_(functor) {} void Compute(OpKernelContext* context) override { const Tensor& data = context->input(0); @@ -288,27 +318,70 @@ class UnsortedSegmentSumOp : public OpKernel { auto output_flat = output->flat_outer_dims<T>(); auto data_ptr = data.template flat<T>().data(); - functor::UnsortedSegmentSumFunctor<Device, T, Index>()( - context, context->template eigen_device<Device>(), output_rows, - segment_ids.shape(), segment_flat, data.NumElements(), data_ptr, - output_flat); + reduction_functor_(context, context->template eigen_device<Device>(), + output_rows, segment_ids.shape(), segment_flat, + data.NumElements(), data_ptr, output_flat); } + private: + functor::UnsortedSegmentBaseFunctor<Device, T, Index>& reduction_functor_; }; -#define REGISTER_CPU_UNSORTED_KERNELS(type, index_type) \ +template <typename Device, class T, class Index> +class UnsortedSegmentSumOp : public UnsortedSegmentBaseOp<Device, T, Index> { + public: + explicit UnsortedSegmentSumOp(OpKernelConstruction* context) + : UnsortedSegmentBaseOp<Device, T, Index>( + context, + sum_functor_) {} + private: + functor::UnsortedSegmentSumFunctor<Device, T, Index> sum_functor_; +}; + +template <typename Device, class T, class Index> +class UnsortedSegmentMaxOp : public UnsortedSegmentBaseOp<Device, T, Index> { + public: + explicit UnsortedSegmentMaxOp(OpKernelConstruction* context) + : UnsortedSegmentBaseOp<Device, T, Index>( + context, + max_functor_) {} + private: + functor::UnsortedSegmentMaxFunctor<Device, T, Index> max_functor_; +}; + +#define REGISTER_REAL_CPU_UNSORTED_KERNELS(type, index_type) \ + REGISTER_KERNEL_BUILDER(Name("UnsortedSegmentSum") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<index_type>("Tindices"), \ + UnsortedSegmentSumOp<CPUDevice, type, index_type>); \ + REGISTER_KERNEL_BUILDER(Name("UnsortedSegmentMax") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<index_type>("Tindices"), \ + UnsortedSegmentMaxOp<CPUDevice, type, index_type>); + +#define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, index_type) \ REGISTER_KERNEL_BUILDER(Name("UnsortedSegmentSum") \ .Device(DEVICE_CPU) \ .TypeConstraint<type>("T") \ .TypeConstraint<index_type>("Tindices"), \ UnsortedSegmentSumOp<CPUDevice, type, index_type>); -#define REGISTER_CPU_UNSORTED_KERNELS_ALL(type) \ - REGISTER_CPU_UNSORTED_KERNELS(type, int32); \ - REGISTER_CPU_UNSORTED_KERNELS(type, int64); - -TF_CALL_NUMBER_TYPES(REGISTER_CPU_UNSORTED_KERNELS_ALL); -#undef REGISTER_CPU_UNSORTED_KERNELS -#undef REGISTER_CPU_UNSORTED_KERNELS_ALL +#define REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL(type) \ + REGISTER_REAL_CPU_UNSORTED_KERNELS(type, int32); \ + REGISTER_REAL_CPU_UNSORTED_KERNELS(type, int64) + +#define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(type) \ + REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, int32); \ + REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, int64) + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL); +REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex64); +REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex128); +#undef REGISTER_REAL_CPU_UNSORTED_KERNELS +#undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS +#undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL +#undef REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL #if GOOGLE_CUDA #define REGISTER_GPU_UNSORTED_KERNELS(type, index_type) \ diff --git a/tensorflow/core/kernels/segment_reduction_ops.h b/tensorflow/core/kernels/segment_reduction_ops.h index 8ed990a1e0..ee09c213b7 100644 --- a/tensorflow/core/kernels/segment_reduction_ops.h +++ b/tensorflow/core/kernels/segment_reduction_ops.h @@ -26,6 +26,17 @@ namespace tensorflow { class OpKernelContext; namespace functor { +// BaseFunctor for definition of UnsorteSegmentReductionOp +// for usage without templates. +template <typename Device, typename T, typename Index> +struct UnsortedSegmentBaseFunctor{ + virtual ~UnsortedSegmentBaseFunctor(){} + virtual void operator()(OpKernelContext* ctx, const Device& d, + const Index output_rows, const TensorShape& segment_ids_shape, + typename TTypes<Index>::ConstFlat segment_ids, + const Index data_size, const T* data, + typename TTypes<T, 2>::Tensor output){}; +}; // Functor for UnsortedSegmentSumOp. // 'output_rows': the number of output segments (unique segment ids in @@ -37,7 +48,7 @@ namespace functor { // 'data': input data tensor. // 'output': output reshaped to {output_rows, output.size/output_rows} template <typename Device, typename T, typename Index> -struct UnsortedSegmentSumFunctor { +struct UnsortedSegmentSumFunctor: public UnsortedSegmentBaseFunctor<Device, T, Index> { void operator()(OpKernelContext* ctx, const Device& d, const Index output_rows, const TensorShape& segment_ids_shape, typename TTypes<Index>::ConstFlat segment_ids, @@ -45,6 +56,23 @@ struct UnsortedSegmentSumFunctor { typename TTypes<T, 2>::Tensor output); }; +// Functor for UnsortedSegmentMaxOp. +// 'output_rows': the number of output segments (unique segment ids in +// 'segment_ids'). +// 'segment_ids_shape': shape of 'segment_ids' tensor. +// 'segment_ids': unsorted map from input to output segment ids at which to +// perform segment sum operation. +// 'data_size': size of input data tensor. +// 'data': input data tensor. +// 'output': output reshaped to {output_rows, output.size/output_rows} +template <typename Device, typename T, typename Index> +struct UnsortedSegmentMaxFunctor: public UnsortedSegmentBaseFunctor<Device, T, Index> { + void operator()(OpKernelContext* ctx, const Device& d, + const Index output_rows, const TensorShape& segment_ids_shape, + typename TTypes<Index>::ConstFlat segment_ids, + const Index data_size, const T* data, + typename TTypes<T, 2>::Tensor output); +}; } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc index e0659c8bc9..5f53f098aa 100644 --- a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc @@ -56,12 +56,12 @@ namespace functor { // UnsortedSegmentSumFunctor implementation for GPUDevice. template <typename T, typename Index> -struct UnsortedSegmentSumFunctor<GPUDevice, T, Index> { +struct UnsortedSegmentSumFunctor<GPUDevice, T, Index>: UnsortedSegmentBaseFunctor<GPUDevice, T, Index> { void operator()(OpKernelContext* ctx, const GPUDevice& d, const Index output_rows, const TensorShape& segment_ids_shape, typename TTypes<Index>::ConstFlat segment_ids, const Index data_size, const T* data, - typename TTypes<T, 2>::Tensor output) { + typename TTypes<T, 2>::Tensor output) override { if (output.size() == 0) { return; } diff --git a/tensorflow/core/ops/math_grad.cc b/tensorflow/core/ops/math_grad.cc index 15bcd53322..a530d286f7 100644 --- a/tensorflow/core/ops/math_grad.cc +++ b/tensorflow/core/ops/math_grad.cc @@ -588,6 +588,7 @@ REGISTER_OP_GRADIENT("Mean", MeanGrad); // REGISTER_OP_GRADIENT("SegmentMin", SegmentMinGrad); // REGISTER_OP_GRADIENT("SegmentMax", SegmentMaxGrad); // REGISTER_OP_GRADIENT("UnsortedSegmentSum", UnsortedSegmentSumGrad); +// REGISTER_OP_GRADIENT("UnsortedSegmentMax", UnsortedSegmentMaxGrad); Status MinMaxGradHelper(const string& op, const AttrSlice& attrs, FunctionDef* g) { diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 92e2823fb2..00876bc18c 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -1342,6 +1342,36 @@ Status SparseSegmentReductionGradShapeFn(InferenceContext* c) { return Status::OK(); } +Status UnsortedSegmentReductionShapeFn(InferenceContext* c) { + ShapeHandle s_data = c->input(0); + ShapeHandle s_segment_ids = c->input(1); + ShapeHandle s_num_segments = c->input(2); + TF_RETURN_IF_ERROR(c->WithRank(s_num_segments, 0, &s_num_segments)); + + ShapeHandle out; + + // Leading dimensions of data must be compatible with dimensions of + // <s_segment_ids>. + if (c->RankKnown(s_segment_ids)) { + TF_RETURN_IF_ERROR( + c->MergePrefix(s_data, s_segment_ids, &s_data, &s_segment_ids)); + + // Get the value of the num_segments input tensor. + DimensionHandle num_segments_dim; + TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &num_segments_dim)); + + // Output is {segment_id_rank} + s_data[segment_id_rank:]. + ShapeHandle s_data_suffix; + TF_RETURN_IF_ERROR( + c->Subshape(s_data, c->Rank(s_segment_ids), &s_data_suffix)); + TF_RETURN_IF_ERROR( + c->Concatenate(c->Vector(num_segments_dim), s_data_suffix, &out)); + } else { + out = c->UnknownShape(); + } + c->set_output(0, out); + return Status::OK(); +} } // namespace REGISTER_OP("SegmentSum") @@ -1495,36 +1525,7 @@ REGISTER_OP("UnsortedSegmentSum") .Output("output: T") .Attr("T: numbertype") .Attr("Tindices: {int32,int64}") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle s_data = c->input(0); - ShapeHandle s_segment_ids = c->input(1); - ShapeHandle s_num_segments = c->input(2); - TF_RETURN_IF_ERROR(c->WithRank(s_num_segments, 0, &s_num_segments)); - - ShapeHandle out; - - // Leading dimensions of data must be compatible with dimensions of - // <s_segment_ids>. - if (c->RankKnown(s_segment_ids)) { - TF_RETURN_IF_ERROR( - c->MergePrefix(s_data, s_segment_ids, &s_data, &s_segment_ids)); - - // Get the value of the num_segments input tensor. - DimensionHandle num_segments_dim; - TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &num_segments_dim)); - - // Output is {segment_id_rank} + s_data[segment_id_rank:]. - ShapeHandle s_data_suffix; - TF_RETURN_IF_ERROR( - c->Subshape(s_data, c->Rank(s_segment_ids), &s_data_suffix)); - TF_RETURN_IF_ERROR( - c->Concatenate(c->Vector(num_segments_dim), s_data_suffix, &out)); - } else { - out = c->UnknownShape(); - } - c->set_output(0, out); - return Status::OK(); - }) + .SetShapeFn(UnsortedSegmentReductionShapeFn) .Doc(R"doc( Computes the sum along segments of a tensor. @@ -1554,6 +1555,43 @@ output: Has same shape as data, except for the first `segment_ids.rank` )doc"); + +REGISTER_OP("UnsortedSegmentMax") + .Input("data: T") + .Input("segment_ids: Tindices") + .Input("num_segments: int32") + .Output("output: T") + .Attr("T: realnumbertype") + .Attr("Tindices: {int32,int64}") + .SetShapeFn(UnsortedSegmentReductionShapeFn) + .Doc(R"doc( +Computes the Max along segments of a tensor. + +Read [the section on +Segmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation +of segments. + +This operator is similar to the [unsorted segment sum operator](../../api_docs/python/math_ops.md#UnsortedSegmentSum). +Instead of computing the sum over segments, it computes the maximum +such that: + +\\(output_i = \max_j data_j\\) where max is over `j` such +that `segment_ids[j] == i`. + +If the maximum is empty for a given segment ID `i`, it outputs the smallest possible value for specific numeric type, + `output[i] = numeric_limits<T>::min()`. + +<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> +<img style="width:100%" src="../../images/UnsortedSegmentSum.png" alt> +</div> + +segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s +first dimension. + +output: Has same shape as data, except for dimension 0 which +has size `num_segments`. + +)doc"); REGISTER_OP("SparseSegmentSum") .Input("data: T") .Input("indices: Tidx") diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 805ba2c662..73635d3417 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -25194,6 +25194,59 @@ op { description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nComputes a tensor such that\n`(output[i] = sum_{j...} data[j...]` where the sum is over tuples `j...` such\nthat `segment_ids[j...] == i`. Unlike `SegmentSum`, `segment_ids`\nneed not be sorted and need not cover all values in the full\nrange of valid values.\n\nIf the sum is empty for a given segment ID `i`, `output[i] = 0`.\n\n`num_segments` should equal the number of distinct segment IDs.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/UnsortedSegmentSum.png\" alt>\n</div>" } op { + name: "UnsortedSegmentSum" + input_arg { + name: "data" + type_attr: "T" + } + input_arg { + name: "segment_ids" + description: "A tensor whose shape is a prefix of `data.shape`." + type_attr: "Tindices" + } + input_arg { + name: "num_segments" + type: DT_INT32 + } + output_arg { + name: "output" + description: "Has same shape as data, except for the first `segment_ids.rank`\ndimensions, which are replaced with a single dimension which has size\n`num_segments`." + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "Tindices" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + summary: "Computes the max along segments of a tensor." + description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nComputes a tensor such that\n\\\\(output_i = \\sum_j data_j\\\\) where sum is over `j` such\nthat `segment_ids[j] == i`. Unlike `SegmentSum`, `segment_ids`\nneed not be sorted and need not cover all values in the full\n range of valid values.\n\nIf the sum is empty for a given segment ID `i`, `output[i] = 0`.\n\n`num_segments` should equal the number of distinct segment IDs.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/UnsortedSegmentSum.png\" alt>\n</div>" +} +op { name: "Unstage" output_arg { name: "values" diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py index d7e3b3e79b..485530d405 100644 --- a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py +++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py @@ -49,12 +49,21 @@ class SegmentReductionHelper(test.TestCase): slice_shape = x.shape[indices.ndim:] x_flat = x.reshape((indices.size,) + slice_shape) for i, index in enumerate(indices.ravel()): - if output[index] is not None: + if (output[index] is not None) and op1 == np.max: + for j in range(0, output[index].shape[0]): + output[index][j] = op1([output[index][j], x_flat[i][j]]) + elif output[index] is not None: output[index] = op1(output[index], x_flat[i]) else: output[index] = x_flat[i] # zero initialize values that are still uncalcuated. - output = [o if o is not None else np.zeros(slice_shape) for o in output] + # output = [o if o is not None else np.zeros(slice_shape) for o in output] + if not op1 == np.max: + output = [o if o is not None else np.zeros(slice_shape) for o in output] + else: + zeroslice = np.zeros(slice_shape) + zeroslice.fill(dtype.min) + output = [o if o is not None else zeroslice for o in output] if op2 is not None: output = [op2(o) for o in output] output = [o.reshape(slice_shape) for o in output] @@ -245,7 +254,7 @@ class UnsortedSegmentSumTest(SegmentReductionHelper): self._assertAllClose(indices, np_ans, tf_ans) self.assertShapeEqual(np_ans, s) - def testGradient(self): + def testGradientSegmentSum(self): num_cols = 2 indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3]) num_segments = max(indices_flat) + 3 @@ -318,6 +327,23 @@ class UnsortedSegmentSumTest(SegmentReductionHelper): unsorted = math_ops.unsorted_segment_sum(data, segment_ids, 2) self.assertAllEqual(unsorted.eval(), np.zeros((2, 0), dtype=dtype)) + def testGradientSegmentMax(self): + num_cols = 2 + indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3]) + num_segments = max(indices_flat) + 3 + for indices in indices_flat, indices_flat.reshape(5, 2): + shape = indices.shape + (num_cols,) + with self.test_session(): + tf_x, np_x = self._input(shape, dtype=dtypes_lib.float64) + s = math_ops.unsorted_segment_max(data=tf_x, segment_ids=indices, + num_segments=num_segments) + jacob_t, jacob_n = gradient_checker.compute_gradient( + tf_x, + shape, + s, + [num_segments, num_cols], + x_init_value=np_x.astype(np.double), delta=1) + self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3) class UnsortedSegmentSumGpuTest(UnsortedSegmentSumTest): use_gpu = True diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index e0232b35f0..b1e3e9e749 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -188,35 +188,42 @@ def _SparseSegmentSqrtNGrad(op, grad): dim0), None, None) -def _SegmentMinOrMaxGrad(op, grad): - """Gradient for SegmentMin and SegmentMax. Both share the same code.""" - zeros = array_ops.zeros( - array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype) +def _SegmentMinOrMaxGrad(op, grad, is_sorted): + """Gradient for SegmentMin and (unsorted) SegmentMax. They share similar code.""" + zeros = array_ops.zeros(array_ops.shape(op.inputs[0]), + dtype=op.inputs[0].dtype) # Get the number of selected (minimum or maximum) elements in each segment. gathered_outputs = array_ops.gather(op.outputs[0], op.inputs[1]) is_selected = math_ops.equal(op.inputs[0], gathered_outputs) - num_selected = math_ops.segment_sum( - math_ops.cast(is_selected, grad.dtype), op.inputs[1]) + if is_sorted: + num_selected = math_ops.segment_sum(math_ops.cast(is_selected, grad.dtype), + op.inputs[1]) + else: + num_selected = math_ops.unsorted_segment_sum(math_ops.cast(is_selected, grad.dtype), + op.inputs[1], op.inputs[2]) # Compute the gradient for each segment. The gradient for the ith segment is # divided evenly among the selected elements in that segment. weighted_grads = math_ops.div(grad, num_selected) gathered_grads = array_ops.gather(weighted_grads, op.inputs[1]) - return array_ops.where(is_selected, gathered_grads, zeros), None + if is_sorted: + return array_ops.where(is_selected, gathered_grads, zeros), None + else: + return array_ops.where(is_selected, gathered_grads, zeros), None, None @ops.RegisterGradient("SegmentMin") def _SegmentMinGrad(op, grad): """Gradient for SegmentMin.""" - return _SegmentMinOrMaxGrad(op, grad) + return _SegmentMinOrMaxGrad(op, grad, True) @ops.RegisterGradient("SegmentMax") def _SegmentMaxGrad(op, grad): """Gradient for SegmentMax.""" - return _SegmentMinOrMaxGrad(op, grad) + return _SegmentMinOrMaxGrad(op, grad, True) @ops.RegisterGradient("UnsortedSegmentSum") @@ -225,6 +232,11 @@ def _UnsortedSegmentSumGrad(op, grad): return array_ops.gather(grad, op.inputs[1]), None, None +@ops.RegisterGradient("UnsortedSegmentMax") +def _UnsortedSegmentMaxGrad(op, grad): + return _SegmentMinOrMaxGrad(op, grad, False) + + @ops.RegisterGradient("Abs") def _AbsGrad(op, grad): x = op.inputs[0] diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 11e7d8382f..1378bf8eca 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -196,6 +196,7 @@ tf.segment_sum(c, tf.constant([0, 0, 1])) @@segment_mean @@unsorted_segment_sum +@@unsorted_segment_max @@sparse_segment_sum @@sparse_segment_mean |