diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc | 28 |
1 files changed, 16 insertions, 12 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc index c77e3c81c9..6028950652 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -66,11 +67,12 @@ Status Visitor::HandleBatchNormInference(HloInstruction* batch_norm) { return Status::OK(); } - HloInstruction* epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); + HloInstruction* epsilon = + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(batch_norm->epsilon()))); HloInstruction* feature_index = computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(batch_norm->feature_index()))); + LiteralUtil::CreateR0(batch_norm->feature_index()))); std::vector<HloInstruction*> operands(batch_norm->operands().begin(), batch_norm->operands().end()); @@ -101,11 +103,12 @@ Status Visitor::HandleBatchNormTraining(HloInstruction* batch_norm) { return Status::OK(); } - HloInstruction* epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); + HloInstruction* epsilon = + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(batch_norm->epsilon()))); HloInstruction* feature_index = computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(batch_norm->feature_index()))); + LiteralUtil::CreateR0(batch_norm->feature_index()))); std::vector<HloInstruction*> operands(batch_norm->operands().begin(), batch_norm->operands().end()); @@ -128,8 +131,8 @@ Status Visitor::HandleBatchNormTraining(HloInstruction* batch_norm) { inverse_stddev->shape(), HloOpcode::kPower, inverse_stddev, computation_->AddInstruction(HloInstruction::CreateBroadcast( inverse_stddev->shape(), - computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(-2))), + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0<float>(-2))), {})))); HloInstruction* variance = computation_->AddInstruction(HloInstruction::CreateBinary( @@ -169,11 +172,12 @@ Status Visitor::HandleBatchNormGrad(HloInstruction* batch_norm) { return Status::OK(); } - HloInstruction* epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); + HloInstruction* epsilon = + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(batch_norm->epsilon()))); HloInstruction* feature_index = computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(batch_norm->feature_index()))); + LiteralUtil::CreateR0(batch_norm->feature_index()))); // The cudnn libcall expects its input to be rsqrt(variance + epsilon), but // the batchnorm HLO takes plain variance as input. Fix it up. @@ -189,7 +193,7 @@ Status Visitor::HandleBatchNormGrad(HloInstruction* batch_norm) { computation_->AddInstruction(HloInstruction::CreateBroadcast( var_plus_epsilon->shape(), computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0<float>(-.5))), + LiteralUtil::CreateR0<float>(-.5))), {})))); std::vector<HloInstruction*> operands(batch_norm->operands().begin(), |