aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/tests/scalar_computations_test.cc13
1 files changed, 13 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
index f3cbc01323..77d1c019f3 100644
--- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc
+++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
@@ -151,6 +151,19 @@ XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsS32) {
ComputeAndCompareR0<int32>(&builder, -3, {});
}
+XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.Parameter(0, ShapeUtil::MakeShape(S64, {}), "a");
+ builder.ConvertElementType(a, F32);
+
+ int64 value = 3LL << 32;
+ std::unique_ptr<Literal> a_literal = Literal::CreateR0<int64>(value);
+ std::unique_ptr<GlobalData> a_data =
+ client_->TransferToServer(*a_literal).ConsumeValueOrDie();
+ ComputeAndCompareR0<float>(&builder, static_cast<float>(value),
+ {a_data.get()});
+}
+
XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32) {
ComputationBuilder builder(client_, TestName());
builder.Mul(builder.Mul(builder.ConstantR0<float>(2.1f),