diff options
author | 2017-02-25 08:32:26 -0800 | |
---|---|---|
committer | 2017-02-25 08:32:26 -0800 | |
commit | d9a08acc27e9276dc2c0e2128ed317a12b390237 (patch) | |
tree | 3865446b261fed46b31d6f03d1d60661ff4ba63a /tensorflow/core/kernels/cwise_op_select.cc | |
parent | 850938b3e36ca2725ade36928d59a2803f716812 (diff) | |
parent | 07427d1b51713a085f06c62b203799490591ed80 (diff) |
Merge commit for internal changes
Diffstat (limited to 'tensorflow/core/kernels/cwise_op_select.cc')
-rw-r--r-- | tensorflow/core/kernels/cwise_op_select.cc | 37 |
1 files changed, 17 insertions, 20 deletions
diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc index 1612429ddf..709628da13 100644 --- a/tensorflow/core/kernels/cwise_op_select.cc +++ b/tensorflow/core/kernels/cwise_op_select.cc @@ -45,9 +45,9 @@ class SelectOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->input("t", &then)); OP_REQUIRES_OK(ctx, ctx->input("e", &else_)); - if (TensorShapeUtils::IsScalar(cond->shape())){ - ComputeScalar(ctx, cond, then, else_); - return; + if (TensorShapeUtils::IsScalar(cond->shape())) { + ComputeScalar(ctx, cond, then, else_); + return; } bool broadcasting = (TensorShapeUtils::IsVector(cond->shape()) && @@ -69,13 +69,15 @@ class SelectOp : public OpKernel { errors::InvalidArgument("'cond' must be a vector, but saw shape: ", cond->shape().DebugString())); OP_REQUIRES( - ctx, FastBoundsCheck(cond->NumElements(), - std::numeric_limits<Eigen::DenseIndex>::max()), + ctx, + FastBoundsCheck(cond->NumElements(), + std::numeric_limits<Eigen::DenseIndex>::max()), errors::InvalidArgument("cond vector larger than ", std::numeric_limits<Eigen::DenseIndex>::max())); OP_REQUIRES( - ctx, FastBoundsCheck(then->flat_outer_dims<T>().dimension(1), - std::numeric_limits<Eigen::DenseIndex>::max()), + ctx, + FastBoundsCheck(then->flat_outer_dims<T>().dimension(1), + std::numeric_limits<Eigen::DenseIndex>::max()), errors::InvalidArgument("flat outer dims dim 1 size >= ", std::numeric_limits<Eigen::DenseIndex>::max())); @@ -96,10 +98,8 @@ class SelectOp : public OpKernel { else_->shape().DebugString())); Tensor* output = nullptr; - if (!ctx->forward_input_to_output("t", "output", &output).ok() && - !ctx->forward_input_to_output("e", "output", &output).ok()) { - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, then->shape(), &output)); - } + OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( + {"t", "e"}, "output", then->shape(), &output)); if (output->NumElements() > 0) { functor::BatchSelectFunctor<Device, T> func; func(ctx->eigen_device<Device>(), output->flat_outer_dims<T>(), @@ -112,10 +112,8 @@ class SelectOp : public OpKernel { const Tensor* then, const Tensor* else_) { if (!ctx->ValidateInputsAreSameShape(this)) return; Tensor* output = nullptr; - if (!ctx->forward_input_to_output("t", "output", &output).ok() && - !ctx->forward_input_to_output("e", "output", &output).ok()) { - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, then->shape(), &output)); - } + OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( + {"t", "e"}, "output", then->shape(), &output)); if (output->NumElements() > 0) { functor::SelectFunctor<Device, T> func; func(ctx->eigen_device<Device>(), output->flat<T>(), cond->flat<bool>(), @@ -124,7 +122,7 @@ class SelectOp : public OpKernel { } void ComputeScalar(OpKernelContext* ctx, const Tensor* cond, - const Tensor* then, const Tensor* else_) { + const Tensor* then, const Tensor* else_) { OP_REQUIRES( ctx, then->shape().IsSameSize(else_->shape()), errors::InvalidArgument( @@ -133,10 +131,8 @@ class SelectOp : public OpKernel { else_->shape().DebugString())); Tensor* output = nullptr; - if (!ctx->forward_input_to_output("t", "output", &output).ok() && - !ctx->forward_input_to_output("e", "output", &output).ok()) { - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, then->shape(), &output)); - } + OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( + {"t", "e"}, "output", then->shape(), &output)); if (output->NumElements() > 0) { functor::SelectScalarFunctor<Device, T> func; @@ -145,6 +141,7 @@ class SelectOp : public OpKernel { then->flat<T>(), else_->flat<T>()); } } + private: TF_DISALLOW_COPY_AND_ASSIGN(SelectOp); }; |