diff options
author | 2018-07-03 11:23:31 -0700 | |
---|---|---|
committer | 2018-07-03 11:26:48 -0700 | |
commit | bdd3a01d20ec6747cc6efc39fe42ed5f29d2c97e (patch) | |
tree | 341a217fc6b428e4e1859068fd4698a59f221f22 | |
parent | 3a118b84b1845092e1900f7aadaea706d18f214b (diff) |
Compare layouts when propagating linear indices through elemental ops in fusions
PiperOrigin-RevId: 203154982
-rw-r--r-- | tensorflow/compiler/xla/service/elemental_ir_emitter.cc | 9 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/fusion_test.cc | 34 |
3 files changed, 43 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index ce0951bbe1..21c6f7d358 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -1227,7 +1227,14 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( // If no implicit broadcast is needed for this operand, returns the target // index as the source index. - if (ShapeUtil::CompatibleIgnoringElementType(operand_shape, hlo.shape())) { + // + // `IrArray::Index` may contain a physical linear which we can propagate to + // our operand only if our layouts match. "only if" is a bit strong since + // e.g. we can still forward the linear index if the operand shape is + // [5,1,1,5]{3,2,1,0} and the HLO shape is[5,1,1,5]{3,1,2,0}, but those cases + // are probably not worth handling here for now. + if (ShapeUtil::CompatibleIgnoringElementType(operand_shape, hlo.shape()) && + LayoutUtil::Equal(operand_shape.layout(), hlo.shape().layout())) { return target_index; } diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 77d398e5e2..02f6fc3a27 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -1783,6 +1783,7 @@ xla_test( "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:client_library_test_base", diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index ab470f16a3..f7f9a87413 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -764,6 +765,39 @@ XLA_TEST_F(FusionTest, Clamp2D) { TestElementwise2D<float, 3>(HloOpcode::kClamp); } +// TODO(b/73903144): Enable on interpreter once interpreter supports bitcast. +XLA_TEST_F(FusionTest, DISABLED_ON_INTERPRETER(FusionWithLayout)) { + const string hlo_text = R"( +HloModule Cluster + +fusion_c { + fusion.arg = f32[2,2]{1,0} parameter(0) + bitcast.0 = f32[2,2,1]{2,1,0} bitcast(fusion.arg) + tanh.0 = f32[2,2,1]{0,2,1} tanh(bitcast.0) + ROOT bitcast.2 = f32[2,2,1]{1,2,0} bitcast(tanh.0) +} + +ENTRY main { + arg = f32[2,2]{1,0} parameter(0) + ROOT fusion = f32[2,2,1]{1,2,0} fusion(arg), kind=kLoop, calls=fusion_c +} +)"; + + std::unique_ptr<Literal> operand = + Literal::CreateR2<float>({{0., 0.}, {1., 0.}}); + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(hlo_text, config)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<Literal> result, + test_runner_.Execute(std::move(module), {operand.get()}, + /*run_hlo_passes=*/false)); + EXPECT_TRUE(LiteralTestUtil::Equal( + *Literal::CreateR3<float>({{{0.}, {0.76159415595}}, {{0.}, {0.}}}), + *result)); +} + void BM_ParallelFusion(int num_iters) { // Simple element-wise computation to benchmark parallel task partitioning. tensorflow::testing::StopTiming(); |