diff options
author | 2018-08-23 14:53:20 -0700 | |
---|---|---|
committer | 2018-08-23 14:57:41 -0700 | |
commit | a73bc982669c8eb4bec9418a94a96b64551e641b (patch) | |
tree | 663477d414aad7a982acecbec029a814cd74d9b4 /tensorflow/compiler | |
parent | c6b8f3617662ff0223eace9774089ce509867160 (diff) |
Do not crash when an empty tuple is passed into hlo_sharding.
PiperOrigin-RevId: 210005372
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_sharding.cc | 9 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_sharding.h | 6 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_sharding_test.cc | 7 |
3 files changed, 13 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index cc8ee94e22..61614f0c43 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -72,12 +72,9 @@ HloSharding HloSharding::SingleTuple(const Shape& tuple_shape, const HloSharding& sharding) { CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape); CHECK(!sharding.IsTuple()) << sharding.ToString(); - int64 leaf_count = ShapeUtil::GetLeafCount(tuple_shape); + int64 leaf_count = RequiredLeaves(tuple_shape); std::vector<HloSharding> flattened_list; - flattened_list.reserve(leaf_count); - for (int64 i = 0; i < leaf_count; ++i) { - flattened_list.push_back(sharding); - } + flattened_list.resize(leaf_count, sharding); return HloSharding(flattened_list); } @@ -446,7 +443,7 @@ absl::optional<HloSharding> HloSharding::ExtractSingleSharding() const { } for (int64 i = 1; i < tuple_elements_.size(); ++i) { if (tuple_elements_[0] != tuple_elements_[i]) { - return absl::optional<HloSharding>(); + return absl::nullopt; } } return tuple_elements_.front(); diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 4c64ac60c5..be51c3f55b 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -260,9 +260,9 @@ class HloSharding { bool maximal_; bool tuple_; Array<int64> tile_assignment_; - // Only non-empty when tuple_ is true, but because empty tuples are allowed - // may also be empty even then. This is a flattened list of all the leaf - // shardings in a tuple shape, by pre-order walk (ShapeTree iterator order). + // Only non-empty when tuple_ is true. If a tuple is empty then one entry is + // present for the root. This is a flattened list of all the leaf shardings in + // a tuple shape, by pre-order walk (ShapeTree iterator order). std::vector<HloSharding> tuple_elements_; }; diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index 45fc300fca..2341f8ada0 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -115,6 +115,13 @@ TEST_F(HloShardingTest, Tile) { } } +// Tests that empty tuple is supported. +TEST_F(HloShardingTest, EmptySingleTuple) { + HloSharding sharding = HloSharding::SingleTuple(ShapeUtil::MakeTupleShape({}), + HloSharding::AssignDevice(0)); + EXPECT_TRUE(sharding.ExtractSingleSharding()); +} + TEST_F(HloShardingTest, NestedTuple) { // nested_tuple_shape = (f32[], (f32[3]), f32[4, 6]) Shape nested_tuple_shape = ShapeUtil::MakeTupleShape({ |