aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/batchnorm_expander.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/batchnorm_expander.cc')
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander.cc56
1 files changed, 39 insertions, 17 deletions
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc
index ec13fadbc7..c4cd60c120 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -34,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -41,6 +43,8 @@ namespace xla {
namespace {
+using tensorflow::gtl::optional;
+
// BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm
// operations into smaller operations.
class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
@@ -97,7 +101,7 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
add_instruction(HloInstruction::CreateConvert(
ShapeUtil::MakeShape(operand->shape().element_type(), {}),
add_instruction(HloInstruction::CreateConstant(
- Literal::CreateR0<float>(-0.5f))))),
+ LiteralUtil::CreateR0<float>(-0.5f))))),
{}));
return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kPower,
operand, exponent);
@@ -113,7 +117,7 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
add_instruction(HloInstruction::CreateConvert(
ShapeUtil::MakeShape(operand->shape().element_type(), {}),
add_instruction(HloInstruction::CreateConstant(
- Literal::CreateR0<float>(1.0 / element_count))))),
+ LiteralUtil::CreateR0<float>(1.0 / element_count))))),
{}));
return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kMultiply,
operand, elem_count_recip);
@@ -200,11 +204,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
HloInstruction* offset = batch_norm->mutable_operand(2);
const Shape feature_shape = scale->shape();
- auto zero_literal = Literal::CreateR0(0.0f);
+ auto zero_literal = LiteralUtil::CreateR0(0.0f);
TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
- auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
+ auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
auto epsilon = add(HloInstruction::CreateBroadcast(
operand_shape,
@@ -288,16 +292,22 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
int64 instruction_count_after = computation_->instruction_count();
CHECK_EQ(instruction_count_after,
instruction_count_before + added_instructions.size());
+ const HloSharding& sharding = batch_norm->sharding();
HloSharding operand_sharding =
- batch_norm->sharding().GetAsShapeTree(batch_norm->shape()).element({0});
+ sharding.GetAsShapeTree(batch_norm->shape()).element({0});
+ optional<int64> unique_device = batch_norm->sharding_unique_device();
+ HloSharding default_sharding =
+ unique_device.has_value()
+ ? HloSharding::AssignDevice(unique_device.value())
+ : HloSharding::Replicate();
for (HloInstruction* inst : added_instructions) {
if (ShapeUtil::Equal(inst->shape(), operand_shape)) {
inst->set_sharding(operand_sharding);
} else {
- inst->set_sharding(HloSharding::Replicate());
+ inst->set_sharding(default_sharding);
}
}
- tuple->set_sharding(batch_norm->sharding());
+ tuple->set_sharding(sharding);
}
TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple)));
return Status::OK();
@@ -320,7 +330,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
HloInstruction* var = batch_norm->mutable_operand(4);
const Shape feature_shape = scale->shape();
- auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
+ auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast(
operand_shape,
@@ -388,14 +398,20 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
CHECK_EQ(instruction_count_after,
instruction_count_before + added_instructions.size());
if (batch_norm->has_sharding()) {
+ const HloSharding& sharding = batch_norm->sharding();
+ optional<int64> unique_device = batch_norm->sharding_unique_device();
+ HloSharding default_sharding =
+ unique_device.has_value()
+ ? HloSharding::AssignDevice(unique_device.value())
+ : HloSharding::Replicate();
for (HloInstruction* inst : added_instructions) {
if (ShapeUtil::Equal(inst->shape(), operand_shape)) {
- inst->set_sharding(batch_norm->sharding());
+ inst->set_sharding(sharding);
} else {
- inst->set_sharding(HloSharding::Replicate());
+ inst->set_sharding(default_sharding);
}
}
- shifted_normalized->set_sharding(batch_norm->sharding());
+ shifted_normalized->set_sharding(sharding);
}
TF_CHECK_OK(
ReplaceWithNewInstruction(batch_norm, std::move(shifted_normalized)));
@@ -447,11 +463,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
const int64 feature_count = activation_shape.dimensions(feature_index);
const int64 elements_per_feature_int64 = size_in_elements / feature_count;
- auto zero_literal = Literal::CreateR0(0.0f);
+ auto zero_literal = LiteralUtil::CreateR0(0.0f);
TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
- auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
+ auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
auto epsilon_scalar =
add(HloInstruction::CreateConstant(std::move(epsilon_literal)));
@@ -542,7 +558,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
Mean(elements_per_feature_int64, scale_times_rsqrt_var_add_epsilon, add));
auto elements_per_feature_literal =
- Literal::CreateR0<float>(elements_per_feature_int64);
+ LiteralUtil::CreateR0<float>(elements_per_feature_int64);
TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
elements_per_feature_literal->Convert(ptype));
auto elements_per_feature = add(
@@ -562,19 +578,25 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
auto tuple =
HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta});
if (batch_norm->has_sharding()) {
+ const HloSharding& sharding = batch_norm->sharding();
int64 instruction_count_after = computation_->instruction_count();
CHECK_EQ(instruction_count_after,
instruction_count_before + added_instructions.size());
HloSharding activation_sharding =
- batch_norm->sharding().GetAsShapeTree(batch_norm->shape()).element({0});
+ sharding.GetAsShapeTree(batch_norm->shape()).element({0});
+ auto unique_device = batch_norm->sharding_unique_device();
+ HloSharding default_sharding =
+ unique_device.has_value()
+ ? HloSharding::AssignDevice(unique_device.value())
+ : HloSharding::Replicate();
for (HloInstruction* inst : added_instructions) {
if (ShapeUtil::Equal(inst->shape(), activation_shape)) {
inst->set_sharding(activation_sharding);
} else {
- inst->set_sharding(HloSharding::Replicate());
+ inst->set_sharding(default_sharding);
}
}
- tuple->set_sharding(batch_norm->sharding());
+ tuple->set_sharding(sharding);
}
TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple)));