aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/cwise_op_select.cc
diff options
context:
space:
mode:
authorGravatar Maciek Chociej <maciekc@google.com>2017-02-25 08:32:26 -0800
committerGravatar Maciek Chociej <maciekc@google.com>2017-02-25 08:32:26 -0800
commitd9a08acc27e9276dc2c0e2128ed317a12b390237 (patch)
tree3865446b261fed46b31d6f03d1d60661ff4ba63a /tensorflow/core/kernels/cwise_op_select.cc
parent850938b3e36ca2725ade36928d59a2803f716812 (diff)
parent07427d1b51713a085f06c62b203799490591ed80 (diff)
Merge commit for internal changes
Diffstat (limited to 'tensorflow/core/kernels/cwise_op_select.cc')
-rw-r--r--tensorflow/core/kernels/cwise_op_select.cc37
1 files changed, 17 insertions, 20 deletions
diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc
index 1612429ddf..709628da13 100644
--- a/tensorflow/core/kernels/cwise_op_select.cc
+++ b/tensorflow/core/kernels/cwise_op_select.cc
@@ -45,9 +45,9 @@ class SelectOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->input("t", &then));
OP_REQUIRES_OK(ctx, ctx->input("e", &else_));
- if (TensorShapeUtils::IsScalar(cond->shape())){
- ComputeScalar(ctx, cond, then, else_);
- return;
+ if (TensorShapeUtils::IsScalar(cond->shape())) {
+ ComputeScalar(ctx, cond, then, else_);
+ return;
}
bool broadcasting = (TensorShapeUtils::IsVector(cond->shape()) &&
@@ -69,13 +69,15 @@ class SelectOp : public OpKernel {
errors::InvalidArgument("'cond' must be a vector, but saw shape: ",
cond->shape().DebugString()));
OP_REQUIRES(
- ctx, FastBoundsCheck(cond->NumElements(),
- std::numeric_limits<Eigen::DenseIndex>::max()),
+ ctx,
+ FastBoundsCheck(cond->NumElements(),
+ std::numeric_limits<Eigen::DenseIndex>::max()),
errors::InvalidArgument("cond vector larger than ",
std::numeric_limits<Eigen::DenseIndex>::max()));
OP_REQUIRES(
- ctx, FastBoundsCheck(then->flat_outer_dims<T>().dimension(1),
- std::numeric_limits<Eigen::DenseIndex>::max()),
+ ctx,
+ FastBoundsCheck(then->flat_outer_dims<T>().dimension(1),
+ std::numeric_limits<Eigen::DenseIndex>::max()),
errors::InvalidArgument("flat outer dims dim 1 size >= ",
std::numeric_limits<Eigen::DenseIndex>::max()));
@@ -96,10 +98,8 @@ class SelectOp : public OpKernel {
else_->shape().DebugString()));
Tensor* output = nullptr;
- if (!ctx->forward_input_to_output("t", "output", &output).ok() &&
- !ctx->forward_input_to_output("e", "output", &output).ok()) {
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, then->shape(), &output));
- }
+ OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
+ {"t", "e"}, "output", then->shape(), &output));
if (output->NumElements() > 0) {
functor::BatchSelectFunctor<Device, T> func;
func(ctx->eigen_device<Device>(), output->flat_outer_dims<T>(),
@@ -112,10 +112,8 @@ class SelectOp : public OpKernel {
const Tensor* then, const Tensor* else_) {
if (!ctx->ValidateInputsAreSameShape(this)) return;
Tensor* output = nullptr;
- if (!ctx->forward_input_to_output("t", "output", &output).ok() &&
- !ctx->forward_input_to_output("e", "output", &output).ok()) {
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, then->shape(), &output));
- }
+ OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
+ {"t", "e"}, "output", then->shape(), &output));
if (output->NumElements() > 0) {
functor::SelectFunctor<Device, T> func;
func(ctx->eigen_device<Device>(), output->flat<T>(), cond->flat<bool>(),
@@ -124,7 +122,7 @@ class SelectOp : public OpKernel {
}
void ComputeScalar(OpKernelContext* ctx, const Tensor* cond,
- const Tensor* then, const Tensor* else_) {
+ const Tensor* then, const Tensor* else_) {
OP_REQUIRES(
ctx, then->shape().IsSameSize(else_->shape()),
errors::InvalidArgument(
@@ -133,10 +131,8 @@ class SelectOp : public OpKernel {
else_->shape().DebugString()));
Tensor* output = nullptr;
- if (!ctx->forward_input_to_output("t", "output", &output).ok() &&
- !ctx->forward_input_to_output("e", "output", &output).ok()) {
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, then->shape(), &output));
- }
+ OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
+ {"t", "e"}, "output", then->shape(), &output));
if (output->NumElements() > 0) {
functor::SelectScalarFunctor<Device, T> func;
@@ -145,6 +141,7 @@ class SelectOp : public OpKernel {
then->flat<T>(), else_->flat<T>());
}
}
+
private:
TF_DISALLOW_COPY_AND_ASSIGN(SelectOp);
};