aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar HyoukJoong Lee <hyouklee@google.com>2018-07-09 07:41:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-09 07:45:38 -0700
commit501c2851492bf30ba72f65516eb892e46800122c (patch)
tree2e1c3b236f8568a4635c99661e9ae4758f2295ef
parent4f2b576bf388db470fea166c6cc68fece455cdd9 (diff)
Fix sharding for instructions created in BatchNormExpander
PiperOrigin-RevId: 203763201
-rw-r--r--tensorflow/compiler/xla/service/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander.cc39
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander_test.cc29
3 files changed, 60 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 73680e8d59..6e3431df52 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1239,6 +1239,7 @@ tf_cc_test(
":batchnorm_expander",
":hlo",
":hlo_matchers",
+ ":hlo_parser",
":hlo_pass",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc
index aed5832eee..c4cd60c120 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc
@@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -42,6 +43,8 @@ namespace xla {
namespace {
+using tensorflow::gtl::optional;
+
// BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm
// operations into smaller operations.
class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
@@ -289,16 +292,22 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
int64 instruction_count_after = computation_->instruction_count();
CHECK_EQ(instruction_count_after,
instruction_count_before + added_instructions.size());
+ const HloSharding& sharding = batch_norm->sharding();
HloSharding operand_sharding =
- batch_norm->sharding().GetAsShapeTree(batch_norm->shape()).element({0});
+ sharding.GetAsShapeTree(batch_norm->shape()).element({0});
+ optional<int64> unique_device = batch_norm->sharding_unique_device();
+ HloSharding default_sharding =
+ unique_device.has_value()
+ ? HloSharding::AssignDevice(unique_device.value())
+ : HloSharding::Replicate();
for (HloInstruction* inst : added_instructions) {
if (ShapeUtil::Equal(inst->shape(), operand_shape)) {
inst->set_sharding(operand_sharding);
} else {
- inst->set_sharding(HloSharding::Replicate());
+ inst->set_sharding(default_sharding);
}
}
- tuple->set_sharding(batch_norm->sharding());
+ tuple->set_sharding(sharding);
}
TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple)));
return Status::OK();
@@ -389,14 +398,20 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
CHECK_EQ(instruction_count_after,
instruction_count_before + added_instructions.size());
if (batch_norm->has_sharding()) {
+ const HloSharding& sharding = batch_norm->sharding();
+ optional<int64> unique_device = batch_norm->sharding_unique_device();
+ HloSharding default_sharding =
+ unique_device.has_value()
+ ? HloSharding::AssignDevice(unique_device.value())
+ : HloSharding::Replicate();
for (HloInstruction* inst : added_instructions) {
if (ShapeUtil::Equal(inst->shape(), operand_shape)) {
- inst->set_sharding(batch_norm->sharding());
+ inst->set_sharding(sharding);
} else {
- inst->set_sharding(HloSharding::Replicate());
+ inst->set_sharding(default_sharding);
}
}
- shifted_normalized->set_sharding(batch_norm->sharding());
+ shifted_normalized->set_sharding(sharding);
}
TF_CHECK_OK(
ReplaceWithNewInstruction(batch_norm, std::move(shifted_normalized)));
@@ -563,19 +578,25 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
auto tuple =
HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta});
if (batch_norm->has_sharding()) {
+ const HloSharding& sharding = batch_norm->sharding();
int64 instruction_count_after = computation_->instruction_count();
CHECK_EQ(instruction_count_after,
instruction_count_before + added_instructions.size());
HloSharding activation_sharding =
- batch_norm->sharding().GetAsShapeTree(batch_norm->shape()).element({0});
+ sharding.GetAsShapeTree(batch_norm->shape()).element({0});
+ auto unique_device = batch_norm->sharding_unique_device();
+ HloSharding default_sharding =
+ unique_device.has_value()
+ ? HloSharding::AssignDevice(unique_device.value())
+ : HloSharding::Replicate();
for (HloInstruction* inst : added_instructions) {
if (ShapeUtil::Equal(inst->shape(), activation_shape)) {
inst->set_sharding(activation_sharding);
} else {
- inst->set_sharding(HloSharding::Replicate());
+ inst->set_sharding(default_sharding);
}
}
- tuple->set_sharding(batch_norm->sharding());
+ tuple->set_sharding(sharding);
}
TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple)));
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
index 9096792237..32f785a70a 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
@@ -114,5 +115,33 @@ TEST_F(BatchNormExpanderTest, BatchNormGrad) {
EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
}
+TEST_F(BatchNormExpanderTest, BatchNormTrainingSharding) {
+ const char* module_str = R"(
+HloModule module
+ENTRY entry {
+ %param.0 = f32[8,4] parameter(0)
+ %param.1 = f32[4] parameter(1)
+ %param.2 = f32[4] parameter(2)
+ ROOT %batch-norm-training = (f32[8,4], f32[4], f32[4])
+ batch-norm-training(f32[8,4] %param.0, f32[4] %param.1, f32[4] %param.2),
+ epsilon=0.001, feature_index=1, sharding={maximal device=1}
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(module_str));
+ BatchNormExpander rewriter(/*rewrite_training_op=*/true,
+ /*rewrite_inference_op=*/true,
+ /*rewrite_grad_op=*/true);
+ ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
+
+ for (auto* instruction : module->entry_computation()->instructions()) {
+ if (instruction->opcode() == HloOpcode::kParameter) {
+ continue;
+ }
+ ASSERT_TRUE(instruction->has_sharding());
+ TF_ASSERT_OK_AND_ASSIGN(int device, instruction->sharding().UniqueDevice());
+ EXPECT_EQ(device, 1);
+ }
+}
+
} // namespace
} // namespace xla