diff options
author | 2018-05-15 12:12:44 -0700 | |
---|---|---|
committer | 2018-05-15 12:15:52 -0700 | |
commit | fdfaff2ed10501ead31fc1eda201031ec9c8d11e (patch) | |
tree | 82316ec19d38a8803931e61a33ab0c3a2f0df21e /tensorflow/compiler/xla/service/hlo_cse_test.cc | |
parent | 16adfe2d0004638d787e85d044178216c42d76a8 (diff) |
[XLA] Make HloCSE compare computations
This shows up when you have two otherwise identical instructions that call a
computation, like a fusion or a reduce. Even if the called computations are
identical but not the same it wouldn't get CSE'd. I was a bit worried about the
compile time impact of comparing full computations, but this only happens if
everything else already compares equal. The impact on compile time of
benchmarks seems to be within the noise.
PiperOrigin-RevId: 196708782
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_cse_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_cse_test.cc | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index a04b4f4dcf..9735764b69 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -469,5 +470,36 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { EXPECT_THAT(root, op::Add(op::Map(op::Constant()), op::Map(op::Constant()))); } +TEST_F(HloCseTest, CompareComputations) { + auto module = tools::Parse(R"( + HloModule m + + add_computation { + add_lhs = f32[] parameter(0) + add_rhs = f32[] parameter(1) + ROOT add_root = f32[] add(add_lhs, add_rhs) + } + + add_computation2 { + add_lhs2 = f32[] parameter(0) + add_rhs2 = f32[] parameter(1) + ROOT add_root2 = f32[] add(add_lhs2, add_rhs2) + } + + ENTRY entry { + p = f32[10]{0} parameter(0) + c = f32[] constant(0) + r1 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation + r2 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation2 + ROOT f2 = (f32[],f32[]) tuple(r1, r2) + })") + .ValueOrDie(); + + HloCSE cse(/*is_layout_sensitive=*/false); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->operand(0), root->operand(1)); +} + } // namespace } // namespace xla |