aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc28
1 files changed, 16 insertions, 12 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc
index c77e3c81c9..6028950652 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
@@ -66,11 +67,12 @@ Status Visitor::HandleBatchNormInference(HloInstruction* batch_norm) {
return Status::OK();
}
- HloInstruction* epsilon = computation_->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon())));
+ HloInstruction* epsilon =
+ computation_->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR0(batch_norm->epsilon())));
HloInstruction* feature_index =
computation_->AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR0(batch_norm->feature_index())));
+ LiteralUtil::CreateR0(batch_norm->feature_index())));
std::vector<HloInstruction*> operands(batch_norm->operands().begin(),
batch_norm->operands().end());
@@ -101,11 +103,12 @@ Status Visitor::HandleBatchNormTraining(HloInstruction* batch_norm) {
return Status::OK();
}
- HloInstruction* epsilon = computation_->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon())));
+ HloInstruction* epsilon =
+ computation_->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR0(batch_norm->epsilon())));
HloInstruction* feature_index =
computation_->AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR0(batch_norm->feature_index())));
+ LiteralUtil::CreateR0(batch_norm->feature_index())));
std::vector<HloInstruction*> operands(batch_norm->operands().begin(),
batch_norm->operands().end());
@@ -128,8 +131,8 @@ Status Visitor::HandleBatchNormTraining(HloInstruction* batch_norm) {
inverse_stddev->shape(), HloOpcode::kPower, inverse_stddev,
computation_->AddInstruction(HloInstruction::CreateBroadcast(
inverse_stddev->shape(),
- computation_->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(-2))),
+ computation_->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR0<float>(-2))),
{}))));
HloInstruction* variance =
computation_->AddInstruction(HloInstruction::CreateBinary(
@@ -169,11 +172,12 @@ Status Visitor::HandleBatchNormGrad(HloInstruction* batch_norm) {
return Status::OK();
}
- HloInstruction* epsilon = computation_->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon())));
+ HloInstruction* epsilon =
+ computation_->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR0(batch_norm->epsilon())));
HloInstruction* feature_index =
computation_->AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR0(batch_norm->feature_index())));
+ LiteralUtil::CreateR0(batch_norm->feature_index())));
// The cudnn libcall expects its input to be rsqrt(variance + epsilon), but
// the batchnorm HLO takes plain variance as input. Fix it up.
@@ -189,7 +193,7 @@ Status Visitor::HandleBatchNormGrad(HloInstruction* batch_norm) {
computation_->AddInstruction(HloInstruction::CreateBroadcast(
var_plus_epsilon->shape(),
computation_->AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR0<float>(-.5))),
+ LiteralUtil::CreateR0<float>(-.5))),
{}))));
std::vector<HloInstruction*> operands(batch_norm->operands().begin(),