aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_sharding.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_sharding.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.h11
1 files changed, 10 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h
index 34324d2058..6f672b0f28 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding.h
@@ -24,7 +24,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/array.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -80,6 +80,15 @@ class HloSharding {
static HloSharding Tuple(const Shape& tuple_shape,
tensorflow::gtl::ArraySlice<HloSharding> shardings);
+ // Creates a new sharding for a tuple type, with a single input sharding
+ // repeated on each leaf.
+ static HloSharding SingleTuple(const Shape& tuple_shape,
+ const HloSharding& sharding);
+
+ // If shape is an array, returns sharding, otherwise returns the tuple shaped
+ // sharding with all the leaf nodes having the same input sharding.
+ static HloSharding Single(const Shape& shape, const HloSharding& sharding);
+
// Create a new sharding from a protobuf OpSharding.
static StatusOr<HloSharding> FromProto(const OpSharding& proto);