diff options
author | 2017-08-17 18:04:58 -0700 | |
---|---|---|
committer | 2017-08-17 18:08:34 -0700 | |
commit | 7359fec792e4efec1670a12332bb524a5608b215 (patch) | |
tree | cce4b3a2360071424ab2d794bc86de4598fe51fb /tensorflow/compiler | |
parent | f0da8bf56ba1b625d53b760683bc44f67e204199 (diff) |
Implement Batchnorm Inference by expanding them into smaller ops.
1. Add batch norm inference support in batchnorm_rewriter
2. Connect xla's batchnorm inference to tf's FusedBatchNorm
RELNOTES: n/a
PiperOrigin-RevId: 165655351
Diffstat (limited to 'tensorflow/compiler')
25 files changed, 606 insertions, 28 deletions
diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py index f8e9fc9268..936fcf8b6b 100644 --- a/tensorflow/compiler/tests/fused_batchnorm_test.py +++ b/tensorflow/compiler/tests/fused_batchnorm_test.py @@ -63,6 +63,39 @@ class FusedBatchNormTest(XLATestCase): grad_offset = np.sum(grad_y, axis=(0, 1, 2)) return grad_x, grad_scale, grad_offset + def testInference(self): + x_shape = [2, 2, 6, 2] + scale_shape = [2] + x_val = np.random.random_sample(x_shape).astype(np.float32) + scale_val = np.random.random_sample(scale_shape).astype(np.float32) + + offset_val = np.random.random_sample(scale_shape).astype(np.float32) + data_format = "NHWC" + with self.test_session() as sess, self.test_scope(): + # To avoid constant folding + t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x") + scale = array_ops.placeholder(np.float32, shape=[2], name="scale") + offset = array_ops.placeholder(np.float32, shape=[2], name="offset") + epsilon = 0.001 + y_ref, mean_ref, var_ref = self._reference_training( + x_val, scale_val, offset_val, epsilon, data_format) + y, mean, variance = nn.fused_batch_norm( + t_val, + scale, + offset, + mean=mean_ref, + variance=var_ref, + epsilon=epsilon, + data_format=data_format, + is_training=False) + + y_val, _, _ = sess.run( + [y, mean, + variance], {t_val: x_val, + scale: scale_val, + offset: offset_val}) + self.assertAllClose(y_val, y_ref, atol=1e-3) + def _testLearning(self, use_gradient_checker): x_shape = [2, 2, 6, 2] scale_shape = [2] diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index 3f23e459b9..9d2703bf95 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -39,28 +39,36 @@ class FusedBatchNormOp : public XlaOpKernel { errors::InvalidArgument("Not supported format")); feature_index_ = GetTensorFeatureDimIndex(/*num_dims=*/4, tensor_format); } - // TODO(b/62843645): Implement BatchNormInference. - OP_REQUIRES( - ctx, is_training_, - errors::InvalidArgument("Fused batch normalization for inference is " - "not supported yet on XLA backend.")); } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationDataHandle output = ctx->builder()->BatchNormTraining( - ctx->Input(0), ctx->Input(1), ctx->Input(2), epsilon_, feature_index_); - - // In training mode, outputs the normalized value as well as the calculated - // mean and variance. - for (int i = 0; i < 3; i++) { - ctx->SetOutput(i, ctx->builder()->GetTupleElement(output, i)); + if (is_training_) { + xla::ComputationDataHandle output = ctx->builder()->BatchNormTraining( + ctx->Input(0), ctx->Input(1), ctx->Input(2), epsilon_, + feature_index_); + + // In training mode, outputs the normalized value as well as the + // calculated mean and variance. + for (int i = 0; i < 3; i++) { + ctx->SetOutput(i, ctx->builder()->GetTupleElement(output, i)); + } + // Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved + // space 1 & 2". They are used to pass the per-batch mean and + // variance to the gradient. Here we maintain the same behavior by setting + // them to the mean and variance calculated by BatchNormTraining. + ctx->SetOutput(3, ctx->builder()->GetTupleElement(output, 1)); + ctx->SetOutput(4, ctx->builder()->GetTupleElement(output, 2)); + } else { + xla::ComputationDataHandle output = ctx->builder()->BatchNormInference( + ctx->Input(0), ctx->Input(1), ctx->Input(2), ctx->Input(3), + ctx->Input(4), epsilon_, feature_index_); + ctx->SetOutput(0, output); + // Directly send input to output as mean and variance in inference mode. + ctx->SetOutput(1, ctx->Input(3)); + ctx->SetOutput(2, ctx->Input(4)); + ctx->SetOutput(3, ctx->Input(3)); + ctx->SetOutput(4, ctx->Input(4)); } - // Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved - // space 1 & 2". They are used to pass the per-batch mean and - // variance to the gradient. Here we maintain the same behavior by setting - // them to the mean and variance calculated by BatchNormTraining. - ctx->SetOutput(3, ctx->builder()->GetTupleElement(output, 1)); - ctx->SetOutput(4, ctx->builder()->GetTupleElement(output, 2)); } private: diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index e6ffc4f98d..30afaed732 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -1477,9 +1477,29 @@ ComputationDataHandle ComputationBuilder::BatchNormInference( const ComputationDataHandle& operand, const ComputationDataHandle& scale, const ComputationDataHandle& offset, const ComputationDataHandle& mean, const ComputationDataHandle& variance, float epsilon, int64 feature_index) { - // TODO(b/62843645): Implement BatchNormInference. - NoteError(Unimplemented("BatchNormInference is not implemented yet.")); - return ComputationDataHandle(); + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + BatchNormInferenceRequest request; + *request.mutable_operand() = operand; + *request.mutable_scale() = scale; + *request.mutable_offset() = offset; + *request.mutable_mean() = mean; + *request.mutable_variance() = variance; + request.set_epsilon(epsilon); + request.set_feature_index(feature_index); + + OpRequest op_request; + *op_request.mutable_batch_norm_inference_request() = request; + *op_request.mutable_computation() = computation_.handle(); + AddOpMetadata(&op_request); + + OpResponse response; + + VLOG(2) << "making BatchNormInference request"; + + Status s = client_->stub()->Op(&op_request, &response); + return ParseOpResponse(s, &response); } ComputationDataHandle ComputationBuilder::BatchNormGrad( diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/batchnorm_rewriter.cc index 721d99301a..41d32d0c8b 100644 --- a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc +++ b/tensorflow/compiler/xla/service/batchnorm_rewriter.cc @@ -56,11 +56,14 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault { Status HandleBatchNormTraining(HloInstruction* batch_norm) override; + Status HandleBatchNormInference(HloInstruction* batch_norm) override; + Status HandleBatchNormGrad(HloInstruction* batch_norm) override; // Runs the visitor on a computation. static bool Run(HloComputation* computation, bool rewrite_training_op, - bool rewrite_grad_op, bool use_fusion); + bool rewrite_inference_op, bool rewrite_grad_op, + bool use_fusion); // Returns whether any batch norm ops were rewritten. const bool changed() const { return changed_; } @@ -70,9 +73,11 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault { private: explicit BatchNormRewriterVisitor(HloComputation* computation, bool rewrite_training_op, + bool rewrite_inference_op, bool rewrite_grad_op, bool use_fusion) : computation_(computation), rewrite_training_op_(rewrite_training_op), + rewrite_inference_op_(rewrite_inference_op), rewrite_grad_op_(rewrite_grad_op), use_fusion_(use_fusion) {} @@ -94,6 +99,7 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault { HloComputation* computation_; bool rewrite_training_op_; + bool rewrite_inference_op_; bool rewrite_grad_op_; bool use_fusion_; @@ -126,11 +132,14 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault { bool BatchNormRewriterVisitor::Run(HloComputation* computation, bool rewrite_training_op, + bool rewrite_inference_op, bool rewrite_grad_op, bool use_fusion) { - BatchNormRewriterVisitor visitor(computation, - /*rewrite_training_op=*/rewrite_training_op, - /*rewrite_grad_op=*/rewrite_grad_op, - /*use_fusion=*/use_fusion); + BatchNormRewriterVisitor visitor( + computation, + /*rewrite_training_op=*/rewrite_training_op, + /*rewrite_inference_op=*/rewrite_inference_op, + /*rewrite_grad_op=*/rewrite_grad_op, + /*use_fusion=*/use_fusion); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } @@ -268,6 +277,82 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining( return Status::OK(); } +Status BatchNormRewriterVisitor::HandleBatchNormInference( + HloInstruction* batch_norm) { + if (!rewrite_inference_op_) { + return Status::OK(); + } + // Expand batch norm inference into smaller HLO ops. + HloInstruction* operand = batch_norm->mutable_operand(0); + const Shape operand_shape = operand->shape(); + int64 feature_index = batch_norm->feature_index(); + + HloInstruction* scale = batch_norm->mutable_operand(1); + HloInstruction* offset = batch_norm->mutable_operand(2); + HloInstruction* mean = batch_norm->mutable_operand(3); + HloInstruction* var = batch_norm->mutable_operand(4); + const Shape feature_shape = scale->shape(); + + auto epsilon = computation_->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); + + std::vector<int64> dimensions_without_feature; + + for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) { + if (i != feature_index) { + dimensions_without_feature.push_back(i); + } + } + + auto scale_broadcasted = computation_->AddInstruction( + HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index})); + + auto offset_broadcasted = computation_->AddInstruction( + HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index})); + + auto mean_broadcasted = computation_->AddInstruction( + HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index})); + + auto var_broadcasted = computation_->AddInstruction( + HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); + + // Var[X] + epsilon. + auto var_add_epsilon = + computation_->AddInstruction(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); + + auto neg_half = computation_->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(-0.5f))); + + // 1 / Sqrt[Var[X] + epsilon]. + auto rsqrt_var_add_epsilon = + computation_->AddInstruction(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); + + // X - E[X]. + auto operand_minus_mean = + computation_->AddInstruction(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); + + // (X - E[X]) / Sqrt[Var[X] + epsilon]. + auto normalized = computation_->AddInstruction( + HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply, + operand_minus_mean, rsqrt_var_add_epsilon)); + + // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale. + auto scaled_normalized = + computation_->AddInstruction(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); + + // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset. + auto shifted_normalized = HloInstruction::CreateBinary( + operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted); + + TF_CHECK_OK( + ReplaceWithNewInstruction(batch_norm, std::move(shifted_normalized))); + return Status::OK(); +} + Status BatchNormRewriterVisitor::HandleBatchNormGrad( HloInstruction* batch_norm) { // Use the following formulas to calculate gradients: @@ -457,7 +542,8 @@ StatusOr<bool> BatchNormRewriter::Run(HloModule* module) { } for (auto& comp : computations) { if (BatchNormRewriterVisitor::Run(comp, rewrite_training_op_, - rewrite_grad_op_, use_fusion_)) { + rewrite_inference_op_, rewrite_grad_op_, + use_fusion_)) { changed = true; } } diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter.h b/tensorflow/compiler/xla/service/batchnorm_rewriter.h index d3ffb31032..f601741d96 100644 --- a/tensorflow/compiler/xla/service/batchnorm_rewriter.h +++ b/tensorflow/compiler/xla/service/batchnorm_rewriter.h @@ -30,8 +30,10 @@ class BatchNormRewriter : public HloPassInterface { public: // When use_fusion is set, a multi-output fusion node is created. BatchNormRewriter(bool rewrite_training_op = false, + bool rewrite_inference_op = false, bool rewrite_grad_op = false, bool use_fusion = true) : rewrite_training_op_(rewrite_training_op), + rewrite_inference_op_(rewrite_inference_op), rewrite_grad_op_(rewrite_grad_op), use_fusion_(use_fusion) {} ~BatchNormRewriter() = default; @@ -43,6 +45,7 @@ class BatchNormRewriter : public HloPassInterface { private: bool rewrite_training_op_; + bool rewrite_inference_op_; bool rewrite_grad_op_; bool use_fusion_; }; diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter_test.cc b/tensorflow/compiler/xla/service/batchnorm_rewriter_test.cc index cc8dffcda5..07775623e7 100644 --- a/tensorflow/compiler/xla/service/batchnorm_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/batchnorm_rewriter_test.cc @@ -64,6 +64,7 @@ TEST_F(BatchNormRewriterTest, BatchNormTraining) { HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormTraining); BatchNormRewriter rewriter(/*rewrite_training_op=*/true, + /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); @@ -105,6 +106,7 @@ TEST_F(BatchNormRewriterTest, BatchNormGrad) { HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormGrad); BatchNormRewriter rewriter(/*rewrite_training_op=*/true, + /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index eca9b0f4be..8a37c8108e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -260,6 +260,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module) { pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification"); pass.AddPass<BatchNormRewriter>( /*rewrite_training_op=*/true, + /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true, /*use_fusion=*/false); pass.AddPass<AlgebraicSimplifier>( diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index e450b31ff1..4baa56658f 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -228,6 +228,9 @@ class DfsHloVisitor { virtual Status HandleBatchNormTraining(HloInstruction* batchNormTraining) = 0; + virtual Status HandleBatchNormInference( + HloInstruction* batchNormInference) = 0; + virtual Status HandleBatchNormGrad(HloInstruction* batchNormGrad) = 0; // Invoked to inform the visitor that the traversal has completed, and that diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index c447165cec..10f8ae9b04 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -54,6 +54,10 @@ class DfsHloVisitorWithDefault : public DfsHloVisitor { return DefaultAction(hlo); } + Status HandleBatchNormInference(HloInstruction* hlo) override { + return DefaultAction(hlo); + } + Status HandleBatchNormGrad(HloInstruction* hlo) override { return DefaultAction(hlo); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 2a7486af88..cd913a4b5d 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -135,6 +135,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, // instead. pass.AddPass<BatchNormRewriter>( /*rewrite_training_op=*/true, + /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true, /*use_fusion=*/false); pass.AddPass<AlgebraicSimplifier>( diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index d113ca2a76..9dbde0ec24 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -374,6 +374,12 @@ Status HloCostAnalysis::HandleBatchNormTraining( return Status::OK(); } +Status HloCostAnalysis::HandleBatchNormInference( + HloInstruction* batchNormInference) { + // TODO(b/62294698): Implement cost analysis for batch-norm-inference. + return Status::OK(); +} + Status HloCostAnalysis::HandleBatchNormGrad(HloInstruction* batchNormGrad) { // TODO(b/62294698): Implement cost analysis for batch-norm-grad. return Status::OK(); diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index ec48c8a0fd..6d8fdfa64b 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -89,6 +89,7 @@ class HloCostAnalysis : public DfsHloVisitor { tensorflow::gtl::ArraySlice<int64> dimensions, HloComputation* function_handle) override; Status HandleBatchNormTraining(HloInstruction* batchNormTraining) override; + Status HandleBatchNormInference(HloInstruction* batchNormInference) override; Status HandleBatchNormGrad(HloInstruction* batchNormGrad) override; Status HandleFusion(HloInstruction* fusion) override; Status HandleCall(HloInstruction* call) override; diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index d1c3196366..38b1291d44 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -742,6 +742,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kParameter: return kOrange; case HloOpcode::kBatchNormTraining: + case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormGrad: case HloOpcode::kReduce: case HloOpcode::kSelectAndScatter: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 825f3f8f60..fb9dbd6421 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -407,6 +407,23 @@ HloInstruction::CreateBatchNormTraining(const Shape& shape, } /* static */ std::unique_ptr<HloInstruction> +HloInstruction::CreateBatchNormInference( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, + float epsilon, int64 feature_index) { + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kBatchNormInference, shape)); + instruction->AppendOperand(operand); + instruction->AppendOperand(scale); + instruction->AppendOperand(offset); + instruction->AppendOperand(mean); + instruction->AppendOperand(variance); + instruction->epsilon_ = epsilon; + instruction->feature_index_ = feature_index; + return instruction; +} + +/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand, HloInstruction* scale, HloInstruction* mean, HloInstruction* variance, @@ -1065,6 +1082,12 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( return CreateBatchNormTraining(shape, new_operands[0], new_operands[1], new_operands[2], epsilon(), feature_index()); + + case HloOpcode::kBatchNormInference: + CHECK_EQ(new_operands.size(), 5); + return CreateBatchNormInference( + shape, new_operands[0], new_operands[1], new_operands[2], + new_operands[3], new_operands[4], epsilon(), feature_index()); case HloOpcode::kInfeed: CHECK_EQ(new_operands.size(), 0); return CreateInfeed(shape, infeed_config()); @@ -1355,6 +1378,7 @@ bool HloInstruction::IdenticalSlowPath( ShapeUtil::Compatible(shape(), other.shape()); case HloOpcode::kBatchNormTraining: + case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormGrad: return feature_index() == other.feature_index() && epsilon() == other.epsilon(); @@ -1952,6 +1976,8 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { return visitor->HandleAbs(this, operands_[0]); case HloOpcode::kBatchNormTraining: return visitor->HandleBatchNormTraining(this); + case HloOpcode::kBatchNormInference: + return visitor->HandleBatchNormInference(this); case HloOpcode::kBatchNormGrad: return visitor->HandleBatchNormGrad(this); case HloOpcode::kSign: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index f2005380d8..d246720b3c 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -224,6 +224,12 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, HloInstruction* scale, HloInstruction* offset, float epsilon, int64 feature_index); + // Creates a batch-norm-inference instruction. + static std::unique_ptr<HloInstruction> CreateBatchNormInference( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, + float epsilon, int64 feature_index); + // Creates a batch-norm-grad instruction. static std::unique_ptr<HloInstruction> CreateBatchNormGrad( const Shape& shape, HloInstruction* operand, HloInstruction* scale, diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc index 3888f757ad..314512d0a8 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -33,6 +33,8 @@ string HloOpcodeString(HloOpcode opcode) { return "add"; case HloOpcode::kBatchNormTraining: return "batch-norm-training"; + case HloOpcode::kBatchNormInference: + return "batch-norm-inference"; case HloOpcode::kBatchNormGrad: return "batch-norm-grad"; case HloOpcode::kBitcast: diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 8a6376b2d1..c4d5efad90 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -31,6 +31,7 @@ enum class HloOpcode { kAbs, kAdd, kBatchNormTraining, + kBatchNormInference, kBatchNormGrad, kBitcast, kBroadcast, diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 4333db17e7..edfcb0922d 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -78,6 +78,7 @@ namespace xla { // Expensive instructions. case HloOpcode::kBatchNormTraining: + case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormGrad: case HloOpcode::kCall: case HloOpcode::kConvolution: diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index ad2d5235f8..d63d33ecb0 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -1211,6 +1211,10 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { handle_status = computation->AddBatchNormTrainingInstruction( arg->batch_norm_training_request()); break; + case OpRequest::kBatchNormInferenceRequest: + handle_status = computation->AddBatchNormInferenceInstruction( + arg->batch_norm_inference_request()); + break; case OpRequest::kBatchNormGradRequest: handle_status = computation->AddBatchNormGradInstruction( arg->batch_norm_grad_request()); diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 2c2b0cca5f..8eeb1cd5d2 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -885,6 +885,150 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( output_shape_for_mean_and_var}); } +/* static */ StatusOr<Shape> ShapeInference::InferBatchNormInferenceShape( + const Shape& operand_shape, const Shape& offset_shape, + const Shape& scale_shape, const Shape& mean_shape, + const Shape& variance_shape, int64 feature_index) { + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm inference")); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( + offset_shape, "offset input of batch norm inference")); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( + scale_shape, "scale input of batch norm inference")); + + TF_RET_CHECK(ShapeUtil::ValidateShape(operand_shape) == + tensorflow::Status::OK()); + TF_RET_CHECK(ShapeUtil::ValidateShape(offset_shape) == + tensorflow::Status::OK()); + TF_RET_CHECK(ShapeUtil::ValidateShape(scale_shape) == + tensorflow::Status::OK()); + TF_RET_CHECK(ShapeUtil::ValidateShape(mean_shape) == + tensorflow::Status::OK()); + TF_RET_CHECK(ShapeUtil::ValidateShape(variance_shape) == + tensorflow::Status::OK()); + + if (feature_index >= ShapeUtil::Rank(operand_shape)) { + return InvalidArgument( + "Expected feature_index of batch-norm-inference to be " + "smaller than the rank of operand_shape; " + "got feature_index %lld, and rank %lld", + feature_index, ShapeUtil::Rank(operand_shape)); + } + + if (feature_index < 0) { + return InvalidArgument( + "Expected feature_index of batch-norm-inference to " + "be a non-negative number, got %lld", + feature_index); + } + + if (ShapeUtil::Rank(operand_shape) < 1) { + return InvalidArgument( + "Expected the rank of operand to " + "batch-norm-inference to be at least 1; got %lld", + ShapeUtil::Rank(operand_shape)); + } + + if (ShapeUtil::Rank(offset_shape) != 1) { + return InvalidArgument( + "Offset input of batch-norm-inference must have" + " rank 1, but has rank %lld.", + ShapeUtil::Rank(offset_shape)); + } + + if (ShapeUtil::Rank(scale_shape) != 1) { + return InvalidArgument( + "Scale input of batch-norm-inference must have" + " rank 1, but has rank %lld.", + ShapeUtil::Rank(scale_shape)); + } + + if (!ShapeUtil::ElementIsFloating(operand_shape)) { + return InvalidArgument( + "The operand to batch-norm-inference must have a floating point " + "element type, but the shape is %s", + PrimitiveType_Name(operand_shape.element_type()).c_str()); + } + + if (!ShapeUtil::SameElementType(offset_shape, operand_shape)) { + return InvalidArgument( + "The inputs should have the same element type for " + "batch-norm-inference, " + "but the shape of offset factor is %s " + "and the shape of operand is %s", + PrimitiveType_Name(offset_shape.element_type()).c_str(), + PrimitiveType_Name(operand_shape.element_type()).c_str()); + } + + if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) { + return InvalidArgument( + "The inputs should have the same element type for " + "batch-norm-inference, " + "but the shape of scale factor is %s " + "and the shape of operand is %s", + PrimitiveType_Name(scale_shape.element_type()).c_str(), + PrimitiveType_Name(operand_shape.element_type()).c_str()); + } + + if (!ShapeUtil::SameElementType(mean_shape, operand_shape)) { + return InvalidArgument( + "The inputs should have the same element type for " + "batch-norm-inference, " + "but the shape of mean is %s " + "and the shape of operand is %s", + PrimitiveType_Name(mean_shape.element_type()).c_str(), + PrimitiveType_Name(operand_shape.element_type()).c_str()); + } + + if (!ShapeUtil::SameElementType(variance_shape, operand_shape)) { + return InvalidArgument( + "The inputs should have the same element type for " + "batch-norm-inference, " + "but the shape of variance is %s " + "and the shape of operand is %s", + PrimitiveType_Name(mean_shape.element_type()).c_str(), + PrimitiveType_Name(variance_shape.element_type()).c_str()); + } + + const int64 feature_count = operand_shape.dimensions(feature_index); + Shape output_shape_for_mean_and_var = + ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count}); + + if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) { + return InvalidArgument( + "The size of offset factor should be the same as feature count," + "but the size of offset factor is %lld " + "and the feature count is %lld", + ShapeUtil::GetDimension(offset_shape, 0), feature_count); + } + + if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) { + return InvalidArgument( + "The size of scale factor should be the same as feature count," + "but the size of scale factor is %lld " + "and the feature count is %lld", + ShapeUtil::GetDimension(scale_shape, 0), feature_count); + } + + if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) { + return InvalidArgument( + "The size of mean should be the same as feature count," + "but the size of mean is %lld " + "and the feature count is %lld", + ShapeUtil::GetDimension(mean_shape, 0), feature_count); + } + + if (ShapeUtil::GetDimension(variance_shape, 0) != feature_count) { + return InvalidArgument( + "The size of variance should be the same as feature count," + "but the size of variance is %lld " + "and the feature count is %lld", + ShapeUtil::GetDimension(variance_shape, 0), feature_count); + } + + return operand_shape; +} + /* static */ StatusOr<Shape> ShapeInference::InferBatchNormGradShape( const Shape& operand_shape, const Shape& scale_shape, const Shape& mean_shape, const Shape& var_shape, diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index f3f0176a43..5d55df92a9 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -71,6 +71,13 @@ class ShapeInference { const Shape& scale_shape, int64 feature_index); + // Infers the shape produced by InferBatchNormInference with the given + // operands. + static StatusOr<Shape> InferBatchNormInferenceShape( + const Shape& operand_shape, const Shape& offset_shape, + const Shape& scale_shape, const Shape& mean_shape, + const Shape& variance_shape, int64 feature_index); + // Infers the shape produced by InferBatchNormGrad with the given operands. static StatusOr<Shape> InferBatchNormGradShape(const Shape& operand_shape, const Shape& scale_shape, diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index 3b280c9727..cfa5c98f59 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -507,6 +507,53 @@ UserComputation::AddBatchNormTrainingInstruction( return handle; } +StatusOr<ComputationDataHandle> +UserComputation::AddBatchNormInferenceInstruction( + const BatchNormInferenceRequest& batch_norm_inference_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, + LookUpRequest(batch_norm_inference_request.operand())); + + TF_ASSIGN_OR_RETURN(const OperationRequest* scale, + LookUpRequest(batch_norm_inference_request.scale())); + + TF_ASSIGN_OR_RETURN(const OperationRequest* offset, + LookUpRequest(batch_norm_inference_request.offset())); + + TF_ASSIGN_OR_RETURN(const OperationRequest* mean, + LookUpRequest(batch_norm_inference_request.mean())); + + TF_ASSIGN_OR_RETURN(const OperationRequest* variance, + LookUpRequest(batch_norm_inference_request.variance())); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + + TF_ASSIGN_OR_RETURN(Shape inferred_shape, + ShapeInference::InferBatchNormInferenceShape( + operand->output_shape(), scale->output_shape(), + offset->output_shape(), mean->output_shape(), + variance->output_shape(), + batch_norm_inference_request.feature_index())); + + *request.mutable_output_shape() = inferred_shape; + + *request.mutable_output_handle() = handle; + + *request.mutable_request()->mutable_batch_norm_inference_request() = + batch_norm_inference_request; + + VLOG(1) << "AddBatchNormInferenceInstruction (" + << GetVersionedHandleInternal() << "), data handle " + << handle.handle() << ": " + << batch_norm_inference_request.ShortDebugString(); + + return handle; +} + StatusOr<ComputationDataHandle> UserComputation::AddBatchNormGradInstruction( const BatchNormGradRequest& batch_norm_grad_request) { tensorflow::mutex_lock lock(mutex_); @@ -1678,6 +1725,25 @@ void ConstantVisitor(const SessionComputation& session_computation, break; } + case OpRequest::kBatchNormInferenceRequest: { + const BatchNormInferenceRequest& batch_norm_inference_request = + request.request().batch_norm_inference_request(); + ConstantVisitor(session_computation, + batch_norm_inference_request.operand(), visited, + is_constant); + ConstantVisitor(session_computation, batch_norm_inference_request.scale(), + visited, is_constant); + ConstantVisitor(session_computation, + batch_norm_inference_request.offset(), visited, + is_constant); + ConstantVisitor(session_computation, batch_norm_inference_request.mean(), + visited, is_constant); + ConstantVisitor(session_computation, + batch_norm_inference_request.variance(), visited, + is_constant); + break; + } + case OpRequest::kBatchNormGradRequest: { const BatchNormGradRequest& batch_norm_grad_request = request.request().batch_norm_grad_request(); @@ -2119,6 +2185,18 @@ static void ForEachOperand( break; } + case OpRequest::kBatchNormInferenceRequest: { + const BatchNormInferenceRequest& batch_norm_inference_request = + request.request().batch_norm_inference_request(); + + apply(batch_norm_inference_request.operand()); + apply(batch_norm_inference_request.scale()); + apply(batch_norm_inference_request.offset()); + apply(batch_norm_inference_request.mean()); + apply(batch_norm_inference_request.variance()); + break; + } + case OpRequest::kBatchNormGradRequest: { const BatchNormGradRequest& batch_norm_grad_request = request.request().batch_norm_grad_request(); @@ -2647,6 +2725,28 @@ void ComputationLowerer::Visit( break; } + case OpRequest::kBatchNormInferenceRequest: { + const BatchNormInferenceRequest& batch_norm_inference_request = + request.request().batch_norm_inference_request(); + HloInstruction* operand = + lookup_instruction(batch_norm_inference_request.operand()); + HloInstruction* scale = + lookup_instruction(batch_norm_inference_request.scale()); + HloInstruction* offset = + lookup_instruction(batch_norm_inference_request.offset()); + HloInstruction* mean = + lookup_instruction(batch_norm_inference_request.mean()); + HloInstruction* variance = + lookup_instruction(batch_norm_inference_request.variance()); + + hlo_instruction = + add_instruction(HloInstruction::CreateBatchNormInference( + request.output_shape(), operand, scale, offset, mean, variance, + batch_norm_inference_request.epsilon(), + batch_norm_inference_request.feature_index())); + break; + } + case OpRequest::kBatchNormGradRequest: { const BatchNormGradRequest& batch_norm_grad_request = request.request().batch_norm_grad_request(); diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h index 36b1d34e05..b779b1f76c 100644 --- a/tensorflow/compiler/xla/service/user_computation.h +++ b/tensorflow/compiler/xla/service/user_computation.h @@ -89,6 +89,10 @@ class UserComputation { StatusOr<ComputationDataHandle> AddBatchNormTrainingInstruction( const BatchNormTrainingRequest& batch_norm_training_request); + // Enqueues a batch norm inference instruction onto this user computation. + StatusOr<ComputationDataHandle> AddBatchNormInferenceInstruction( + const BatchNormInferenceRequest& batch_norm_inference_request); + // Enqueues a batch norm grad instruction onto this user computation. StatusOr<ComputationDataHandle> AddBatchNormGradInstruction( const BatchNormGradRequest& batch_norm_grad_request); diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index 34b3abb8c7..028d1251b4 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -306,6 +306,109 @@ XLA_TEST_P(BatchNormTest, RandomizedTests) { ErrorSpec(0.01, 1)); } +XLA_TEST_P(BatchNormTest, RandomizedInferencingTests) { + float epsilon = 0.001; + ComputationBuilder builder(client_, TestName()); + const std::vector<int64>& bounds = GetParam().bounds; + Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]); + input_array.FillRandom(GetParam().random_value_var, + GetParam().random_value_mean); + + const int64 feature_index = GetParam().feature_index; + const int64 num_elements_per_feature = + Product(bounds) / bounds[feature_index]; + const int64 feature_bound = bounds[feature_index]; + std::vector<float> offset(feature_bound, 1); + std::vector<float> scale(feature_bound, 2); + + auto input_squared = + ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; }); + std::vector<int64> reduce_dims; + for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) { + if (i != feature_index) { + reduce_dims.push_back(i); + } + } + + auto sum = + ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims, + [](float a, float b) { return a + b; }); + + auto sum_squared = + ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims, + [](float a, float b) { return a + b; }); + + std::vector<float> mean(feature_bound); + + for (int64 i = 0; i < feature_bound; ++i) { + mean[i] = sum[i] / num_elements_per_feature; + } + + std::vector<float> mean_square(feature_bound); + for (int64 i = 0; i < feature_bound; ++i) { + mean_square[i] = mean[i] * mean[i]; + } + + std::vector<float> square_mean(feature_bound); + for (int64 i = 0; i < feature_bound; ++i) { + square_mean[i] = sum_squared[i] / num_elements_per_feature; + } + + std::vector<float> var(feature_bound); + for (int64 i = 0; i < feature_bound; ++i) { + var[i] = square_mean[i] - mean_square[i]; + } + + Array4D<float> mean4D = + *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index); + auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index); + auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index); + auto offset4D = + *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index); + + auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D, + scale4D, offset4D, epsilon); + + auto offset_literal = Literal::CreateR1<float>(offset); + auto scale_literal = Literal::CreateR1<float>(scale); + auto mean_literal = Literal::CreateR1<float>(mean); + auto var_literal = Literal::CreateR1<float>(var); + auto input_literal = Literal::CreateR4FromArray4D<float>(input_array); + + auto input_activations = + builder.Parameter(0, input_literal->shape(), "input"); + auto scale_activations = + builder.Parameter(1, scale_literal->shape(), "offset"); + auto offset_activations = + builder.Parameter(2, offset_literal->shape(), "scale"); + auto mean_activations = builder.Parameter(3, mean_literal->shape(), "mean"); + auto variance_activations = + builder.Parameter(4, var_literal->shape(), "variance"); + + Array4D<float> expected = normalized; + + std::unique_ptr<GlobalData> input_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + std::unique_ptr<GlobalData> scale_data = + client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); + std::unique_ptr<GlobalData> offset_data = + client_->TransferToServer(*offset_literal).ConsumeValueOrDie(); + std::unique_ptr<GlobalData> mean_data = + client_->TransferToServer(*mean_literal).ConsumeValueOrDie(); + std::unique_ptr<GlobalData> variance_data = + client_->TransferToServer(*var_literal).ConsumeValueOrDie(); + + builder.BatchNormInference(input_activations, scale_activations, + offset_activations, mean_activations, + variance_activations, epsilon, feature_index); + + ComputeAndCompareR4<float>( + &builder, expected, + {input_data.get(), scale_data.get(), offset_data.get(), mean_data.get(), + variance_data.get()}, + ErrorSpec(0.01, 1)); +} + XLA_TEST_P(BatchNormTest, RandomizedGradTests) { float epsilon = 0.001; ComputationBuilder builder(client_, TestName()); diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 38e6675ab7..185ca7e681 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -491,6 +491,16 @@ message BatchNormTrainingRequest { int64 feature_index = 5; } +message BatchNormInferenceRequest { + ComputationDataHandle operand = 1; + ComputationDataHandle scale = 2; + ComputationDataHandle offset = 3; + ComputationDataHandle mean = 4; + ComputationDataHandle variance = 5; + float epsilon = 6; + int64 feature_index = 7; +} + message BatchNormGradRequest { ComputationDataHandle operand = 1; ComputationDataHandle scale = 2; @@ -813,7 +823,8 @@ message OpRequest { OutfeedRequest outfeed_request = 32; BatchNormTrainingRequest batch_norm_training_request = 35; BatchNormGradRequest batch_norm_grad_request = 37; - // Next: 38 + BatchNormInferenceRequest batch_norm_inference_request = 38; + // Next: 39 } } |