aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/cwise_op_clip.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/cwise_op_clip.cc')
-rw-r--r--tensorflow/core/kernels/cwise_op_clip.cc43
1 files changed, 16 insertions, 27 deletions
diff --git a/tensorflow/core/kernels/cwise_op_clip.cc b/tensorflow/core/kernels/cwise_op_clip.cc
index 14d889e8e3..49b90e855b 100644
--- a/tensorflow/core/kernels/cwise_op_clip.cc
+++ b/tensorflow/core/kernels/cwise_op_clip.cc
@@ -33,52 +33,41 @@ class ClipOp : public OpKernel {
const Tensor& in0 = ctx->input(0);
const Tensor& in1 = ctx->input(1);
const Tensor& in2 = ctx->input(2);
+ OP_REQUIRES(ctx, (in0.shape() == in1.shape() ||
+ TensorShapeUtils::IsScalar(in1.shape())) &&
+ (in0.shape() == in2.shape() ||
+ TensorShapeUtils::IsScalar(in2.shape())),
+ errors::InvalidArgument(
+ "clip_value_min and clip_value_max must be either of "
+ "the same shape as input, or a scalar. ",
+ "input shape: ", in0.shape().DebugString(),
+ "clip_value_min shape: ", in1.shape().DebugString(),
+ "clip_value_max shape: ", in2.shape().DebugString()));
+
+ Tensor* out = nullptr;
+ OP_REQUIRES_OK(
+ ctx, ctx->forward_input_or_allocate_output({0}, 0, in0.shape(), &out));
+ if (out->NumElements() == 0) return; // Nothing to do for empty output
auto in0_flat = in0.flat<T>();
auto in1_flat = in1.flat<T>();
auto in2_flat = in2.flat<T>();
+ auto out_flat = out->flat<T>();
const Device& d = ctx->eigen_device<Device>();
- Tensor* out = nullptr;
- OP_REQUIRES_OK(
- ctx, ctx->forward_input_or_allocate_output({0}, 0, in0.shape(), &out));
- auto out_flat = out->flat<T>();
if (in1.shape() == in2.shape()) {
if (in0.shape() == in1.shape()) {
functor::TernaryClipOp<Device, T>()(d, in0_flat, in1_flat, in2_flat,
out_flat);
} else {
- OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(in1.shape()),
- errors::InvalidArgument(
- "clip_value_min and clip_value_max must be either of "
- "the same shape as input, or a scalar. ",
- "input shape: ", in0.shape().DebugString(),
- "clip_value_min shape: ", in1.shape().DebugString(),
- "clip_value_max shape: ", in2.shape().DebugString()));
functor::UnaryClipOp<Device, T>()(d, in0_flat, in1_flat, in2_flat,
out_flat);
}
} else {
if (in0.shape() == in1.shape()) {
- OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(in2.shape()),
- errors::InvalidArgument(
- "clip_value_min and clip_value_max must be either of "
- "the same shape as input, or a scalar. ",
- "input shape: ", in0.shape().DebugString(),
- "clip_value_min shape: ", in1.shape().DebugString(),
- "clip_value_max shape: ", in2.shape().DebugString()));
functor::BinaryLeftClipOp<Device, T>()(d, in0_flat, in1_flat, in2_flat,
out_flat);
} else {
- OP_REQUIRES(ctx,
- (in0.shape() == in2.shape() &&
- TensorShapeUtils::IsScalar(in1.shape())),
- errors::InvalidArgument(
- "clip_value_min and clip_value_max must be either of "
- "the same shape as input, or a scalar. ",
- "input shape: ", in0.shape().DebugString(),
- "clip_value_min shape: ", in1.shape().DebugString(),
- "clip_value_max shape: ", in2.shape().DebugString()));
functor::BinaryRightClipOp<Device, T>()(d, in0_flat, in1_flat, in2_flat,
out_flat);
}