diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_creation_utils.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_creation_utils.cc | 55 |
1 files changed, 55 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index 90d2be118d..858992a326 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -174,6 +174,29 @@ StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers)); } +StatusOr<HloInstruction*> MakeMapHlo( + tensorflow::gtl::ArraySlice<HloInstruction*> operands, + HloComputation* map_computation) { + CHECK(!operands.empty()) << "Map Hlo requires at least one operand."; + HloComputation* computation = operands.front()->parent(); + std::vector<const Shape*> operand_shapes; + int64 max_operand_rank = 0; + for (const HloInstruction* operand : operands) { + CHECK_EQ(computation, operand->parent()); + operand_shapes.push_back(&operand->shape()); + max_operand_rank = + std::max(max_operand_rank, ShapeUtil::Rank(operand->shape())); + } + std::vector<int64> map_dims(max_operand_rank); + std::iota(map_dims.begin(), map_dims.end(), 0); + TF_ASSIGN_OR_RETURN( + Shape map_shape, + ShapeInference::InferMapShape( + operand_shapes, map_computation->ComputeProgramShape(), map_dims)); + return computation->AddInstruction( + HloInstruction::CreateMap(map_shape, operands, map_computation)); +} + StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand, int64 n) { CHECK_GT(n, 0); @@ -251,6 +274,38 @@ StatusOr<HloInstruction*> ElideDegenerateDims(HloInstruction* operand, return MakeReshapeHlo(output_shape, operand); } +StatusOr<HloInstruction*> InsertDegenerateDims( + HloInstruction* operand, ArraySlice<int64> dims_to_insert) { + CHECK(c_is_sorted(dims_to_insert)); + + const Shape& operand_shape = operand->shape(); + int64 output_shape_rank = + operand_shape.dimensions_size() + dims_to_insert.size(); + for (auto dim_to_insert : dims_to_insert) { + CHECK_LT(dim_to_insert, output_shape_rank); + } + + std::vector<int64> output_shape_dim_bounds; + output_shape_dim_bounds.reserve(output_shape_rank); + int64 operand_dims_idx = 0; + int64 dims_to_insert_idx = 0; + for (int64 i = 0; i < output_shape_rank; ++i) { + if (dims_to_insert_idx < dims_to_insert.size() && + i == dims_to_insert[dims_to_insert_idx]) { + output_shape_dim_bounds.push_back(1); + ++dims_to_insert_idx; + } else { + output_shape_dim_bounds.push_back( + operand_shape.dimensions(operand_dims_idx)); + ++operand_dims_idx; + } + } + + Shape output_shape = ShapeUtil::MakeShape(operand_shape.element_type(), + output_shape_dim_bounds); + return MakeReshapeHlo(output_shape, operand); +} + StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand, int64 zeros_to_prepend, int64 zeros_to_append) { |