/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include #include #include #include #include #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_runner.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace { class MultiOutputFusionTest : public HloTestBase { protected: MultiOutputFusionTest() { error_spec_ = ErrorSpec{0.0001, 1e-2}; } // Layout assignment assumes that there are no fusions in the input graph. // Since the purpose of this test is to send pre-fused graphs to XLA, we have // to do layout assignment ourselves. DebugOptions GetDebugOptionsForTest() override { auto opts = HloTestBase::GetDebugOptionsForTest(); opts.add_xla_disable_hlo_passes("layout-assignment"); return opts; } void RunTest2D(bool manual_fusion, int64 size) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); const Shape elem_shape0 = ShapeUtil::MakeShapeWithLayout(F32, {}, {}); const Shape elem_shape2 = ShapeUtil::MakeShapeWithLayout(F32, {size, size}, {1, 0}); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(8.0f))); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, elem_shape0, "0")); auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( elem_shape0, HloOpcode::kAdd, param0, const0)); HloInstruction* broadcast = builder.AddInstruction( HloInstruction::CreateBroadcast(elem_shape2, add1, {})); auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, elem_shape2, "1")); HloInstruction* add2 = builder.AddInstruction(HloInstruction::CreateBinary( elem_shape2, HloOpcode::kAdd, broadcast, param1)); HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( elem_shape2, HloOpcode::kSubtract, param1, broadcast)); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( elem_shape2, sub, add2, dot_dnums, DefaultPrecisionConfig(2))); auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); if (manual_fusion) { auto tuple = computation->AddInstruction(HloInstruction::CreateTuple({sub, add2})); auto gte0 = computation->AddInstruction( HloInstruction::CreateGetTupleElement(elem_shape2, tuple, 0)); auto gte1 = computation->AddInstruction( HloInstruction::CreateGetTupleElement(elem_shape2, tuple, 1)); TF_CHECK_OK(dot->ReplaceOperandWith(0, gte0)); TF_CHECK_OK(dot->ReplaceOperandWith(1, gte1)); CHECK_NE( computation->CreateFusionInstruction( {tuple, sub, add2, broadcast}, HloInstruction::FusionKind::kLoop), nullptr); } Literal arg1(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, size})); arg1.PopulateWithValue(2.5f); Literal expect(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, size})); expect.PopulateWithValue(size * 1.5f * 3.5f); Literal literal_r0 = LiteralUtil::CreateR0(-9.0f); auto actual = ExecuteAndTransfer(std::move(hlo_module), {&literal_r0, &arg1}); EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, error_spec_)); } void RunTest1D(bool manual_fusion, int size) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); const Shape elem_shape_F32 = ShapeUtil::MakeShapeWithDescendingLayout(F32, {size}); const Shape elem_shape_U8 = ShapeUtil::MakeShapeWithDescendingLayout(F64, {size}); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, elem_shape_F32, "0")); auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, elem_shape_U8, "1")); HloInstruction* param0_U8 = builder.AddInstruction( HloInstruction::CreateConvert(elem_shape_U8, param0)); HloInstruction* param1_F32 = builder.AddInstruction( HloInstruction::CreateConvert(elem_shape_F32, param1)); HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( elem_shape_F32, HloOpcode::kAdd, param0, param1_F32)); HloInstruction* sub_U8 = builder.AddInstruction(HloInstruction::CreateBinary( elem_shape_U8, HloOpcode::kSubtract, param0_U8, param1)); HloInstruction* sub = builder.AddInstruction( HloInstruction::CreateConvert(elem_shape_F32, sub_U8)); HloInstruction* reshape = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, 1}), add)); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(0); HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( ShapeUtil::MakeShapeWithDescendingLayout(F32, {1}), sub, reshape, dot_dnums, DefaultPrecisionConfig(2))); auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); if (manual_fusion) { auto tuple = computation->AddInstruction( HloInstruction::CreateTuple({sub_U8, add})); auto gte0 = computation->AddInstruction( HloInstruction::CreateGetTupleElement(elem_shape_U8, tuple, 0)); auto gte1 = computation->AddInstruction( HloInstruction::CreateGetTupleElement(elem_shape_F32, tuple, 1)); TF_CHECK_OK(sub->ReplaceOperandWith(0, gte0)); TF_CHECK_OK(reshape->ReplaceOperandWith(0, gte1)); CHECK_NE(computation->CreateFusionInstruction( {tuple, sub_U8, add, param0_U8, param1_F32}, HloInstruction::FusionKind::kLoop), nullptr); } Literal input0(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size})); input0.PopulateWithValue(2.5f); Literal input1(ShapeUtil::MakeShapeWithDescendingLayout(F64, {size})); input1.PopulateWithValue(1.); Literal expect = LiteralUtil::CreateR1({size * 1.5f * 3.5f}); auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1}); EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, error_spec_)); } }; XLA_TEST_F(MultiOutputFusionTest, 2DNofusion) { RunTest2D(false, 5); } XLA_TEST_F(MultiOutputFusionTest, 2DFusion) { RunTest2D(true, 5); } XLA_TEST_F(MultiOutputFusionTest, 2DFusionSize129) { RunTest2D(true, 129); } XLA_TEST_F(MultiOutputFusionTest, DiffentTypesNoFusion) { RunTest1D(false, 8); } XLA_TEST_F(MultiOutputFusionTest, DiffentTypesFusion) { RunTest1D(true, 8); } XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) { const char* testcase = R"( HloModule m fused_computation { x.param_0 = (((s32[]), f32[]), (f32[], s32[])) parameter(0) gte.3 = ((s32[]), f32[]) get-tuple-element(x.param_0), index=0 gte.2 = (s32[]) get-tuple-element(gte.3), index=0 gte.4 = s32[] get-tuple-element(gte.2), index=0 copy = s32[] copy(gte.4) ROOT tuple = (s32[]) tuple(copy) } ENTRY thing.v3 { x = (((s32[]), f32[]), (f32[], s32[])) parameter(0) ROOT fusion = (s32[]) fusion(x), kind=kLoop, calls=fused_computation } )"; auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = LiteralUtil::MakeTupleOwned( LiteralUtil::MakeTupleOwned( LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(42)), LiteralUtil::CreateR0(1.0)), LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(3.0), LiteralUtil::CreateR0(4))); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(42)), result)); } XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { const char* testcase = R"( HloModule m fused_computation { p = f32[4] parameter(0) multiply = f32[4] multiply(p, p) less-than = pred[4] less-than(p, multiply) ROOT tuple = (pred[4], f32[4]) tuple(less-than, multiply) } ENTRY PredFloatMOF { p0 = f32[4] parameter(0) fusion = (pred[4], f32[4]) fusion(p0), kind=kLoop, calls=fused_computation gte0 = pred[4] get-tuple-element(fusion), index=0 gte1 = f32[4] get-tuple-element(fusion), index=1 const = f32[4] constant({0, 0, 0, 0}) ROOT select = f32[4] select(gte0, gte1, const) })"; auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = LiteralUtil::CreateR1({1.0, 2.0, 3.0, -1.0}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0, 1.0}, result); } XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { const char* testcase = R"( HloModule m fused_computation { p = f32[] parameter(0) multiply = f32[] multiply(p, p) less-than = pred[] less-than(p, multiply) ROOT tuple = (pred[], f32[]) tuple(less-than, multiply) } map_computation { p0 = f32[] parameter(0) fusion = (pred[], f32[]) fusion(p0), kind=kLoop, calls=fused_computation gte0 = pred[] get-tuple-element(fusion), index=0 gte1 = f32[] get-tuple-element(fusion), index=1 const = f32[] constant(0) ROOT select = f32[] select(gte0, gte1, const) } ENTRY MapMOF { p1 = f32[3] parameter(0) ROOT map = f32[3] map(p1), to_apply=map_computation })"; auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = LiteralUtil::CreateR1({1.0, 2.0, 3.0}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0}, result); } const char* const kScalarOps = R"( HloModule m Add { lhsadd = f32[] parameter(0) rhsadd = f32[] parameter(1) ROOT add = f32[] add(lhsadd, rhsadd) } Max { lhsmax = f32[] parameter(0) rhsmax = f32[] parameter(1) ROOT max = f32[] maximum(lhsmax, rhsmax) } )"; XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionMinor)) { const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={2}, to_apply=Add mul = f32[2,2,2]{2,1,0} multiply(p0, p0) c1 = f32[] constant(5) r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) } ENTRY reduce { p = f32[2,2,2]{2,1,0} parameter(0) ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput, calls=fused_reduce })"); auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR2({{3, 7}, {11, 15}}), LiteralUtil::CreateR2({{5, 16}, {36, 64}})), result)); } XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionMajor)) { const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={0}, to_apply=Add mul = f32[2,2,2]{2,1,0} multiply(p0, p0) c1 = f32[] constant(5) r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={0}, to_apply=Max ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) } ENTRY reduce { p = f32[2,2,2]{2,1,0} parameter(0) ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput, calls=fused_reduce })"); auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR2({{6, 8}, {10, 12}}), LiteralUtil::CreateR2({{25, 36}, {49, 64}})), result)); } XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionScalar)) { const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) r1 = f32[2]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add mul = f32[2,2,2]{2,1,0} multiply(p0, p0) c1 = f32[] constant(1.17549e-38) r2 = f32[2]{0} reduce(mul, c1), dimensions={0,2}, to_apply=Max r3 = f32[2]{0} reduce(mul, c0), dimensions={0,2}, to_apply=Add ROOT tuple = (f32[2]{0}, f32[2]{0}, f32[2]{0}) tuple(r1, r2, r3) } ENTRY reduce { p = f32[2,2,2]{2,1,0} parameter(0) ROOT fusion = (f32[2]{0}, f32[2]{0}, f32[2]{0}) fusion(p), kind=kInput, calls=fused_reduce })"); auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({14, 22}), LiteralUtil::CreateR1({36, 64}), LiteralUtil::CreateR1({66, 138})), result)); } XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionMinorWithExtraOutput)) { const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={2}, to_apply=Add mul = f32[2,2,2]{2,1,0} multiply(p0, p0) c1 = f32[] constant(5) r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max ROOT tuple = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(p0, r1, r2) } ENTRY reduce { p = f32[2,2,2]{2,1,0} parameter(0) ROOT fusion = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput, calls=fused_reduce })"); auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}), LiteralUtil::CreateR2({{3, 7}, {11, 15}}), LiteralUtil::CreateR2({{5, 16}, {36, 64}})), result)); } XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionMajorWithExtraOutput)) { const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={0}, to_apply=Add mul = f32[2,2,2]{2,1,0} multiply(p0, p0) c1 = f32[] constant(5) r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={0}, to_apply=Max ROOT tuple = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0}) tuple(r1, mul, r2) } ENTRY reduce { p = f32[2,2,2]{2,1,0} parameter(0) ROOT fusion = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput, calls=fused_reduce })"); auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR2({{6, 8}, {10, 12}}), LiteralUtil::CreateR3( {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), LiteralUtil::CreateR2({{25, 36}, {49, 64}})), result)); } XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionScalarWithExtraOutput)) { const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) r1 = f32[2]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add mul = f32[2,2,2]{2,1,0} multiply(p0, p0) c1 = f32[] constant(5) b1 = f32[2,2,2]{2,1,0} broadcast(c1), dimensions={} mul2 = f32[2,2,2]{2,1,0} multiply(p0, b1) ROOT tuple = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0}) tuple(r1, mul, mul2) } ENTRY reduce { p = f32[2,2,2]{2,1,0} parameter(0) ROOT fusion = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0}) fusion(p), kind=kInput, calls=fused_reduce })"); auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR1({14, 22}), LiteralUtil::CreateR3( {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), LiteralUtil::CreateR3( {{{5, 10}, {15, 20}}, {{25, 30}, {35, 40}}})), result)); } XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionNonConstInit)) { const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) init1 = f32[] parameter(1) init2 = f32[] parameter(2) r1 = f32[2,2]{1,0} reduce(p0, init1), dimensions={2}, to_apply=Add r2 = f32[2,2]{1,0} reduce(p0, init2), dimensions={2}, to_apply=Max ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) } ENTRY reduce { p = f32[2,2,2]{2,1,0} parameter(0) i = f32[] parameter(1) j = f32[] parameter(2) ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p, i, j), kind=kInput, calls=fused_reduce })"); auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); auto init1 = LiteralUtil::CreateR0(5); auto init2 = LiteralUtil::CreateR0(6); Literal result = ExecuteNoHloPasses(std::move(module), {¶m, &init1, &init2}); EXPECT_TRUE(LiteralTestUtil::Equal( LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR2({{167, 172}, {176, 180}}), LiteralUtil::CreateR2({{6, 6}, {6, 8}})), result)); } XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionDifferentElementTypes)) { const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce (p0: f16[2,2,2]) -> (f32[2,2], f32[2,2], f16[2,2,2]) { p0 = f16[2,2,2]{2,1,0} parameter(0) convert = f32[2,2,2]{2,1,0} convert(p0) c0 = f32[] constant(0) r1 = f32[2,2]{1,0} reduce(convert, c0), dimensions={2}, to_apply=Add mul = f32[2,2,2]{2,1,0} multiply(convert, convert) c1 = f32[] constant(5) r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) tuple(r1, r2, p0) } ENTRY reduce { p = f16[2,2,2]{2,1,0} parameter(0) ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) fusion(p), kind=kInput, calls=fused_reduce })"); auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = LiteralUtil::CreateR3( {{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}}, {{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR2({{3, 7}, {11, 15}}), LiteralUtil::CreateR2({{5, 16}, {36, 64}}), LiteralUtil::CreateR3( {{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}}, {{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}})), result)); } } // namespace } // namespace xla