aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-09 02:32:44 -0800
committerGravatar Andrew Selle <aselle@andyselle.com>2017-11-10 16:14:38 -0800
commitc38781f6ec9710c0102bdc9d95bf6176fd96d1ce (patch)
treef52551bc7beabfdf8007a5de9d5d8b6ec8cd625d
parent791c8bf0baf4198c5922dba08a74960ca6dac74f (diff)
When sharding a tuple, we typically want to describe the data sharding
of each individual subtensor individually. Tuples are essentially just containers - the tensors they contain should be able to be sharded differently. Tuples are hierarchically structured, but shardings were designed to not contain the sharded type (the sharded type is inferred from the output type of the instruction the sharding is applied to). Therefore, shardings for tuples contain shardings for each subtensor as a non-structured list. This list is ordered as a preorder walk of the tuple shape, and of course only the leaf nodes of the tuple shape are stored. The structure is reapplied when the sharded instruction's shape is known. PiperOrigin-RevId: 175132692
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc71
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.h83
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_test.cc68
-rw-r--r--tensorflow/compiler/xla/shape_tree.h3
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser.cc41
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc15
-rw-r--r--tensorflow/compiler/xla/xla_data.proto13
7 files changed, 278 insertions, 16 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index 0d019d22f5..bc5663513b 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/str_util.h"
namespace xla {
@@ -38,6 +39,15 @@ HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) {
}
string HloSharding::ToString() const {
+ if (IsTuple()) {
+ std::vector<string> parts;
+ parts.reserve(tuple_elements_.size());
+ for (const HloSharding& element : tuple_elements_) {
+ parts.push_back(element.ToString());
+ }
+ return StrCat("{", tensorflow::str_util::Join(parts, ", "), "}");
+ }
+
string result = StrCat("{", (replicated_ ? " replicated" : ""),
(maximal_ ? " maximal" : ""));
@@ -53,6 +63,11 @@ string HloSharding::ToString() const {
}
bool HloSharding::UsesDevice(int64 device) const {
+ if (IsTuple()) {
+ return std::any_of(
+ tuple_elements_.begin(), tuple_elements_.end(),
+ [&](const HloSharding& s) { return s.UsesDevice(device); });
+ }
const auto& devices = tile_assignment_;
return replicated_ ||
std::find(devices.begin(), devices.end(), device) != devices.end();
@@ -61,6 +76,7 @@ bool HloSharding::UsesDevice(int64 device) const {
std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const {
CHECK(!ShapeUtil::IsTuple(tile_shape_));
CHECK(!maximal_);
+ CHECK(!IsTuple());
std::vector<int64> ret_index;
tile_assignment_.Each([&](tensorflow::gtl::ArraySlice<int64> index, int64 d) {
if (d == device) {
@@ -74,6 +90,7 @@ std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const {
int64 HloSharding::DeviceForTileIndex(
tensorflow::gtl::ArraySlice<int64> index) const {
CHECK(!replicated_);
+ CHECK(!IsTuple());
if (maximal_) {
return *tile_assignment_.begin();
}
@@ -82,7 +99,7 @@ int64 HloSharding::DeviceForTileIndex(
}
std::vector<int64> HloSharding::TileOffsetForDevice(int64 device) const {
- CHECK(!ShapeUtil::IsTuple(tile_shape_));
+ CHECK(!IsTuple());
std::vector<int64> index = TileIndexForDevice(device);
if (maximal_) {
@@ -97,7 +114,7 @@ std::vector<int64> HloSharding::TileOffsetForDevice(int64 device) const {
}
std::vector<int64> HloSharding::TileLimitForDevice(int64 device) const {
- CHECK(!ShapeUtil::IsTuple(tile_shape_));
+ CHECK(!IsTuple());
CHECK(!maximal_); // Maximal shardings do not have a valid tile shape.
std::vector<int64> index = TileIndexForDevice(device);
@@ -108,13 +125,41 @@ std::vector<int64> HloSharding::TileLimitForDevice(int64 device) const {
}
StatusOr<int64> HloSharding::UniqueDevice() const {
- if (!replicated_ && maximal_) {
+ if (IsTuple()) {
+ if (tuple_elements_.empty()) {
+ return tensorflow::errors::InvalidArgument(
+ "UniqueDevice() called on empty tuple");
+ }
+ std::vector<StatusOr<int64>> results;
+ std::transform(tuple_elements_.begin(), tuple_elements_.end(),
+ std::back_inserter(results),
+ [](const HloSharding& s) { return s.UniqueDevice(); });
+ if (std::all_of(results.begin(), results.end(),
+ [&](const StatusOr<int64>& s) {
+ return s.ok() && results[0].ok() &&
+ s.ValueOrDie() == results[0].ValueOrDie();
+ })) {
+ return results[0];
+ } else {
+ return tensorflow::errors::InvalidArgument(
+ "Tuple did not contain a unique device");
+ }
+ }
+ if (!replicated_ && maximal_ && !IsTuple()) {
return static_cast<int64>(*tile_assignment_.begin());
}
return tensorflow::errors::InvalidArgument(
"UniqueDevice() called on sharding that executes on multiple devices");
}
+bool HloSharding::HasUniqueDevice() const {
+ if (IsTuple()) {
+ return UniqueDevice().status().ok();
+ } else {
+ return !IsReplicated() && IsTileMaximal();
+ }
+}
+
Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
if (replicated_) {
return Status::OK();
@@ -193,7 +238,16 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
/*static*/ StatusOr<HloSharding> HloSharding::FromProto(
const OpSharding& proto) {
- if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) {
+ if (proto.type() == OpSharding::Type::OpSharding_Type_TUPLE) {
+ std::vector<HloSharding> tuple_shardings;
+ tuple_shardings.reserve(proto.tuple_shardings().size());
+ for (const OpSharding& tuple_sharding_proto : proto.tuple_shardings()) {
+ TF_ASSIGN_OR_RETURN(HloSharding sharding,
+ HloSharding::FromProto(tuple_sharding_proto));
+ tuple_shardings.push_back(sharding);
+ }
+ return HloSharding(tuple_shardings);
+ } else if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) {
return Replicate();
} else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL) {
return HloSharding(proto.tile_assignment_devices(0));
@@ -212,6 +266,15 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
OpSharding HloSharding::ToProto() const {
OpSharding result;
+
+ if (IsTuple()) {
+ for (const HloSharding& element : tuple_elements_) {
+ *result.add_tuple_shardings() = element.ToProto();
+ }
+ result.set_type(OpSharding::Type::OpSharding_Type_TUPLE);
+ return result;
+ }
+
*result.mutable_tile_shape() = tile_shape_;
for (int64 dim : tile_assignment_.dimensions()) {
result.add_tile_assignment_dimensions(dim);
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h
index d7ada30c70..f8ef2a3d05 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding.h
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
+#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/hash/hash.h"
@@ -67,6 +68,18 @@ class HloSharding {
// `num_tiles` tiles.
static HloSharding Tile1D(const Shape& input_shape, int64 num_tiles);
+ // Creates a new sharding for a tuple type. The given ShapeTree must have
+ // elements for every leaf shape contained in the tuple.
+ static HloSharding Tuple(const ShapeTree<HloSharding>& sub_shardings) {
+ std::vector<HloSharding> flattened_list;
+ flattened_list.reserve(
+ std::distance(sub_shardings.leaf_begin(), sub_shardings.leaf_end()));
+ for (const auto& index_to_sharding : sub_shardings.leaves()) {
+ flattened_list.push_back(index_to_sharding.second);
+ }
+ return HloSharding(flattened_list);
+ }
+
// Create a new sharding from a protobuf OpSharding.
static StatusOr<HloSharding> FromProto(const OpSharding& proto);
@@ -76,47 +89,89 @@ class HloSharding {
// Validate that this sharding can be applied to a tensor with shape `shape`.
Status Validate(const Shape& shape, int64 num_devices) const;
+ // Returns true if the sharding has tuple type.
+ bool IsTuple() const { return tuple_; }
+
// Returns true if the sharding is trivial: replicate on all devices.
- bool IsReplicated() const { return replicated_; }
+ bool IsReplicated() const {
+ if (!IsTuple()) {
+ return replicated_;
+ }
+ return std::all_of(tuple_elements_.begin(), tuple_elements_.end(),
+ [](const HloSharding& s) { return s.IsReplicated(); });
+ }
// Returns true if the tile size is the same as the input size.
- bool IsTileMaximal() const { return maximal_; }
+ bool IsTileMaximal() const {
+ if (!IsTuple()) {
+ return maximal_;
+ }
+ return std::all_of(tuple_elements_.begin(), tuple_elements_.end(),
+ [](const HloSharding& s) { return s.IsTileMaximal(); });
+ }
// Returns true if the sharding defines an operation on the given device.
bool UsesDevice(int64 device) const;
// Returns the tile that should be executed on the given device.
+ // REQUIRES: !IsTuple()
std::vector<int64> TileIndexForDevice(int64 device) const;
// Returns the device that should execute the given tile.
// It is an error to call this if is_replicated() is true.
+ // REQUIRES: !IsTuple()
int64 DeviceForTileIndex(tensorflow::gtl::ArraySlice<int64> index) const;
// Given a device ID, returns the offset within the input space of the
// tile that should be executed on the given core. This returns the lower
// extent of the tile in the input space.
+ // REQUIRES: !IsTuple()
std::vector<int64> TileOffsetForDevice(int64 device) const;
// Given a device ID, returns the limit within the input space of the
// tile that should be executed on the given core. This returns the upper
// extent of the tile in the input space.
+ // REQUIRES: !IsTuple()
std::vector<int64> TileLimitForDevice(int64 device) const;
// Returns the single device this op operates on.
- // Requires !Replicated() && IsTileMaximal().
+ // REQUIRES: !IsTuple&& !Replicated() && IsTileMaximal()
StatusOr<int64> UniqueDevice() const;
// Returns true if this op only uses a single device.
- bool HasUniqueDevice() const { return !IsReplicated() && IsTileMaximal(); }
+ bool HasUniqueDevice() const;
+
+ // Returns the ShapeTree containing the shardings for each element of this
+ // tuple. Only the leaf elements are populated. This creates a new ShapeTree
+ // object so is not cheap. REQUIRES: IsTuple()
+ ShapeTree<HloSharding> GetTupleShardingsAsShapeTree(
+ const Shape& tuple_shape) const {
+ ShapeTree<HloSharding> result(tuple_shape, HloSharding::Replicate());
+ CHECK_EQ(std::distance(result.leaf_begin(), result.leaf_end()),
+ tuple_elements_.size());
+ auto it = tuple_elements_.begin();
+ for (auto& index_to_sharding : result.leaves()) {
+ index_to_sharding.second = *it++;
+ }
+ return result;
+ }
bool operator==(const HloSharding& other) const {
return replicated_ == other.replicated_ && maximal_ == other.maximal_ &&
protobuf_util::ProtobufEquals(tile_shape_, other.tile_shape_) &&
- tile_assignment_ == other.tile_assignment_;
+ tile_assignment_ == other.tile_assignment_ &&
+ tuple_elements_ == other.tuple_elements_;
}
bool operator!=(const HloSharding& other) const { return !(*this == other); }
size_t Hash() const {
+ if (!tuple_) {
+ size_t h = 0;
+ for (const auto& element : tuple_elements_) {
+ h = tensorflow::Hash64Combine(h, element.Hash());
+ }
+ return h;
+ }
if (replicated_) {
return 0;
}
@@ -131,33 +186,47 @@ class HloSharding {
}
// Gets the tile shape.
- // It is an error to call this if IsTileMaximal() is true.
+ // REQUIRES: !IsTileMaximal() && !IsTuple()
const Shape& tile_shape() const { return tile_shape_; }
// Gets the tile assignment tensor.
- // It is an error to call this if IsReplicated() is true.
+ // REQUIRES: !IsReplicated() && !IsTuple()
const Array<int64>& tile_assignment() const { return tile_assignment_; }
private:
HloSharding()
: replicated_(true),
maximal_(true),
+ tuple_(false),
tile_shape_(),
tile_assignment_({0}) {}
explicit HloSharding(int64 device_id)
: replicated_(false),
maximal_(true),
+ tuple_(false),
tile_shape_(),
tile_assignment_({1}, device_id) {}
HloSharding(const Shape& tile_shape, const Array<int64>& tile_assignment)
: replicated_(false),
maximal_(false),
+ tuple_(false),
tile_shape_(tile_shape),
tile_assignment_(tile_assignment) {}
+ HloSharding(const std::vector<HloSharding>& tuple_shardings)
+ : replicated_(false),
+ maximal_(false),
+ tuple_(true),
+ tile_assignment_({0}),
+ tuple_elements_(tuple_shardings) {}
bool replicated_;
bool maximal_;
+ bool tuple_;
Shape tile_shape_;
Array<int64> tile_assignment_;
+ // Only non-empty when tuple_ is true, but because empty tuples are allowed
+ // may also be empty even then. This is a flattened list of all the leaf
+ // shardings in a tuple shape, by pre-order walk (ShapeTree iterator order).
+ std::vector<HloSharding> tuple_elements_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
index d0a20471a0..00ea38480e 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
@@ -132,6 +132,29 @@ TEST_F(HloShardingTest, Tile) {
}
}
+TEST_F(HloShardingTest, NestedTuple) {
+ // nested_tuple_shape = (f32[], (f32[3]), f32[4, 6])
+ Shape nested_tuple_shape = ShapeUtil::MakeTupleShape({
+ ShapeUtil::MakeShape(F32, {}),
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3})}),
+ ShapeUtil::MakeShape(F32, {4, 6}),
+ });
+
+ OpSharding proto;
+ proto.set_type(OpSharding::Type::OpSharding_Type_TUPLE);
+ *proto.add_tuple_shardings() = HloSharding::Replicate().ToProto();
+ *proto.add_tuple_shardings() = HloSharding::AssignDevice(0).ToProto();
+ *proto.add_tuple_shardings() = HloSharding::AssignDevice(1).ToProto();
+ HloSharding tuple_sharding =
+ HloSharding::FromProto(proto).ConsumeValueOrDie();
+
+ ShapeTree<HloSharding> shape_tree =
+ tuple_sharding.GetTupleShardingsAsShapeTree(nested_tuple_shape);
+ EXPECT_EQ(shape_tree.element({0}), HloSharding::Replicate());
+ EXPECT_EQ(shape_tree.element({1, 0}), HloSharding::AssignDevice(0));
+ EXPECT_EQ(shape_tree.element({2}), HloSharding::AssignDevice(1));
+}
+
TEST_F(HloShardingTest, Hash) {
auto hash_compare_equal = [](const HloSharding& a, const HloSharding& b) {
if (a.Hash() != b.Hash()) {
@@ -184,6 +207,51 @@ TEST_F(HloShardingTest, Hash) {
MakeArray({2, 2}, {0, 3, 1, 2}));
EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
}
+
+ HloSharding default_sharding = HloSharding::Replicate();
+ {
+ ShapeTree<HloSharding> shape_tree(ShapeUtil::MakeTupleShape({}),
+ default_sharding);
+ HloSharding sharding1 = HloSharding::Replicate();
+ HloSharding sharding2 = HloSharding::Tuple(shape_tree);
+ EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
+ }
+
+ {
+ ShapeTree<HloSharding> shape_tree(ShapeUtil::MakeTupleShape({}),
+ default_sharding);
+ HloSharding sharding1 = HloSharding::Tuple(shape_tree);
+ HloSharding sharding2 = HloSharding::Tuple(shape_tree);
+ EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
+ }
+
+ {
+ ShapeTree<HloSharding> shape_tree1(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
+ default_sharding);
+ *shape_tree1.mutable_element({0}) = HloSharding::Replicate();
+ ShapeTree<HloSharding> shape_tree2(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
+ default_sharding);
+ *shape_tree2.mutable_element({0}) = HloSharding::AssignDevice(0);
+ HloSharding sharding1 = HloSharding::Tuple(shape_tree1);
+ HloSharding sharding2 = HloSharding::Tuple(shape_tree2);
+ EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
+ }
+
+ {
+ ShapeTree<HloSharding> shape_tree1(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
+ default_sharding);
+ *shape_tree1.mutable_element({0}) = HloSharding::AssignDevice(0);
+ ShapeTree<HloSharding> shape_tree2(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
+ default_sharding);
+ *shape_tree2.mutable_element({0}) = HloSharding::AssignDevice(0);
+ HloSharding sharding1 = HloSharding::Tuple(shape_tree1);
+ HloSharding sharding2 = HloSharding::Tuple(shape_tree2);
+ EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
+ }
}
} // namespace
diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h
index 64a36471b9..a898a4d375 100644
--- a/tensorflow/compiler/xla/shape_tree.h
+++ b/tensorflow/compiler/xla/shape_tree.h
@@ -116,6 +116,7 @@ class ShapeTree {
ShapeTree(const Shape* shape, const T& init_value);
ShapeTree(const ShapeTree& other) { *this = other; }
+ ShapeTree(ShapeTree&&) = default;
ShapeTree& operator=(const ShapeTree& other) {
root_ = other.root_;
@@ -132,6 +133,8 @@ class ShapeTree {
return *this;
}
+ ShapeTree& operator=(ShapeTree&& other) = default;
+
// Returns the data element associated with the array in the shape at the
// given index (see ShapeUtil::GetSubshape for how indexes are defined).
const T& element(const ShapeIndex& index) const;
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
index f1e987cb15..df07e069a0 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
@@ -60,6 +60,7 @@ class HloParser {
bool ParseInstructionList(HloComputation::Builder* builder,
string* root_name);
bool ParseInstruction(HloComputation::Builder* builder, string* root_name);
+ bool ParseControlPredecessors(HloInstruction* instruction);
bool ParseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
bool ParseTupleLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
bool ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
@@ -123,6 +124,7 @@ class HloParser {
bool ParseWindow(Window* window);
bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums);
bool ParseSharding(OpSharding* sharding);
+ bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed);
// Parses a sub-attribute of the window attribute, e.g.,size=1x2x3.
bool ParseDxD(const string& name, std::vector<int64>* result);
@@ -548,14 +550,49 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
return AddInstruction(name, instruction);
}
-// ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape? ('devices=' ('['
-// dims ']')* device_list)? '}' dims ::= int_list device_list ::= int_list
+// ::= '{' (single_sharding | tuple_sharding) '}'
+//
+// tuple_sharding ::= single_sharding* (',' single_sharding)*
bool HloParser::ParseSharding(OpSharding* sharding) {
+ // A single sharding starts with '{' and is not followed by '{'.
+ // A tuple sharding starts with '{' and is followed by '{', or is '{''}' for
+ // an empty tuple.
if (!ParseToken(TokKind::kLbrace,
"expected '{' to start sharding attribute")) {
return false;
}
+ if (lexer_.GetKind() != TokKind::kLbrace &&
+ lexer_.GetKind() != TokKind::kRbrace) {
+ return ParseSingleSharding(sharding, /*lbrace_pre_lexed=*/true);
+ }
+
+ // Tuple sharding.
+ // Allow empty tuple shardings.
+ if (lexer_.GetKind() != TokKind::kRbrace) {
+ do {
+ if (!ParseSingleSharding(sharding->add_tuple_shardings(),
+ /*lbrace_pre_lexed=*/false)) {
+ return false;
+ }
+ } while (EatIfPresent(TokKind::kComma));
+ }
+ sharding->set_type(OpSharding::Type::OpSharding_Type_TUPLE);
+
+ return ParseToken(TokKind::kRbrace, "expected '}' to end sharding attribute");
+}
+
+// ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape?
+// ('devices=' ('[' dims ']')* device_list)? '}'
+// dims ::= int_list device_list ::= int_list
+bool HloParser::ParseSingleSharding(OpSharding* sharding,
+ bool lbrace_pre_lexed) {
+ if (!lbrace_pre_lexed &&
+ !ParseToken(TokKind::kLbrace,
+ "expected '{' to start sharding attribute")) {
+ return false;
+ }
+
bool maximal = false;
bool replicated = false;
std::vector<int64> devices;
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
index 62b4385e76..a9dc360978 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
@@ -152,7 +152,7 @@ ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f3
%v1 = f32[4]{0} parameter(0), sharding={maximal device=1}
%v2 = f32[4]{0} parameter(1), sharding={maximal device=1}
%greater-than = pred[4]{0} greater-than(f32[4]{0} %v1, f32[4]{0} %v2), sharding={replicated}
- ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2)
+ ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2), sharding={}
}
)"
@@ -182,6 +182,19 @@ ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f
)"
},
+{
+"ShardedTupleCreate",
+R"(HloModule ShardedTupleCreate_module:
+
+ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) {
+ %v1 = f32[] parameter(0)
+ %v2 = f32[3]{0} parameter(1)
+ %v3 = f32[2,3]{1,0} parameter(2)
+ ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3), sharding={{replicated}, {maximal device=0}, {replicated}}
+}
+
+)"
+},
// int32 result = 0;
// while (result < 5) { result = result + 1; }
{
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index 06987e0044..7146604708 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -825,8 +825,10 @@ message OpSharding {
REPLICATED = 0;
// This sharding is maximal - one device runs the entire operation.
MAXIMAL = 1;
- // Neither of the above; tile_shape and tile_assignment are both used.
- OTHER = 2;
+ // This sharding is a tuple - only the tuple_shardings field is valid.
+ TUPLE = 2;
+ // None of the above; tile_shape and tile_assignment are both used.
+ OTHER = 3;
}
Type type = 1;
// The shape of the sharded tile.
@@ -838,6 +840,13 @@ message OpSharding {
// Flattened list of device IDs. The order of flattening is the same as used
// by IndexUtil::MultiToLinearIndex(tile_assignment_shape).
repeated int64 tile_assignment_devices = 4;
+ // If type == TUPLE, the sub-shardings, one per leaf node in the tuple shape,
+ // in pre-order. The tuple shape could be nested; here we store just a
+ // flattened list of all leaves in the tuple shape. Note that the tuple shape
+ // is not stored here; shardings do not store the shapes to which they are
+ // applied, this is inferred from the instruction this sharding gets attached
+ // to.
+ repeated OpSharding tuple_shardings = 5;
}
message OpRequest {