aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-08-17 18:04:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-17 18:08:34 -0700
commit7359fec792e4efec1670a12332bb524a5608b215 (patch)
treecce4b3a2360071424ab2d794bc86de4598fe51fb /tensorflow/compiler
parentf0da8bf56ba1b625d53b760683bc44f67e204199 (diff)
Implement Batchnorm Inference by expanding them into smaller ops.
1. Add batch norm inference support in batchnorm_rewriter 2. Connect xla's batchnorm inference to tf's FusedBatchNorm RELNOTES: n/a PiperOrigin-RevId: 165655351
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/tests/fused_batchnorm_test.py33
-rw-r--r--tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc44
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.cc26
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_rewriter.cc98
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_rewriter.h3
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_rewriter_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc1
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h3
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_compiler.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc26
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h6
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h1
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc1
-rw-r--r--tensorflow/compiler/xla/service/service.cc4
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc144
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h7
-rw-r--r--tensorflow/compiler/xla/service/user_computation.cc100
-rw-r--r--tensorflow/compiler/xla/service/user_computation.h4
-rw-r--r--tensorflow/compiler/xla/tests/batch_normalization_test.cc103
-rw-r--r--tensorflow/compiler/xla/xla_data.proto13
25 files changed, 606 insertions, 28 deletions
diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py
index f8e9fc9268..936fcf8b6b 100644
--- a/tensorflow/compiler/tests/fused_batchnorm_test.py
+++ b/tensorflow/compiler/tests/fused_batchnorm_test.py
@@ -63,6 +63,39 @@ class FusedBatchNormTest(XLATestCase):
grad_offset = np.sum(grad_y, axis=(0, 1, 2))
return grad_x, grad_scale, grad_offset
+ def testInference(self):
+ x_shape = [2, 2, 6, 2]
+ scale_shape = [2]
+ x_val = np.random.random_sample(x_shape).astype(np.float32)
+ scale_val = np.random.random_sample(scale_shape).astype(np.float32)
+
+ offset_val = np.random.random_sample(scale_shape).astype(np.float32)
+ data_format = "NHWC"
+ with self.test_session() as sess, self.test_scope():
+ # To avoid constant folding
+ t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x")
+ scale = array_ops.placeholder(np.float32, shape=[2], name="scale")
+ offset = array_ops.placeholder(np.float32, shape=[2], name="offset")
+ epsilon = 0.001
+ y_ref, mean_ref, var_ref = self._reference_training(
+ x_val, scale_val, offset_val, epsilon, data_format)
+ y, mean, variance = nn.fused_batch_norm(
+ t_val,
+ scale,
+ offset,
+ mean=mean_ref,
+ variance=var_ref,
+ epsilon=epsilon,
+ data_format=data_format,
+ is_training=False)
+
+ y_val, _, _ = sess.run(
+ [y, mean,
+ variance], {t_val: x_val,
+ scale: scale_val,
+ offset: offset_val})
+ self.assertAllClose(y_val, y_ref, atol=1e-3)
+
def _testLearning(self, use_gradient_checker):
x_shape = [2, 2, 6, 2]
scale_shape = [2]
diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
index 3f23e459b9..9d2703bf95 100644
--- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
@@ -39,28 +39,36 @@ class FusedBatchNormOp : public XlaOpKernel {
errors::InvalidArgument("Not supported format"));
feature_index_ = GetTensorFeatureDimIndex(/*num_dims=*/4, tensor_format);
}
- // TODO(b/62843645): Implement BatchNormInference.
- OP_REQUIRES(
- ctx, is_training_,
- errors::InvalidArgument("Fused batch normalization for inference is "
- "not supported yet on XLA backend."));
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationDataHandle output = ctx->builder()->BatchNormTraining(
- ctx->Input(0), ctx->Input(1), ctx->Input(2), epsilon_, feature_index_);
-
- // In training mode, outputs the normalized value as well as the calculated
- // mean and variance.
- for (int i = 0; i < 3; i++) {
- ctx->SetOutput(i, ctx->builder()->GetTupleElement(output, i));
+ if (is_training_) {
+ xla::ComputationDataHandle output = ctx->builder()->BatchNormTraining(
+ ctx->Input(0), ctx->Input(1), ctx->Input(2), epsilon_,
+ feature_index_);
+
+ // In training mode, outputs the normalized value as well as the
+ // calculated mean and variance.
+ for (int i = 0; i < 3; i++) {
+ ctx->SetOutput(i, ctx->builder()->GetTupleElement(output, i));
+ }
+ // Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved
+ // space 1 & 2". They are used to pass the per-batch mean and
+ // variance to the gradient. Here we maintain the same behavior by setting
+ // them to the mean and variance calculated by BatchNormTraining.
+ ctx->SetOutput(3, ctx->builder()->GetTupleElement(output, 1));
+ ctx->SetOutput(4, ctx->builder()->GetTupleElement(output, 2));
+ } else {
+ xla::ComputationDataHandle output = ctx->builder()->BatchNormInference(
+ ctx->Input(0), ctx->Input(1), ctx->Input(2), ctx->Input(3),
+ ctx->Input(4), epsilon_, feature_index_);
+ ctx->SetOutput(0, output);
+ // Directly send input to output as mean and variance in inference mode.
+ ctx->SetOutput(1, ctx->Input(3));
+ ctx->SetOutput(2, ctx->Input(4));
+ ctx->SetOutput(3, ctx->Input(3));
+ ctx->SetOutput(4, ctx->Input(4));
}
- // Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved
- // space 1 & 2". They are used to pass the per-batch mean and
- // variance to the gradient. Here we maintain the same behavior by setting
- // them to the mean and variance calculated by BatchNormTraining.
- ctx->SetOutput(3, ctx->builder()->GetTupleElement(output, 1));
- ctx->SetOutput(4, ctx->builder()->GetTupleElement(output, 2));
}
private:
diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc
index e6ffc4f98d..30afaed732 100644
--- a/tensorflow/compiler/xla/client/computation_builder.cc
+++ b/tensorflow/compiler/xla/client/computation_builder.cc
@@ -1477,9 +1477,29 @@ ComputationDataHandle ComputationBuilder::BatchNormInference(
const ComputationDataHandle& operand, const ComputationDataHandle& scale,
const ComputationDataHandle& offset, const ComputationDataHandle& mean,
const ComputationDataHandle& variance, float epsilon, int64 feature_index) {
- // TODO(b/62843645): Implement BatchNormInference.
- NoteError(Unimplemented("BatchNormInference is not implemented yet."));
- return ComputationDataHandle();
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+ BatchNormInferenceRequest request;
+ *request.mutable_operand() = operand;
+ *request.mutable_scale() = scale;
+ *request.mutable_offset() = offset;
+ *request.mutable_mean() = mean;
+ *request.mutable_variance() = variance;
+ request.set_epsilon(epsilon);
+ request.set_feature_index(feature_index);
+
+ OpRequest op_request;
+ *op_request.mutable_batch_norm_inference_request() = request;
+ *op_request.mutable_computation() = computation_.handle();
+ AddOpMetadata(&op_request);
+
+ OpResponse response;
+
+ VLOG(2) << "making BatchNormInference request";
+
+ Status s = client_->stub()->Op(&op_request, &response);
+ return ParseOpResponse(s, &response);
}
ComputationDataHandle ComputationBuilder::BatchNormGrad(
diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/batchnorm_rewriter.cc
index 721d99301a..41d32d0c8b 100644
--- a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_rewriter.cc
@@ -56,11 +56,14 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
Status HandleBatchNormTraining(HloInstruction* batch_norm) override;
+ Status HandleBatchNormInference(HloInstruction* batch_norm) override;
+
Status HandleBatchNormGrad(HloInstruction* batch_norm) override;
// Runs the visitor on a computation.
static bool Run(HloComputation* computation, bool rewrite_training_op,
- bool rewrite_grad_op, bool use_fusion);
+ bool rewrite_inference_op, bool rewrite_grad_op,
+ bool use_fusion);
// Returns whether any batch norm ops were rewritten.
const bool changed() const { return changed_; }
@@ -70,9 +73,11 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
private:
explicit BatchNormRewriterVisitor(HloComputation* computation,
bool rewrite_training_op,
+ bool rewrite_inference_op,
bool rewrite_grad_op, bool use_fusion)
: computation_(computation),
rewrite_training_op_(rewrite_training_op),
+ rewrite_inference_op_(rewrite_inference_op),
rewrite_grad_op_(rewrite_grad_op),
use_fusion_(use_fusion) {}
@@ -94,6 +99,7 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
HloComputation* computation_;
bool rewrite_training_op_;
+ bool rewrite_inference_op_;
bool rewrite_grad_op_;
bool use_fusion_;
@@ -126,11 +132,14 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
bool BatchNormRewriterVisitor::Run(HloComputation* computation,
bool rewrite_training_op,
+ bool rewrite_inference_op,
bool rewrite_grad_op, bool use_fusion) {
- BatchNormRewriterVisitor visitor(computation,
- /*rewrite_training_op=*/rewrite_training_op,
- /*rewrite_grad_op=*/rewrite_grad_op,
- /*use_fusion=*/use_fusion);
+ BatchNormRewriterVisitor visitor(
+ computation,
+ /*rewrite_training_op=*/rewrite_training_op,
+ /*rewrite_inference_op=*/rewrite_inference_op,
+ /*rewrite_grad_op=*/rewrite_grad_op,
+ /*use_fusion=*/use_fusion);
TF_CHECK_OK(computation->Accept(&visitor));
return visitor.changed_;
}
@@ -268,6 +277,82 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining(
return Status::OK();
}
+Status BatchNormRewriterVisitor::HandleBatchNormInference(
+ HloInstruction* batch_norm) {
+ if (!rewrite_inference_op_) {
+ return Status::OK();
+ }
+ // Expand batch norm inference into smaller HLO ops.
+ HloInstruction* operand = batch_norm->mutable_operand(0);
+ const Shape operand_shape = operand->shape();
+ int64 feature_index = batch_norm->feature_index();
+
+ HloInstruction* scale = batch_norm->mutable_operand(1);
+ HloInstruction* offset = batch_norm->mutable_operand(2);
+ HloInstruction* mean = batch_norm->mutable_operand(3);
+ HloInstruction* var = batch_norm->mutable_operand(4);
+ const Shape feature_shape = scale->shape();
+
+ auto epsilon = computation_->AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon())));
+
+ std::vector<int64> dimensions_without_feature;
+
+ for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) {
+ if (i != feature_index) {
+ dimensions_without_feature.push_back(i);
+ }
+ }
+
+ auto scale_broadcasted = computation_->AddInstruction(
+ HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index}));
+
+ auto offset_broadcasted = computation_->AddInstruction(
+ HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index}));
+
+ auto mean_broadcasted = computation_->AddInstruction(
+ HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index}));
+
+ auto var_broadcasted = computation_->AddInstruction(
+ HloInstruction::CreateBroadcast(operand_shape, var, {feature_index}));
+
+ // Var[X] + epsilon.
+ auto var_add_epsilon =
+ computation_->AddInstruction(HloInstruction::CreateBinary(
+ operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon));
+
+ auto neg_half = computation_->AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0(-0.5f)));
+
+ // 1 / Sqrt[Var[X] + epsilon].
+ auto rsqrt_var_add_epsilon =
+ computation_->AddInstruction(HloInstruction::CreateBinary(
+ operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half));
+
+ // X - E[X].
+ auto operand_minus_mean =
+ computation_->AddInstruction(HloInstruction::CreateBinary(
+ operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted));
+
+ // (X - E[X]) / Sqrt[Var[X] + epsilon].
+ auto normalized = computation_->AddInstruction(
+ HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply,
+ operand_minus_mean, rsqrt_var_add_epsilon));
+
+ // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale.
+ auto scaled_normalized =
+ computation_->AddInstruction(HloInstruction::CreateBinary(
+ operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted));
+
+ // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset.
+ auto shifted_normalized = HloInstruction::CreateBinary(
+ operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted);
+
+ TF_CHECK_OK(
+ ReplaceWithNewInstruction(batch_norm, std::move(shifted_normalized)));
+ return Status::OK();
+}
+
Status BatchNormRewriterVisitor::HandleBatchNormGrad(
HloInstruction* batch_norm) {
// Use the following formulas to calculate gradients:
@@ -457,7 +542,8 @@ StatusOr<bool> BatchNormRewriter::Run(HloModule* module) {
}
for (auto& comp : computations) {
if (BatchNormRewriterVisitor::Run(comp, rewrite_training_op_,
- rewrite_grad_op_, use_fusion_)) {
+ rewrite_inference_op_, rewrite_grad_op_,
+ use_fusion_)) {
changed = true;
}
}
diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter.h b/tensorflow/compiler/xla/service/batchnorm_rewriter.h
index d3ffb31032..f601741d96 100644
--- a/tensorflow/compiler/xla/service/batchnorm_rewriter.h
+++ b/tensorflow/compiler/xla/service/batchnorm_rewriter.h
@@ -30,8 +30,10 @@ class BatchNormRewriter : public HloPassInterface {
public:
// When use_fusion is set, a multi-output fusion node is created.
BatchNormRewriter(bool rewrite_training_op = false,
+ bool rewrite_inference_op = false,
bool rewrite_grad_op = false, bool use_fusion = true)
: rewrite_training_op_(rewrite_training_op),
+ rewrite_inference_op_(rewrite_inference_op),
rewrite_grad_op_(rewrite_grad_op),
use_fusion_(use_fusion) {}
~BatchNormRewriter() = default;
@@ -43,6 +45,7 @@ class BatchNormRewriter : public HloPassInterface {
private:
bool rewrite_training_op_;
+ bool rewrite_inference_op_;
bool rewrite_grad_op_;
bool use_fusion_;
};
diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter_test.cc b/tensorflow/compiler/xla/service/batchnorm_rewriter_test.cc
index cc8dffcda5..07775623e7 100644
--- a/tensorflow/compiler/xla/service/batchnorm_rewriter_test.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_rewriter_test.cc
@@ -64,6 +64,7 @@ TEST_F(BatchNormRewriterTest, BatchNormTraining) {
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormTraining);
BatchNormRewriter rewriter(/*rewrite_training_op=*/true,
+ /*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true);
ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
root = computation->root_instruction();
@@ -105,6 +106,7 @@ TEST_F(BatchNormRewriterTest, BatchNormGrad) {
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormGrad);
BatchNormRewriter rewriter(/*rewrite_training_op=*/true,
+ /*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true);
ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
root = computation->root_instruction();
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index eca9b0f4be..8a37c8108e 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -260,6 +260,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module) {
pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
pass.AddPass<BatchNormRewriter>(
/*rewrite_training_op=*/true,
+ /*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true,
/*use_fusion=*/false);
pass.AddPass<AlgebraicSimplifier>(
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index e450b31ff1..4baa56658f 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -228,6 +228,9 @@ class DfsHloVisitor {
virtual Status HandleBatchNormTraining(HloInstruction* batchNormTraining) = 0;
+ virtual Status HandleBatchNormInference(
+ HloInstruction* batchNormInference) = 0;
+
virtual Status HandleBatchNormGrad(HloInstruction* batchNormGrad) = 0;
// Invoked to inform the visitor that the traversal has completed, and that
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
index c447165cec..10f8ae9b04 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
@@ -54,6 +54,10 @@ class DfsHloVisitorWithDefault : public DfsHloVisitor {
return DefaultAction(hlo);
}
+ Status HandleBatchNormInference(HloInstruction* hlo) override {
+ return DefaultAction(hlo);
+ }
+
Status HandleBatchNormGrad(HloInstruction* hlo) override {
return DefaultAction(hlo);
}
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index 2a7486af88..cd913a4b5d 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -135,6 +135,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module,
// instead.
pass.AddPass<BatchNormRewriter>(
/*rewrite_training_op=*/true,
+ /*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true,
/*use_fusion=*/false);
pass.AddPass<AlgebraicSimplifier>(
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index d113ca2a76..9dbde0ec24 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -374,6 +374,12 @@ Status HloCostAnalysis::HandleBatchNormTraining(
return Status::OK();
}
+Status HloCostAnalysis::HandleBatchNormInference(
+ HloInstruction* batchNormInference) {
+ // TODO(b/62294698): Implement cost analysis for batch-norm-inference.
+ return Status::OK();
+}
+
Status HloCostAnalysis::HandleBatchNormGrad(HloInstruction* batchNormGrad) {
// TODO(b/62294698): Implement cost analysis for batch-norm-grad.
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
index ec48c8a0fd..6d8fdfa64b 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -89,6 +89,7 @@ class HloCostAnalysis : public DfsHloVisitor {
tensorflow::gtl::ArraySlice<int64> dimensions,
HloComputation* function_handle) override;
Status HandleBatchNormTraining(HloInstruction* batchNormTraining) override;
+ Status HandleBatchNormInference(HloInstruction* batchNormInference) override;
Status HandleBatchNormGrad(HloInstruction* batchNormGrad) override;
Status HandleFusion(HloInstruction* fusion) override;
Status HandleCall(HloInstruction* call) override;
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index d1c3196366..38b1291d44 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -742,6 +742,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kParameter:
return kOrange;
case HloOpcode::kBatchNormTraining:
+ case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormGrad:
case HloOpcode::kReduce:
case HloOpcode::kSelectAndScatter:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 825f3f8f60..fb9dbd6421 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -407,6 +407,23 @@ HloInstruction::CreateBatchNormTraining(const Shape& shape,
}
/* static */ std::unique_ptr<HloInstruction>
+HloInstruction::CreateBatchNormInference(
+ const Shape& shape, HloInstruction* operand, HloInstruction* scale,
+ HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
+ float epsilon, int64 feature_index) {
+ auto instruction =
+ WrapUnique(new HloInstruction(HloOpcode::kBatchNormInference, shape));
+ instruction->AppendOperand(operand);
+ instruction->AppendOperand(scale);
+ instruction->AppendOperand(offset);
+ instruction->AppendOperand(mean);
+ instruction->AppendOperand(variance);
+ instruction->epsilon_ = epsilon;
+ instruction->feature_index_ = feature_index;
+ return instruction;
+}
+
+/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand,
HloInstruction* scale, HloInstruction* mean,
HloInstruction* variance,
@@ -1065,6 +1082,12 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
return CreateBatchNormTraining(shape, new_operands[0], new_operands[1],
new_operands[2], epsilon(),
feature_index());
+
+ case HloOpcode::kBatchNormInference:
+ CHECK_EQ(new_operands.size(), 5);
+ return CreateBatchNormInference(
+ shape, new_operands[0], new_operands[1], new_operands[2],
+ new_operands[3], new_operands[4], epsilon(), feature_index());
case HloOpcode::kInfeed:
CHECK_EQ(new_operands.size(), 0);
return CreateInfeed(shape, infeed_config());
@@ -1355,6 +1378,7 @@ bool HloInstruction::IdenticalSlowPath(
ShapeUtil::Compatible(shape(), other.shape());
case HloOpcode::kBatchNormTraining:
+ case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormGrad:
return feature_index() == other.feature_index() &&
epsilon() == other.epsilon();
@@ -1952,6 +1976,8 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) {
return visitor->HandleAbs(this, operands_[0]);
case HloOpcode::kBatchNormTraining:
return visitor->HandleBatchNormTraining(this);
+ case HloOpcode::kBatchNormInference:
+ return visitor->HandleBatchNormInference(this);
case HloOpcode::kBatchNormGrad:
return visitor->HandleBatchNormGrad(this);
case HloOpcode::kSign:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index f2005380d8..d246720b3c 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -224,6 +224,12 @@ class HloInstruction {
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
HloInstruction* offset, float epsilon, int64 feature_index);
+ // Creates a batch-norm-inference instruction.
+ static std::unique_ptr<HloInstruction> CreateBatchNormInference(
+ const Shape& shape, HloInstruction* operand, HloInstruction* scale,
+ HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
+ float epsilon, int64 feature_index);
+
// Creates a batch-norm-grad instruction.
static std::unique_ptr<HloInstruction> CreateBatchNormGrad(
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc
index 3888f757ad..314512d0a8 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.cc
+++ b/tensorflow/compiler/xla/service/hlo_opcode.cc
@@ -33,6 +33,8 @@ string HloOpcodeString(HloOpcode opcode) {
return "add";
case HloOpcode::kBatchNormTraining:
return "batch-norm-training";
+ case HloOpcode::kBatchNormInference:
+ return "batch-norm-inference";
case HloOpcode::kBatchNormGrad:
return "batch-norm-grad";
case HloOpcode::kBitcast:
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index 8a6376b2d1..c4d5efad90 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -31,6 +31,7 @@ enum class HloOpcode {
kAbs,
kAdd,
kBatchNormTraining,
+ kBatchNormInference,
kBatchNormGrad,
kBitcast,
kBroadcast,
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 4333db17e7..edfcb0922d 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -78,6 +78,7 @@ namespace xla {
// Expensive instructions.
case HloOpcode::kBatchNormTraining:
+ case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormGrad:
case HloOpcode::kCall:
case HloOpcode::kConvolution:
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index ad2d5235f8..d63d33ecb0 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -1211,6 +1211,10 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) {
handle_status = computation->AddBatchNormTrainingInstruction(
arg->batch_norm_training_request());
break;
+ case OpRequest::kBatchNormInferenceRequest:
+ handle_status = computation->AddBatchNormInferenceInstruction(
+ arg->batch_norm_inference_request());
+ break;
case OpRequest::kBatchNormGradRequest:
handle_status = computation->AddBatchNormGradInstruction(
arg->batch_norm_grad_request());
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 2c2b0cca5f..8eeb1cd5d2 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -885,6 +885,150 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
output_shape_for_mean_and_var});
}
+/* static */ StatusOr<Shape> ShapeInference::InferBatchNormInferenceShape(
+ const Shape& operand_shape, const Shape& offset_shape,
+ const Shape& scale_shape, const Shape& mean_shape,
+ const Shape& variance_shape, int64 feature_index) {
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm inference"));
+ TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
+ offset_shape, "offset input of batch norm inference"));
+ TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
+ scale_shape, "scale input of batch norm inference"));
+
+ TF_RET_CHECK(ShapeUtil::ValidateShape(operand_shape) ==
+ tensorflow::Status::OK());
+ TF_RET_CHECK(ShapeUtil::ValidateShape(offset_shape) ==
+ tensorflow::Status::OK());
+ TF_RET_CHECK(ShapeUtil::ValidateShape(scale_shape) ==
+ tensorflow::Status::OK());
+ TF_RET_CHECK(ShapeUtil::ValidateShape(mean_shape) ==
+ tensorflow::Status::OK());
+ TF_RET_CHECK(ShapeUtil::ValidateShape(variance_shape) ==
+ tensorflow::Status::OK());
+
+ if (feature_index >= ShapeUtil::Rank(operand_shape)) {
+ return InvalidArgument(
+ "Expected feature_index of batch-norm-inference to be "
+ "smaller than the rank of operand_shape; "
+ "got feature_index %lld, and rank %lld",
+ feature_index, ShapeUtil::Rank(operand_shape));
+ }
+
+ if (feature_index < 0) {
+ return InvalidArgument(
+ "Expected feature_index of batch-norm-inference to "
+ "be a non-negative number, got %lld",
+ feature_index);
+ }
+
+ if (ShapeUtil::Rank(operand_shape) < 1) {
+ return InvalidArgument(
+ "Expected the rank of operand to "
+ "batch-norm-inference to be at least 1; got %lld",
+ ShapeUtil::Rank(operand_shape));
+ }
+
+ if (ShapeUtil::Rank(offset_shape) != 1) {
+ return InvalidArgument(
+ "Offset input of batch-norm-inference must have"
+ " rank 1, but has rank %lld.",
+ ShapeUtil::Rank(offset_shape));
+ }
+
+ if (ShapeUtil::Rank(scale_shape) != 1) {
+ return InvalidArgument(
+ "Scale input of batch-norm-inference must have"
+ " rank 1, but has rank %lld.",
+ ShapeUtil::Rank(scale_shape));
+ }
+
+ if (!ShapeUtil::ElementIsFloating(operand_shape)) {
+ return InvalidArgument(
+ "The operand to batch-norm-inference must have a floating point "
+ "element type, but the shape is %s",
+ PrimitiveType_Name(operand_shape.element_type()).c_str());
+ }
+
+ if (!ShapeUtil::SameElementType(offset_shape, operand_shape)) {
+ return InvalidArgument(
+ "The inputs should have the same element type for "
+ "batch-norm-inference, "
+ "but the shape of offset factor is %s "
+ "and the shape of operand is %s",
+ PrimitiveType_Name(offset_shape.element_type()).c_str(),
+ PrimitiveType_Name(operand_shape.element_type()).c_str());
+ }
+
+ if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) {
+ return InvalidArgument(
+ "The inputs should have the same element type for "
+ "batch-norm-inference, "
+ "but the shape of scale factor is %s "
+ "and the shape of operand is %s",
+ PrimitiveType_Name(scale_shape.element_type()).c_str(),
+ PrimitiveType_Name(operand_shape.element_type()).c_str());
+ }
+
+ if (!ShapeUtil::SameElementType(mean_shape, operand_shape)) {
+ return InvalidArgument(
+ "The inputs should have the same element type for "
+ "batch-norm-inference, "
+ "but the shape of mean is %s "
+ "and the shape of operand is %s",
+ PrimitiveType_Name(mean_shape.element_type()).c_str(),
+ PrimitiveType_Name(operand_shape.element_type()).c_str());
+ }
+
+ if (!ShapeUtil::SameElementType(variance_shape, operand_shape)) {
+ return InvalidArgument(
+ "The inputs should have the same element type for "
+ "batch-norm-inference, "
+ "but the shape of variance is %s "
+ "and the shape of operand is %s",
+ PrimitiveType_Name(mean_shape.element_type()).c_str(),
+ PrimitiveType_Name(variance_shape.element_type()).c_str());
+ }
+
+ const int64 feature_count = operand_shape.dimensions(feature_index);
+ Shape output_shape_for_mean_and_var =
+ ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
+
+ if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
+ return InvalidArgument(
+ "The size of offset factor should be the same as feature count,"
+ "but the size of offset factor is %lld "
+ "and the feature count is %lld",
+ ShapeUtil::GetDimension(offset_shape, 0), feature_count);
+ }
+
+ if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
+ return InvalidArgument(
+ "The size of scale factor should be the same as feature count,"
+ "but the size of scale factor is %lld "
+ "and the feature count is %lld",
+ ShapeUtil::GetDimension(scale_shape, 0), feature_count);
+ }
+
+ if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
+ return InvalidArgument(
+ "The size of mean should be the same as feature count,"
+ "but the size of mean is %lld "
+ "and the feature count is %lld",
+ ShapeUtil::GetDimension(mean_shape, 0), feature_count);
+ }
+
+ if (ShapeUtil::GetDimension(variance_shape, 0) != feature_count) {
+ return InvalidArgument(
+ "The size of variance should be the same as feature count,"
+ "but the size of variance is %lld "
+ "and the feature count is %lld",
+ ShapeUtil::GetDimension(variance_shape, 0), feature_count);
+ }
+
+ return operand_shape;
+}
+
/* static */ StatusOr<Shape> ShapeInference::InferBatchNormGradShape(
const Shape& operand_shape, const Shape& scale_shape,
const Shape& mean_shape, const Shape& var_shape,
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index f3f0176a43..5d55df92a9 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -71,6 +71,13 @@ class ShapeInference {
const Shape& scale_shape,
int64 feature_index);
+ // Infers the shape produced by InferBatchNormInference with the given
+ // operands.
+ static StatusOr<Shape> InferBatchNormInferenceShape(
+ const Shape& operand_shape, const Shape& offset_shape,
+ const Shape& scale_shape, const Shape& mean_shape,
+ const Shape& variance_shape, int64 feature_index);
+
// Infers the shape produced by InferBatchNormGrad with the given operands.
static StatusOr<Shape> InferBatchNormGradShape(const Shape& operand_shape,
const Shape& scale_shape,
diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc
index 3b280c9727..cfa5c98f59 100644
--- a/tensorflow/compiler/xla/service/user_computation.cc
+++ b/tensorflow/compiler/xla/service/user_computation.cc
@@ -507,6 +507,53 @@ UserComputation::AddBatchNormTrainingInstruction(
return handle;
}
+StatusOr<ComputationDataHandle>
+UserComputation::AddBatchNormInferenceInstruction(
+ const BatchNormInferenceRequest& batch_norm_inference_request) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
+ LookUpRequest(batch_norm_inference_request.operand()));
+
+ TF_ASSIGN_OR_RETURN(const OperationRequest* scale,
+ LookUpRequest(batch_norm_inference_request.scale()));
+
+ TF_ASSIGN_OR_RETURN(const OperationRequest* offset,
+ LookUpRequest(batch_norm_inference_request.offset()));
+
+ TF_ASSIGN_OR_RETURN(const OperationRequest* mean,
+ LookUpRequest(batch_norm_inference_request.mean()));
+
+ TF_ASSIGN_OR_RETURN(const OperationRequest* variance,
+ LookUpRequest(batch_norm_inference_request.variance()));
+
+ ComputationDataHandle handle = CreateComputationDataHandle();
+
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+
+ TF_ASSIGN_OR_RETURN(Shape inferred_shape,
+ ShapeInference::InferBatchNormInferenceShape(
+ operand->output_shape(), scale->output_shape(),
+ offset->output_shape(), mean->output_shape(),
+ variance->output_shape(),
+ batch_norm_inference_request.feature_index()));
+
+ *request.mutable_output_shape() = inferred_shape;
+
+ *request.mutable_output_handle() = handle;
+
+ *request.mutable_request()->mutable_batch_norm_inference_request() =
+ batch_norm_inference_request;
+
+ VLOG(1) << "AddBatchNormInferenceInstruction ("
+ << GetVersionedHandleInternal() << "), data handle "
+ << handle.handle() << ": "
+ << batch_norm_inference_request.ShortDebugString();
+
+ return handle;
+}
+
StatusOr<ComputationDataHandle> UserComputation::AddBatchNormGradInstruction(
const BatchNormGradRequest& batch_norm_grad_request) {
tensorflow::mutex_lock lock(mutex_);
@@ -1678,6 +1725,25 @@ void ConstantVisitor(const SessionComputation& session_computation,
break;
}
+ case OpRequest::kBatchNormInferenceRequest: {
+ const BatchNormInferenceRequest& batch_norm_inference_request =
+ request.request().batch_norm_inference_request();
+ ConstantVisitor(session_computation,
+ batch_norm_inference_request.operand(), visited,
+ is_constant);
+ ConstantVisitor(session_computation, batch_norm_inference_request.scale(),
+ visited, is_constant);
+ ConstantVisitor(session_computation,
+ batch_norm_inference_request.offset(), visited,
+ is_constant);
+ ConstantVisitor(session_computation, batch_norm_inference_request.mean(),
+ visited, is_constant);
+ ConstantVisitor(session_computation,
+ batch_norm_inference_request.variance(), visited,
+ is_constant);
+ break;
+ }
+
case OpRequest::kBatchNormGradRequest: {
const BatchNormGradRequest& batch_norm_grad_request =
request.request().batch_norm_grad_request();
@@ -2119,6 +2185,18 @@ static void ForEachOperand(
break;
}
+ case OpRequest::kBatchNormInferenceRequest: {
+ const BatchNormInferenceRequest& batch_norm_inference_request =
+ request.request().batch_norm_inference_request();
+
+ apply(batch_norm_inference_request.operand());
+ apply(batch_norm_inference_request.scale());
+ apply(batch_norm_inference_request.offset());
+ apply(batch_norm_inference_request.mean());
+ apply(batch_norm_inference_request.variance());
+ break;
+ }
+
case OpRequest::kBatchNormGradRequest: {
const BatchNormGradRequest& batch_norm_grad_request =
request.request().batch_norm_grad_request();
@@ -2647,6 +2725,28 @@ void ComputationLowerer::Visit(
break;
}
+ case OpRequest::kBatchNormInferenceRequest: {
+ const BatchNormInferenceRequest& batch_norm_inference_request =
+ request.request().batch_norm_inference_request();
+ HloInstruction* operand =
+ lookup_instruction(batch_norm_inference_request.operand());
+ HloInstruction* scale =
+ lookup_instruction(batch_norm_inference_request.scale());
+ HloInstruction* offset =
+ lookup_instruction(batch_norm_inference_request.offset());
+ HloInstruction* mean =
+ lookup_instruction(batch_norm_inference_request.mean());
+ HloInstruction* variance =
+ lookup_instruction(batch_norm_inference_request.variance());
+
+ hlo_instruction =
+ add_instruction(HloInstruction::CreateBatchNormInference(
+ request.output_shape(), operand, scale, offset, mean, variance,
+ batch_norm_inference_request.epsilon(),
+ batch_norm_inference_request.feature_index()));
+ break;
+ }
+
case OpRequest::kBatchNormGradRequest: {
const BatchNormGradRequest& batch_norm_grad_request =
request.request().batch_norm_grad_request();
diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h
index 36b1d34e05..b779b1f76c 100644
--- a/tensorflow/compiler/xla/service/user_computation.h
+++ b/tensorflow/compiler/xla/service/user_computation.h
@@ -89,6 +89,10 @@ class UserComputation {
StatusOr<ComputationDataHandle> AddBatchNormTrainingInstruction(
const BatchNormTrainingRequest& batch_norm_training_request);
+ // Enqueues a batch norm inference instruction onto this user computation.
+ StatusOr<ComputationDataHandle> AddBatchNormInferenceInstruction(
+ const BatchNormInferenceRequest& batch_norm_inference_request);
+
// Enqueues a batch norm grad instruction onto this user computation.
StatusOr<ComputationDataHandle> AddBatchNormGradInstruction(
const BatchNormGradRequest& batch_norm_grad_request);
diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
index 34b3abb8c7..028d1251b4 100644
--- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc
+++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
@@ -306,6 +306,109 @@ XLA_TEST_P(BatchNormTest, RandomizedTests) {
ErrorSpec(0.01, 1));
}
+XLA_TEST_P(BatchNormTest, RandomizedInferencingTests) {
+ float epsilon = 0.001;
+ ComputationBuilder builder(client_, TestName());
+ const std::vector<int64>& bounds = GetParam().bounds;
+ Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
+ input_array.FillRandom(GetParam().random_value_var,
+ GetParam().random_value_mean);
+
+ const int64 feature_index = GetParam().feature_index;
+ const int64 num_elements_per_feature =
+ Product(bounds) / bounds[feature_index];
+ const int64 feature_bound = bounds[feature_index];
+ std::vector<float> offset(feature_bound, 1);
+ std::vector<float> scale(feature_bound, 2);
+
+ auto input_squared =
+ ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
+ std::vector<int64> reduce_dims;
+ for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) {
+ if (i != feature_index) {
+ reduce_dims.push_back(i);
+ }
+ }
+
+ auto sum =
+ ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
+ [](float a, float b) { return a + b; });
+
+ auto sum_squared =
+ ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
+ [](float a, float b) { return a + b; });
+
+ std::vector<float> mean(feature_bound);
+
+ for (int64 i = 0; i < feature_bound; ++i) {
+ mean[i] = sum[i] / num_elements_per_feature;
+ }
+
+ std::vector<float> mean_square(feature_bound);
+ for (int64 i = 0; i < feature_bound; ++i) {
+ mean_square[i] = mean[i] * mean[i];
+ }
+
+ std::vector<float> square_mean(feature_bound);
+ for (int64 i = 0; i < feature_bound; ++i) {
+ square_mean[i] = sum_squared[i] / num_elements_per_feature;
+ }
+
+ std::vector<float> var(feature_bound);
+ for (int64 i = 0; i < feature_bound; ++i) {
+ var[i] = square_mean[i] - mean_square[i];
+ }
+
+ Array4D<float> mean4D =
+ *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
+ auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
+ auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
+ auto offset4D =
+ *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index);
+
+ auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D,
+ scale4D, offset4D, epsilon);
+
+ auto offset_literal = Literal::CreateR1<float>(offset);
+ auto scale_literal = Literal::CreateR1<float>(scale);
+ auto mean_literal = Literal::CreateR1<float>(mean);
+ auto var_literal = Literal::CreateR1<float>(var);
+ auto input_literal = Literal::CreateR4FromArray4D<float>(input_array);
+
+ auto input_activations =
+ builder.Parameter(0, input_literal->shape(), "input");
+ auto scale_activations =
+ builder.Parameter(1, scale_literal->shape(), "offset");
+ auto offset_activations =
+ builder.Parameter(2, offset_literal->shape(), "scale");
+ auto mean_activations = builder.Parameter(3, mean_literal->shape(), "mean");
+ auto variance_activations =
+ builder.Parameter(4, var_literal->shape(), "variance");
+
+ Array4D<float> expected = normalized;
+
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ std::unique_ptr<GlobalData> scale_data =
+ client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
+ std::unique_ptr<GlobalData> offset_data =
+ client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
+ std::unique_ptr<GlobalData> mean_data =
+ client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
+ std::unique_ptr<GlobalData> variance_data =
+ client_->TransferToServer(*var_literal).ConsumeValueOrDie();
+
+ builder.BatchNormInference(input_activations, scale_activations,
+ offset_activations, mean_activations,
+ variance_activations, epsilon, feature_index);
+
+ ComputeAndCompareR4<float>(
+ &builder, expected,
+ {input_data.get(), scale_data.get(), offset_data.get(), mean_data.get(),
+ variance_data.get()},
+ ErrorSpec(0.01, 1));
+}
+
XLA_TEST_P(BatchNormTest, RandomizedGradTests) {
float epsilon = 0.001;
ComputationBuilder builder(client_, TestName());
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index 38e6675ab7..185ca7e681 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -491,6 +491,16 @@ message BatchNormTrainingRequest {
int64 feature_index = 5;
}
+message BatchNormInferenceRequest {
+ ComputationDataHandle operand = 1;
+ ComputationDataHandle scale = 2;
+ ComputationDataHandle offset = 3;
+ ComputationDataHandle mean = 4;
+ ComputationDataHandle variance = 5;
+ float epsilon = 6;
+ int64 feature_index = 7;
+}
+
message BatchNormGradRequest {
ComputationDataHandle operand = 1;
ComputationDataHandle scale = 2;
@@ -813,7 +823,8 @@ message OpRequest {
OutfeedRequest outfeed_request = 32;
BatchNormTrainingRequest batch_norm_training_request = 35;
BatchNormGradRequest batch_norm_grad_request = 37;
- // Next: 38
+ BatchNormInferenceRequest batch_norm_inference_request = 38;
+ // Next: 39
}
}