aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/cwise_op_select.cc
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2016-03-07 14:34:07 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-03-07 15:13:39 -0800
commitd82cc200e027552d46f1a0ff193f9cd48992c3f0 (patch)
treeb92e0bca26a5ee44e41db371ce432d6d4ccae986 /tensorflow/core/kernels/cwise_op_select.cc
parentcc62d992bfaa41744f91ddf00e0c835d375f4f5c (diff)
TensorFlow: select op: add support for empty tensors: select
can propagate empty tensors. Change: 116588842
Diffstat (limited to 'tensorflow/core/kernels/cwise_op_select.cc')
-rw-r--r--tensorflow/core/kernels/cwise_op_select.cc19
1 files changed, 11 insertions, 8 deletions
diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc
index 29b98e8bc9..e7e515b94e 100644
--- a/tensorflow/core/kernels/cwise_op_select.cc
+++ b/tensorflow/core/kernels/cwise_op_select.cc
@@ -77,11 +77,12 @@ class SelectOp : public OpKernel {
Tensor* output = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, then->shape(), &output));
-
- functor::BatchSelectFunctor<Device, T> func;
- func(ctx->eigen_device<Device>(), output->flat_outer_dims<T>(),
- cond->vec<bool>(), then->flat_outer_dims<T>(),
- else_->flat_outer_dims<T>());
+ if (output->NumElements() > 0) {
+ functor::BatchSelectFunctor<Device, T> func;
+ func(ctx->eigen_device<Device>(), output->flat_outer_dims<T>(),
+ cond->vec<bool>(), then->flat_outer_dims<T>(),
+ else_->flat_outer_dims<T>());
+ }
}
void ComputeElementwise(OpKernelContext* ctx, const Tensor* cond,
@@ -89,9 +90,11 @@ class SelectOp : public OpKernel {
if (!ctx->ValidateInputsAreSameShape(this)) return;
Tensor* output = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, then->shape(), &output));
- functor::SelectFunctor<Device, T> func;
- func(ctx->eigen_device<Device>(), output->flat<T>(), cond->flat<bool>(),
- then->flat<T>(), else_->flat<T>());
+ if (output->NumElements() > 0) {
+ functor::SelectFunctor<Device, T> func;
+ func(ctx->eigen_device<Device>(), output->flat<T>(), cond->flat<bool>(),
+ then->flat<T>(), else_->flat<T>());
+ }
}
private: