aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_creation_utils.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_creation_utils.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.cc55
1 files changed, 55 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
index 90d2be118d..858992a326 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
@@ -174,6 +174,29 @@ StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers));
}
+StatusOr<HloInstruction*> MakeMapHlo(
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* map_computation) {
+ CHECK(!operands.empty()) << "Map Hlo requires at least one operand.";
+ HloComputation* computation = operands.front()->parent();
+ std::vector<const Shape*> operand_shapes;
+ int64 max_operand_rank = 0;
+ for (const HloInstruction* operand : operands) {
+ CHECK_EQ(computation, operand->parent());
+ operand_shapes.push_back(&operand->shape());
+ max_operand_rank =
+ std::max(max_operand_rank, ShapeUtil::Rank(operand->shape()));
+ }
+ std::vector<int64> map_dims(max_operand_rank);
+ std::iota(map_dims.begin(), map_dims.end(), 0);
+ TF_ASSIGN_OR_RETURN(
+ Shape map_shape,
+ ShapeInference::InferMapShape(
+ operand_shapes, map_computation->ComputeProgramShape(), map_dims));
+ return computation->AddInstruction(
+ HloInstruction::CreateMap(map_shape, operands, map_computation));
+}
+
StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand, int64 n) {
CHECK_GT(n, 0);
@@ -251,6 +274,38 @@ StatusOr<HloInstruction*> ElideDegenerateDims(HloInstruction* operand,
return MakeReshapeHlo(output_shape, operand);
}
+StatusOr<HloInstruction*> InsertDegenerateDims(
+ HloInstruction* operand, ArraySlice<int64> dims_to_insert) {
+ CHECK(c_is_sorted(dims_to_insert));
+
+ const Shape& operand_shape = operand->shape();
+ int64 output_shape_rank =
+ operand_shape.dimensions_size() + dims_to_insert.size();
+ for (auto dim_to_insert : dims_to_insert) {
+ CHECK_LT(dim_to_insert, output_shape_rank);
+ }
+
+ std::vector<int64> output_shape_dim_bounds;
+ output_shape_dim_bounds.reserve(output_shape_rank);
+ int64 operand_dims_idx = 0;
+ int64 dims_to_insert_idx = 0;
+ for (int64 i = 0; i < output_shape_rank; ++i) {
+ if (dims_to_insert_idx < dims_to_insert.size() &&
+ i == dims_to_insert[dims_to_insert_idx]) {
+ output_shape_dim_bounds.push_back(1);
+ ++dims_to_insert_idx;
+ } else {
+ output_shape_dim_bounds.push_back(
+ operand_shape.dimensions(operand_dims_idx));
+ ++operand_dims_idx;
+ }
+ }
+
+ Shape output_shape = ShapeUtil::MakeShape(operand_shape.element_type(),
+ output_shape_dim_bounds);
+ return MakeReshapeHlo(output_shape, operand);
+}
+
StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
int64 zeros_to_prepend,
int64 zeros_to_append) {