From 35287be3bb7daa0448af064f5d005a25201d6853 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 7 Jul 2018 06:54:44 -0700 Subject: Build fully connected graph which edges across called computations. Restructured sharding passes to propagate sharding on pass-through instructions which now the placer does not assign anymore (GTEs, tuples, bitcast, parameters, ...). PiperOrigin-RevId: 203591020 --- tensorflow/compiler/xla/service/hlo_sharding.cc | 21 +++++++++++++++++++++ tensorflow/compiler/xla/service/hlo_sharding.h | 9 +++++++++ 2 files changed, 30 insertions(+) diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 268b4727bc..393944c20f 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -60,6 +60,9 @@ HloSharding HloSharding::Tuple( const Shape& tuple_shape, tensorflow::gtl::ArraySlice shardings) { CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape); + for (auto& sharding : shardings) { + CHECK(!sharding.IsTuple()) << sharding.ToString(); + } std::vector flattened_list(shardings.begin(), shardings.end()); CHECK_EQ(flattened_list.size(), RequiredLeaves(tuple_shape)) << "Flat list has " << flattened_list.size() << ", required " @@ -67,6 +70,24 @@ HloSharding HloSharding::Tuple( return HloSharding(flattened_list); } +HloSharding HloSharding::SingleTuple(const Shape& tuple_shape, + const HloSharding& sharding) { + CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape); + CHECK(!sharding.IsTuple()) << sharding.ToString(); + int64 leaf_count = ShapeUtil::GetLeafCount(tuple_shape); + std::vector flattened_list; + flattened_list.reserve(leaf_count); + for (int64 i = 0; i < leaf_count; ++i) { + flattened_list.push_back(sharding); + } + return HloSharding(flattened_list); +} + +HloSharding HloSharding::Single(const Shape& shape, + const HloSharding& sharding) { + return ShapeUtil::IsTuple(shape) ? SingleTuple(shape, sharding) : sharding; +} + string HloSharding::ToString() const { if (IsTuple()) { std::vector parts; diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 3d14f9c89e..6f672b0f28 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -80,6 +80,15 @@ class HloSharding { static HloSharding Tuple(const Shape& tuple_shape, tensorflow::gtl::ArraySlice shardings); + // Creates a new sharding for a tuple type, with a single input sharding + // repeated on each leaf. + static HloSharding SingleTuple(const Shape& tuple_shape, + const HloSharding& sharding); + + // If shape is an array, returns sharding, otherwise returns the tuple shaped + // sharding with all the leaf nodes having the same input sharding. + static HloSharding Single(const Shape& shape, const HloSharding& sharding); + // Create a new sharding from a protobuf OpSharding. static StatusOr FromProto(const OpSharding& proto); -- cgit v1.2.3