aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/kernels/cwise_op_select.cc19
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_test.py11
2 files changed, 22 insertions, 8 deletions
diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc
index 29b98e8bc9..e7e515b94e 100644
--- a/tensorflow/core/kernels/cwise_op_select.cc
+++ b/tensorflow/core/kernels/cwise_op_select.cc
@@ -77,11 +77,12 @@ class SelectOp : public OpKernel {
Tensor* output = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, then->shape(), &output));
-
- functor::BatchSelectFunctor<Device, T> func;
- func(ctx->eigen_device<Device>(), output->flat_outer_dims<T>(),
- cond->vec<bool>(), then->flat_outer_dims<T>(),
- else_->flat_outer_dims<T>());
+ if (output->NumElements() > 0) {
+ functor::BatchSelectFunctor<Device, T> func;
+ func(ctx->eigen_device<Device>(), output->flat_outer_dims<T>(),
+ cond->vec<bool>(), then->flat_outer_dims<T>(),
+ else_->flat_outer_dims<T>());
+ }
}
void ComputeElementwise(OpKernelContext* ctx, const Tensor* cond,
@@ -89,9 +90,11 @@ class SelectOp : public OpKernel {
if (!ctx->ValidateInputsAreSameShape(this)) return;
Tensor* output = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, then->shape(), &output));
- functor::SelectFunctor<Device, T> func;
- func(ctx->eigen_device<Device>(), output->flat<T>(), cond->flat<bool>(),
- then->flat<T>(), else_->flat<T>());
+ if (output->NumElements() > 0) {
+ functor::SelectFunctor<Device, T> func;
+ func(ctx->eigen_device<Device>(), output->flat<T>(), cond->flat<bool>(),
+ then->flat<T>(), else_->flat<T>());
+ }
}
private:
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index 0499c5ea07..347b844345 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -967,6 +967,17 @@ class SelectOpTest(tf.test.TestCase):
with self.assertRaises(ValueError):
tf.select(c, xt, yt)
+ def testEmptyTensor(self):
+ c = np.random.randint(0, 3, 0).astype(np.bool).reshape(1, 3, 0)
+ x = np.random.rand(1, 3, 0) * 100
+ y = np.random.rand(1, 3, 0) * 100
+ z_expected = np.zeros((1, 3, 0), dtype=np.float32)
+ with self.test_session():
+ xt = x.astype(np.float32)
+ yt = y.astype(np.float32)
+ z = tf.select(c, xt, yt).eval()
+ self.assertAllEqual(z_expected, z)
+
class BatchSelectOpTest(tf.test.TestCase):
"""Test broadcasting of Select when 'c' is a vec and 't' &'e' are rank2+."""