aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-07-03 11:23:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-03 11:26:48 -0700
commitbdd3a01d20ec6747cc6efc39fe42ed5f29d2c97e (patch)
tree341a217fc6b428e4e1859068fd4698a59f221f22
parent3a118b84b1845092e1900f7aadaea706d18f214b (diff)
Compare layouts when propagating linear indices through elemental ops in fusions
PiperOrigin-RevId: 203154982
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc9
-rw-r--r--tensorflow/compiler/xla/tests/BUILD1
-rw-r--r--tensorflow/compiler/xla/tests/fusion_test.cc34
3 files changed, 43 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index ce0951bbe1..21c6f7d358 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -1227,7 +1227,14 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex(
// If no implicit broadcast is needed for this operand, returns the target
// index as the source index.
- if (ShapeUtil::CompatibleIgnoringElementType(operand_shape, hlo.shape())) {
+ //
+ // `IrArray::Index` may contain a physical linear which we can propagate to
+ // our operand only if our layouts match. "only if" is a bit strong since
+ // e.g. we can still forward the linear index if the operand shape is
+ // [5,1,1,5]{3,2,1,0} and the HLO shape is[5,1,1,5]{3,1,2,0}, but those cases
+ // are probably not worth handling here for now.
+ if (ShapeUtil::CompatibleIgnoringElementType(operand_shape, hlo.shape()) &&
+ LayoutUtil::Equal(operand_shape.layout(), hlo.shape().layout())) {
return target_index;
}
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 77d398e5e2..02f6fc3a27 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -1783,6 +1783,7 @@ xla_test(
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_runner",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/tests:client_library_test_base",
diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc
index ab470f16a3..f7f9a87413 100644
--- a/tensorflow/compiler/xla/tests/fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/fusion_test.cc
@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
@@ -764,6 +765,39 @@ XLA_TEST_F(FusionTest, Clamp2D) {
TestElementwise2D<float, 3>(HloOpcode::kClamp);
}
+// TODO(b/73903144): Enable on interpreter once interpreter supports bitcast.
+XLA_TEST_F(FusionTest, DISABLED_ON_INTERPRETER(FusionWithLayout)) {
+ const string hlo_text = R"(
+HloModule Cluster
+
+fusion_c {
+ fusion.arg = f32[2,2]{1,0} parameter(0)
+ bitcast.0 = f32[2,2,1]{2,1,0} bitcast(fusion.arg)
+ tanh.0 = f32[2,2,1]{0,2,1} tanh(bitcast.0)
+ ROOT bitcast.2 = f32[2,2,1]{1,2,0} bitcast(tanh.0)
+}
+
+ENTRY main {
+ arg = f32[2,2]{1,0} parameter(0)
+ ROOT fusion = f32[2,2,1]{1,2,0} fusion(arg), kind=kLoop, calls=fusion_c
+}
+)";
+
+ std::unique_ptr<Literal> operand =
+ Literal::CreateR2<float>({{0., 0.}, {1., 0.}});
+ HloModuleConfig config;
+ config.set_debug_options(GetDebugOptionsForTest());
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(hlo_text, config));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result,
+ test_runner_.Execute(std::move(module), {operand.get()},
+ /*run_hlo_passes=*/false));
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ *Literal::CreateR3<float>({{{0.}, {0.76159415595}}, {{0.}, {0.}}}),
+ *result));
+}
+
void BM_ParallelFusion(int num_iters) {
// Simple element-wise computation to benchmark parallel task partitioning.
tensorflow::testing::StopTiming();