aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/cwise_op_select.cc
diff options
context:
space:
mode:
authorGravatar Xiaoqiang Zheng <zhengxq@google.com>2016-10-28 10:29:28 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-28 11:38:26 -0700
commite2d51a87f0727f8537b46048d8241aeebb6e48d6 (patch)
tree64c075f59bae00706a009e5d1ed15aaff6adc6ff /tensorflow/core/kernels/cwise_op_select.cc
parentf80ef2d696456c970956f47e7d5aa88bc7ccbdce (diff)
Merge changes from github.
Change: 137532946
Diffstat (limited to 'tensorflow/core/kernels/cwise_op_select.cc')
-rw-r--r--tensorflow/core/kernels/cwise_op_select.cc35
1 files changed, 35 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc
index fbfde88e61..8160fb74c2 100644
--- a/tensorflow/core/kernels/cwise_op_select.cc
+++ b/tensorflow/core/kernels/cwise_op_select.cc
@@ -41,6 +41,11 @@ class SelectOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->input("t", &then));
OP_REQUIRES_OK(ctx, ctx->input("e", &else_));
+ if (TensorShapeUtils::IsScalar(cond->shape())){
+ ComputeScalar(ctx, cond, then, else_);
+ return;
+ }
+
bool broadcasting = (TensorShapeUtils::IsVector(cond->shape()) &&
!TensorShapeUtils::IsVector(then->shape()));
@@ -108,6 +113,25 @@ class SelectOp : public OpKernel {
}
}
+ void ComputeScalar(OpKernelContext* ctx, const Tensor* cond,
+ const Tensor* then, const Tensor* else_) {
+ OP_REQUIRES(
+ ctx, then->shape().IsSameSize(else_->shape()),
+ errors::InvalidArgument(
+ "'then' and 'else' must have the same size. but received: ",
+ then->shape().DebugString(), " vs. ",
+ else_->shape().DebugString()));
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, then->shape(), &output));
+
+ if (output->NumElements() > 0) {
+ functor::SelectScalarFunctor<Device, T> func;
+ TTypes<bool>::ConstScalar cond_scalar = cond->scalar<bool>();
+ func(ctx->eigen_device<Device>(), output->flat<T>(), cond_scalar,
+ then->flat<T>(), else_->flat<T>());
+ }
+ }
private:
TF_DISALLOW_COPY_AND_ASSIGN(SelectOp);
};
@@ -152,6 +176,17 @@ struct SelectFunctor<CPUDevice, T> {
}
};
+// CPU Specializations of Select functors with scalar
+template <typename T>
+struct SelectScalarFunctor<CPUDevice, T> {
+ void operator()(const CPUDevice& d, typename TTypes<T>::Flat out,
+ TTypes<bool>::ConstScalar cond,
+ typename TTypes<T>::ConstFlat then_flat,
+ typename TTypes<T>::ConstFlat else_flat) {
+ out.device(d) = cond() ? then_flat : else_flat;
+ }
+};
+
template <typename T>
struct BatchSelectFunctor<CPUDevice, T> {
void operator()(const CPUDevice& d,