aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_test.cc9
3 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index aa9ff89e98..27b4626094 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -374,4 +374,9 @@ HloSharding HloSharding::TransformShardedTileShape(
return HloSharding::Tile(new_tile_shape, tile_assignment());
}
+std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) {
+ out << sharding.ToString();
+ return out;
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h
index 38273236f9..18d406f370 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding.h
@@ -269,6 +269,8 @@ class HloSharding {
std::vector<HloSharding> tuple_elements_;
};
+std::ostream& operator<<(std::ostream& out, const HloSharding& sharding);
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
index 07fc4687cc..9887096eb5 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
@@ -282,5 +282,14 @@ TEST_F(HloShardingTest, TransformShardedTileShapeTest) {
EXPECT_EQ(result, expected);
}
+TEST_F(HloShardingTest, OstreamTest) {
+ HloSharding sharding =
+ HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 5, 7, 11}),
+ Array4D<int64>({{{{0, 1}, {2, 3}}}}));
+ std::ostringstream oss;
+ oss << sharding;
+ EXPECT_EQ(oss.str(), "{f32[3,5,7,11] devices=(0, 1, 2, 3)}");
+}
+
} // namespace
} // namespace xla