diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_sharding.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_sharding.h | 11 |
1 files changed, 10 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 34324d2058..6f672b0f28 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -24,7 +24,7 @@ limitations under the License. #include <vector> #include "tensorflow/compiler/xla/array.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -80,6 +80,15 @@ class HloSharding { static HloSharding Tuple(const Shape& tuple_shape, tensorflow::gtl::ArraySlice<HloSharding> shardings); + // Creates a new sharding for a tuple type, with a single input sharding + // repeated on each leaf. + static HloSharding SingleTuple(const Shape& tuple_shape, + const HloSharding& sharding); + + // If shape is an array, returns sharding, otherwise returns the tuple shaped + // sharding with all the leaf nodes having the same input sharding. + static HloSharding Single(const Shape& shape, const HloSharding& sharding); + // Create a new sharding from a protobuf OpSharding. static StatusOr<HloSharding> FromProto(const OpSharding& proto); |