diff options
author | 2018-02-02 23:02:16 -0800 | |
---|---|---|
committer | 2018-02-02 23:06:15 -0800 | |
commit | 34bff30979896879815dd6fc4d77c1a37d9b98a0 (patch) | |
tree | d3949669ed961164340f542950f499ca29611517 | |
parent | a4f4d3131cdd63c871e864d631f8c8466924f10c (diff) |
[XLA] Add tests for Clamp with scalars S32 and U32.
PiperOrigin-RevId: 184376425
-rw-r--r-- | tensorflow/compiler/xla/tests/scalar_computations_test.cc | 60 |
1 files changed, 57 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index 43e4d891a1..4da6ee9160 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -737,7 +737,61 @@ XLA_TEST_F(ScalarComputationsTest, PowScalar) { ComputeAndCompareR0<float>(&builder, 8.0, {}, error_spec_); } -XLA_TEST_F(ScalarComputationsTest, ClampScalarHigh) { +XLA_TEST_F(ScalarComputationsTest, ClampScalarHighS32) { + ComputationBuilder builder(client_, TestName()); + builder.Clamp(builder.ConstantR0<int32>(-1), // The lower bound. + builder.ConstantR0<int32>(5), // The operand to be clamped. + builder.ConstantR0<int32>(3)); // The upper bound. + + ComputeAndCompareR0<int32>(&builder, 3, {}); +} + +XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleS32) { + ComputationBuilder builder(client_, TestName()); + builder.Clamp(builder.ConstantR0<int32>(-1), // The lower bound. + builder.ConstantR0<int32>(2), // The operand to be clamped. + builder.ConstantR0<int32>(3)); // The upper bound. + + ComputeAndCompareR0<int32>(&builder, 2, {}); +} + +XLA_TEST_F(ScalarComputationsTest, ClampScalarLowS32) { + ComputationBuilder builder(client_, TestName()); + builder.Clamp(builder.ConstantR0<int32>(-1), // The lower bound. + builder.ConstantR0<int32>(-5), // The operand to be clamped. + builder.ConstantR0<int32>(3)); // The upper bound. + + ComputeAndCompareR0<int32>(&builder, -1, {}); +} + +XLA_TEST_F(ScalarComputationsTest, ClampScalarHighU32) { + ComputationBuilder builder(client_, TestName()); + builder.Clamp(builder.ConstantR0<uint32>(1), // The lower bound. + builder.ConstantR0<uint32>(5), // The operand to be clamped. + builder.ConstantR0<uint32>(3)); // The upper bound. + + ComputeAndCompareR0<uint32>(&builder, 3, {}); +} + +XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleU32) { + ComputationBuilder builder(client_, TestName()); + builder.Clamp(builder.ConstantR0<uint32>(1), // The lower bound. + builder.ConstantR0<uint32>(2), // The operand to be clamped. + builder.ConstantR0<uint32>(3)); // The upper bound. + + ComputeAndCompareR0<uint32>(&builder, 2, {}); +} + +XLA_TEST_F(ScalarComputationsTest, ClampScalarLowU32) { + ComputationBuilder builder(client_, TestName()); + builder.Clamp(builder.ConstantR0<uint32>(1), // The lower bound. + builder.ConstantR0<uint32>(0), // The operand to be clamped. + builder.ConstantR0<uint32>(3)); // The upper bound. + + ComputeAndCompareR0<uint32>(&builder, 1, {}); +} + +XLA_TEST_F(ScalarComputationsTest, ClampScalarHighF32) { ComputationBuilder builder(client_, TestName()); builder.Clamp(builder.ConstantR0<float>(2.0f), // The lower bound. builder.ConstantR0<float>(5.0f), // The operand to be clamped. @@ -746,7 +800,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarHigh) { ComputeAndCompareR0<float>(&builder, 3.0, {}, error_spec_); } -XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddle) { +XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleF32) { ComputationBuilder builder(client_, TestName()); builder.Clamp(builder.ConstantR0<float>(2.0f), // The lower bound. builder.ConstantR0<float>(2.5f), // The operand to be clamped. @@ -755,7 +809,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddle) { ComputeAndCompareR0<float>(&builder, 2.5, {}, error_spec_); } -XLA_TEST_F(ScalarComputationsTest, ClampScalarLow) { +XLA_TEST_F(ScalarComputationsTest, ClampScalarLowF32) { ComputationBuilder builder(client_, TestName()); builder.Clamp(builder.ConstantR0<float>(2.0f), // The lower bound. builder.ConstantR0<float>(-5.0f), // The operand to be clamped. |