aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_cse_test.cc
diff options
context:
space:
mode:
authorGravatar Benjamin Kramer <kramerb@google.com>2018-05-15 12:12:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-15 12:15:52 -0700
commitfdfaff2ed10501ead31fc1eda201031ec9c8d11e (patch)
tree82316ec19d38a8803931e61a33ab0c3a2f0df21e /tensorflow/compiler/xla/service/hlo_cse_test.cc
parent16adfe2d0004638d787e85d044178216c42d76a8 (diff)
[XLA] Make HloCSE compare computations
This shows up when you have two otherwise identical instructions that call a computation, like a fusion or a reduce. Even if the called computations are identical but not the same it wouldn't get CSE'd. I was a bit worried about the compile time impact of comparing full computations, but this only happens if everything else already compares equal. The impact on compile time of benchmarks seems to be within the noise. PiperOrigin-RevId: 196708782
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_cse_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse_test.cc32
1 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc
index a04b4f4dcf..9735764b69 100644
--- a/tensorflow/compiler/xla/service/hlo_cse_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc
@@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -469,5 +470,36 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) {
EXPECT_THAT(root, op::Add(op::Map(op::Constant()), op::Map(op::Constant())));
}
+TEST_F(HloCseTest, CompareComputations) {
+ auto module = tools::Parse(R"(
+ HloModule m
+
+ add_computation {
+ add_lhs = f32[] parameter(0)
+ add_rhs = f32[] parameter(1)
+ ROOT add_root = f32[] add(add_lhs, add_rhs)
+ }
+
+ add_computation2 {
+ add_lhs2 = f32[] parameter(0)
+ add_rhs2 = f32[] parameter(1)
+ ROOT add_root2 = f32[] add(add_lhs2, add_rhs2)
+ }
+
+ ENTRY entry {
+ p = f32[10]{0} parameter(0)
+ c = f32[] constant(0)
+ r1 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation
+ r2 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation2
+ ROOT f2 = (f32[],f32[]) tuple(r1, r2)
+ })")
+ .ValueOrDie();
+
+ HloCSE cse(/*is_layout_sensitive=*/false);
+ EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_EQ(root->operand(0), root->operand(1));
+}
+
} // namespace
} // namespace xla