diff options
-rw-r--r-- | tensorflow/core/kernels/cwise_op_select.cc | 19 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/cwise_ops_test.py | 11 |
2 files changed, 22 insertions, 8 deletions
diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc index 29b98e8bc9..e7e515b94e 100644 --- a/tensorflow/core/kernels/cwise_op_select.cc +++ b/tensorflow/core/kernels/cwise_op_select.cc @@ -77,11 +77,12 @@ class SelectOp : public OpKernel { Tensor* output = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, then->shape(), &output)); - - functor::BatchSelectFunctor<Device, T> func; - func(ctx->eigen_device<Device>(), output->flat_outer_dims<T>(), - cond->vec<bool>(), then->flat_outer_dims<T>(), - else_->flat_outer_dims<T>()); + if (output->NumElements() > 0) { + functor::BatchSelectFunctor<Device, T> func; + func(ctx->eigen_device<Device>(), output->flat_outer_dims<T>(), + cond->vec<bool>(), then->flat_outer_dims<T>(), + else_->flat_outer_dims<T>()); + } } void ComputeElementwise(OpKernelContext* ctx, const Tensor* cond, @@ -89,9 +90,11 @@ class SelectOp : public OpKernel { if (!ctx->ValidateInputsAreSameShape(this)) return; Tensor* output = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, then->shape(), &output)); - functor::SelectFunctor<Device, T> func; - func(ctx->eigen_device<Device>(), output->flat<T>(), cond->flat<bool>(), - then->flat<T>(), else_->flat<T>()); + if (output->NumElements() > 0) { + functor::SelectFunctor<Device, T> func; + func(ctx->eigen_device<Device>(), output->flat<T>(), cond->flat<bool>(), + then->flat<T>(), else_->flat<T>()); + } } private: diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py index 0499c5ea07..347b844345 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_test.py @@ -967,6 +967,17 @@ class SelectOpTest(tf.test.TestCase): with self.assertRaises(ValueError): tf.select(c, xt, yt) + def testEmptyTensor(self): + c = np.random.randint(0, 3, 0).astype(np.bool).reshape(1, 3, 0) + x = np.random.rand(1, 3, 0) * 100 + y = np.random.rand(1, 3, 0) * 100 + z_expected = np.zeros((1, 3, 0), dtype=np.float32) + with self.test_session(): + xt = x.astype(np.float32) + yt = y.astype(np.float32) + z = tf.select(c, xt, yt).eval() + self.assertAllEqual(z_expected, z) + class BatchSelectOpTest(tf.test.TestCase): """Test broadcasting of Select when 'c' is a vec and 't' &'e' are rank2+.""" |