diff options
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_sharding.cc | 5 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_sharding.h | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_sharding_test.cc | 9 |
3 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index aa9ff89e98..27b4626094 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -374,4 +374,9 @@ HloSharding HloSharding::TransformShardedTileShape( return HloSharding::Tile(new_tile_shape, tile_assignment()); } +std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) { + out << sharding.ToString(); + return out; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 38273236f9..18d406f370 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -269,6 +269,8 @@ class HloSharding { std::vector<HloSharding> tuple_elements_; }; +std::ostream& operator<<(std::ostream& out, const HloSharding& sharding); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index 07fc4687cc..9887096eb5 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -282,5 +282,14 @@ TEST_F(HloShardingTest, TransformShardedTileShapeTest) { EXPECT_EQ(result, expected); } +TEST_F(HloShardingTest, OstreamTest) { + HloSharding sharding = + HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 5, 7, 11}), + Array4D<int64>({{{{0, 1}, {2, 3}}}})); + std::ostringstream oss; + oss << sharding; + EXPECT_EQ(oss.str(), "{f32[3,5,7,11] devices=(0, 1, 2, 3)}"); +} + } // namespace } // namespace xla |