aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/batch_dot_simplification.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/batch_dot_simplification.cc')
-rw-r--r--tensorflow/compiler/xla/service/batch_dot_simplification.cc12
1 files changed, 7 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc
index 2099916509..a16b85a0a5 100644
--- a/tensorflow/compiler/xla/service/batch_dot_simplification.cc
+++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/batch_dot_simplification.h"
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
@@ -63,6 +64,7 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot(
TF_ASSIGN_OR_RETURN(HloInstruction * new_dot,
MakeDotHlo(new_lhs, new_rhs, new_dim_numbers));
+ new_dot->set_precision_config(batch_dot->precision_config());
TF_ASSIGN_OR_RETURN(HloInstruction * new_dot_reshaped,
MakeReshapeHlo(batch_dot->shape(), new_dot));
@@ -76,7 +78,7 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot(
return true;
}
-tensorflow::StringPiece BatchDotSimplification::name() const {
+absl::string_view BatchDotSimplification::name() const {
return "batch-dot-simplification";
}
@@ -84,10 +86,10 @@ StatusOr<bool> BatchDotSimplification::Run(HloModule* module) {
bool changed = false;
std::vector<HloInstruction*> dot_instrs;
for (HloComputation* computation : module->MakeNonfusionComputations()) {
- c_copy_if(computation->instructions(), std::back_inserter(dot_instrs),
- [](HloInstruction* instr) {
- return instr->opcode() == HloOpcode::kDot;
- });
+ absl::c_copy_if(computation->instructions(), std::back_inserter(dot_instrs),
+ [](HloInstruction* instr) {
+ return instr->opcode() == HloOpcode::kDot;
+ });
}
for (HloInstruction* dot_instr : dot_instrs) {
TF_ASSIGN_OR_RETURN(bool elided_batch_dim_from_one,