diff options
author | 2016-10-28 10:29:28 -0800 | |
---|---|---|
committer | 2016-10-28 11:38:26 -0700 | |
commit | e2d51a87f0727f8537b46048d8241aeebb6e48d6 (patch) | |
tree | 64c075f59bae00706a009e5d1ed15aaff6adc6ff /tensorflow/core/kernels/cwise_op_select.cc | |
parent | f80ef2d696456c970956f47e7d5aa88bc7ccbdce (diff) |
Merge changes from github.
Change: 137532946
Diffstat (limited to 'tensorflow/core/kernels/cwise_op_select.cc')
-rw-r--r-- | tensorflow/core/kernels/cwise_op_select.cc | 35 |
1 files changed, 35 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc index fbfde88e61..8160fb74c2 100644 --- a/tensorflow/core/kernels/cwise_op_select.cc +++ b/tensorflow/core/kernels/cwise_op_select.cc @@ -41,6 +41,11 @@ class SelectOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->input("t", &then)); OP_REQUIRES_OK(ctx, ctx->input("e", &else_)); + if (TensorShapeUtils::IsScalar(cond->shape())){ + ComputeScalar(ctx, cond, then, else_); + return; + } + bool broadcasting = (TensorShapeUtils::IsVector(cond->shape()) && !TensorShapeUtils::IsVector(then->shape())); @@ -108,6 +113,25 @@ class SelectOp : public OpKernel { } } + void ComputeScalar(OpKernelContext* ctx, const Tensor* cond, + const Tensor* then, const Tensor* else_) { + OP_REQUIRES( + ctx, then->shape().IsSameSize(else_->shape()), + errors::InvalidArgument( + "'then' and 'else' must have the same size. but received: ", + then->shape().DebugString(), " vs. ", + else_->shape().DebugString())); + + Tensor* output = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, then->shape(), &output)); + + if (output->NumElements() > 0) { + functor::SelectScalarFunctor<Device, T> func; + TTypes<bool>::ConstScalar cond_scalar = cond->scalar<bool>(); + func(ctx->eigen_device<Device>(), output->flat<T>(), cond_scalar, + then->flat<T>(), else_->flat<T>()); + } + } private: TF_DISALLOW_COPY_AND_ASSIGN(SelectOp); }; @@ -152,6 +176,17 @@ struct SelectFunctor<CPUDevice, T> { } }; +// CPU Specializations of Select functors with scalar +template <typename T> +struct SelectScalarFunctor<CPUDevice, T> { + void operator()(const CPUDevice& d, typename TTypes<T>::Flat out, + TTypes<bool>::ConstScalar cond, + typename TTypes<T>::ConstFlat then_flat, + typename TTypes<T>::ConstFlat else_flat) { + out.device(d) = cond() ? then_flat : else_flat; + } +}; + template <typename T> struct BatchSelectFunctor<CPUDevice, T> { void operator()(const CPUDevice& d, |