aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_creation_utils.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_creation_utils.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.h16
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h
index 49b1402d68..5ff8946fb0 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.h
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h
@@ -102,6 +102,12 @@ StatusOr<HloInstruction*> MakeConcatHlo(
StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
const DotDimensionNumbers& dim_numbers);
+// Creates a Map HLO instruction and adds it to the computation containing the
+// operands. All operands must be in the same computation.
+StatusOr<HloInstruction*> MakeMapHlo(
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* map_computation);
+
// -----------------------------------------------------------------------------
// Some other miscellaneous helpers to generate common HLO patterns. All of
// these add all the instructions they generate into the computation containing
@@ -144,6 +150,16 @@ StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
StatusOr<HloInstruction*> ElideDegenerateDims(
HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dims_to_elide);
+// Inserts (via reshape) a set of degenerate dimensions (dimensions containing
+// exactly one element), `dims_to_insert` into `operand`. The dimensions in
+// `dims_to_insert` refer to the dimensions in the result, and hence should be
+// less than the rank of the result. Also, `dims_to_insert` must be sorted.
+//
+// For example, if `operand` is of shape f32[12,21,8,34] and dims_to_insert is
+// {0, 2}, then the result is `operand` reshaped to [1,12,1,21,8,34].
+StatusOr<HloInstruction*> InsertDegenerateDims(
+ HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dims_to_insert);
+
// Pads `operand` (which must have rank 1) with `zeros_to_prepend` zeros in the
// front and `zeros_to_append` zeros in the back.
StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,