aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/algebraic_simplifier.cc
diff options
context:
space:
mode:
authorGravatar David Majnemer <majnemer@google.com>2018-08-28 17:08:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-28 17:12:30 -0700
commitde25baba51a7369bbdc5fd8051af0e9deb6a268d (patch)
tree4dd13806044b9e80128164ecaaf0d13ec2548173 /tensorflow/compiler/xla/service/algebraic_simplifier.cc
parent4dcd00066fba2bd7c504c1bc35738f804de9df67 (diff)
[XLA] Add support for algebraic simplifications involving kIota
PiperOrigin-RevId: 210634966
Diffstat (limited to 'tensorflow/compiler/xla/service/algebraic_simplifier.cc')
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc34
1 files changed, 29 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index c236453fc7..19bb4da9a6 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -30,9 +30,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_query.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
@@ -1238,7 +1240,7 @@ namespace {
// return value = {1, 3}
//
// Precondition: input_dim_indices is sorted.
-std::pair<bool, std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
+absl::optional<std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
const HloInstruction* hlo,
tensorflow::gtl::ArraySlice<int64> input_dim_indices) {
CHECK_EQ(HloOpcode::kReshape, hlo->opcode());
@@ -1258,11 +1260,11 @@ std::pair<bool, std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
}
if (i >= unmodified_dims.size() ||
unmodified_dims[i].first != input_dim_index) {
- return std::make_pair(false, std::vector<int64>());
+ return absl::nullopt;
}
output_dim_indices.push_back(unmodified_dims[i].second);
}
- return std::make_pair(true, output_dim_indices);
+ return output_dim_indices;
}
// Returns true if the output of "instruction" is a permutation of the
@@ -1391,6 +1393,15 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
return Status::OK();
}
+ // broadcast(iota) -> iota.
+ if (operand->opcode() == HloOpcode::kIota) {
+ return ReplaceWithNewInstruction(
+ broadcast,
+ HloInstruction::CreateIota(
+ broadcast->shape(),
+ dims[Cast<HloIotaInstruction>(operand)->iota_dimension()]));
+ }
+
// Merge two consecutive broadcasts into a single one.
if (operand->opcode() == HloOpcode::kBroadcast) {
std::vector<int64> new_dimensions;
@@ -1719,12 +1730,25 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) {
auto opt_dims = ReshapeLeavesDimensionsUnmodified(
reshape, reshape->operand(0)->dimensions());
- if (opt_dims.first) {
+ if (opt_dims.has_value()) {
return ReplaceWithNewInstruction(
reshape,
HloInstruction::CreateBroadcast(
reshape->shape(), reshape->mutable_operand(0)->mutable_operand(0),
- opt_dims.second));
+ *opt_dims));
+ }
+ }
+
+ // reshape(iota) -> iota.
+ if (operand->opcode() == HloOpcode::kIota) {
+ auto* iota = Cast<HloIotaInstruction>(operand);
+ auto opt_dims =
+ ReshapeLeavesDimensionsUnmodified(reshape, {iota->iota_dimension()});
+ if (opt_dims.has_value()) {
+ CHECK_EQ(opt_dims->size(), 1);
+ return ReplaceWithNewInstruction(
+ reshape,
+ HloInstruction::CreateIota(reshape->shape(), opt_dims->front()));
}
}