aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Yunxing Dai <yunxing@google.com>2018-08-23 14:53:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-23 14:57:41 -0700
commita73bc982669c8eb4bec9418a94a96b64551e641b (patch)
tree663477d414aad7a982acecbec029a814cd74d9b4 /tensorflow/compiler
parentc6b8f3617662ff0223eace9774089ce509867160 (diff)
Do not crash when an empty tuple is passed into hlo_sharding.
PiperOrigin-RevId: 210005372
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc9
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.h6
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_test.cc7
3 files changed, 13 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index cc8ee94e22..61614f0c43 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -72,12 +72,9 @@ HloSharding HloSharding::SingleTuple(const Shape& tuple_shape,
const HloSharding& sharding) {
CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape);
CHECK(!sharding.IsTuple()) << sharding.ToString();
- int64 leaf_count = ShapeUtil::GetLeafCount(tuple_shape);
+ int64 leaf_count = RequiredLeaves(tuple_shape);
std::vector<HloSharding> flattened_list;
- flattened_list.reserve(leaf_count);
- for (int64 i = 0; i < leaf_count; ++i) {
- flattened_list.push_back(sharding);
- }
+ flattened_list.resize(leaf_count, sharding);
return HloSharding(flattened_list);
}
@@ -446,7 +443,7 @@ absl::optional<HloSharding> HloSharding::ExtractSingleSharding() const {
}
for (int64 i = 1; i < tuple_elements_.size(); ++i) {
if (tuple_elements_[0] != tuple_elements_[i]) {
- return absl::optional<HloSharding>();
+ return absl::nullopt;
}
}
return tuple_elements_.front();
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h
index 4c64ac60c5..be51c3f55b 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding.h
@@ -260,9 +260,9 @@ class HloSharding {
bool maximal_;
bool tuple_;
Array<int64> tile_assignment_;
- // Only non-empty when tuple_ is true, but because empty tuples are allowed
- // may also be empty even then. This is a flattened list of all the leaf
- // shardings in a tuple shape, by pre-order walk (ShapeTree iterator order).
+ // Only non-empty when tuple_ is true. If a tuple is empty then one entry is
+ // present for the root. This is a flattened list of all the leaf shardings in
+ // a tuple shape, by pre-order walk (ShapeTree iterator order).
std::vector<HloSharding> tuple_elements_;
};
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
index 45fc300fca..2341f8ada0 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
@@ -115,6 +115,13 @@ TEST_F(HloShardingTest, Tile) {
}
}
+// Tests that empty tuple is supported.
+TEST_F(HloShardingTest, EmptySingleTuple) {
+ HloSharding sharding = HloSharding::SingleTuple(ShapeUtil::MakeTupleShape({}),
+ HloSharding::AssignDevice(0));
+ EXPECT_TRUE(sharding.ExtractSingleSharding());
+}
+
TEST_F(HloShardingTest, NestedTuple) {
// nested_tuple_shape = (f32[], (f32[3]), f32[4, 6])
Shape nested_tuple_shape = ShapeUtil::MakeTupleShape({