diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/batchnorm_expander.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/batchnorm_expander.cc | 56 |
1 files changed, 39 insertions, 17 deletions
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index ec13fadbc7..c4cd60c120 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -20,6 +20,7 @@ limitations under the License. #include <utility> #include <vector> +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -34,6 +35,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -41,6 +43,8 @@ namespace xla { namespace { +using tensorflow::gtl::optional; + // BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm // operations into smaller operations. class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { @@ -97,7 +101,7 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { add_instruction(HloInstruction::CreateConvert( ShapeUtil::MakeShape(operand->shape().element_type(), {}), add_instruction(HloInstruction::CreateConstant( - Literal::CreateR0<float>(-0.5f))))), + LiteralUtil::CreateR0<float>(-0.5f))))), {})); return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kPower, operand, exponent); @@ -113,7 +117,7 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { add_instruction(HloInstruction::CreateConvert( ShapeUtil::MakeShape(operand->shape().element_type(), {}), add_instruction(HloInstruction::CreateConstant( - Literal::CreateR0<float>(1.0 / element_count))))), + LiteralUtil::CreateR0<float>(1.0 / element_count))))), {})); return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kMultiply, operand, elem_count_recip); @@ -200,11 +204,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( HloInstruction* offset = batch_norm->mutable_operand(2); const Shape feature_shape = scale->shape(); - auto zero_literal = Literal::CreateR0(0.0f); + auto zero_literal = LiteralUtil::CreateR0(0.0f); TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); - auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); + auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); auto epsilon = add(HloInstruction::CreateBroadcast( operand_shape, @@ -288,16 +292,22 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( int64 instruction_count_after = computation_->instruction_count(); CHECK_EQ(instruction_count_after, instruction_count_before + added_instructions.size()); + const HloSharding& sharding = batch_norm->sharding(); HloSharding operand_sharding = - batch_norm->sharding().GetAsShapeTree(batch_norm->shape()).element({0}); + sharding.GetAsShapeTree(batch_norm->shape()).element({0}); + optional<int64> unique_device = batch_norm->sharding_unique_device(); + HloSharding default_sharding = + unique_device.has_value() + ? HloSharding::AssignDevice(unique_device.value()) + : HloSharding::Replicate(); for (HloInstruction* inst : added_instructions) { if (ShapeUtil::Equal(inst->shape(), operand_shape)) { inst->set_sharding(operand_sharding); } else { - inst->set_sharding(HloSharding::Replicate()); + inst->set_sharding(default_sharding); } } - tuple->set_sharding(batch_norm->sharding()); + tuple->set_sharding(sharding); } TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple))); return Status::OK(); @@ -320,7 +330,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( HloInstruction* var = batch_norm->mutable_operand(4); const Shape feature_shape = scale->shape(); - auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); + auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast( operand_shape, @@ -388,14 +398,20 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( CHECK_EQ(instruction_count_after, instruction_count_before + added_instructions.size()); if (batch_norm->has_sharding()) { + const HloSharding& sharding = batch_norm->sharding(); + optional<int64> unique_device = batch_norm->sharding_unique_device(); + HloSharding default_sharding = + unique_device.has_value() + ? HloSharding::AssignDevice(unique_device.value()) + : HloSharding::Replicate(); for (HloInstruction* inst : added_instructions) { if (ShapeUtil::Equal(inst->shape(), operand_shape)) { - inst->set_sharding(batch_norm->sharding()); + inst->set_sharding(sharding); } else { - inst->set_sharding(HloSharding::Replicate()); + inst->set_sharding(default_sharding); } } - shifted_normalized->set_sharding(batch_norm->sharding()); + shifted_normalized->set_sharding(sharding); } TF_CHECK_OK( ReplaceWithNewInstruction(batch_norm, std::move(shifted_normalized))); @@ -447,11 +463,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( const int64 feature_count = activation_shape.dimensions(feature_index); const int64 elements_per_feature_int64 = size_in_elements / feature_count; - auto zero_literal = Literal::CreateR0(0.0f); + auto zero_literal = LiteralUtil::CreateR0(0.0f); TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); - auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); + auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); auto epsilon_scalar = add(HloInstruction::CreateConstant(std::move(epsilon_literal))); @@ -542,7 +558,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( Mean(elements_per_feature_int64, scale_times_rsqrt_var_add_epsilon, add)); auto elements_per_feature_literal = - Literal::CreateR0<float>(elements_per_feature_int64); + LiteralUtil::CreateR0<float>(elements_per_feature_int64); TF_ASSIGN_OR_RETURN(elements_per_feature_literal, elements_per_feature_literal->Convert(ptype)); auto elements_per_feature = add( @@ -562,19 +578,25 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( auto tuple = HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta}); if (batch_norm->has_sharding()) { + const HloSharding& sharding = batch_norm->sharding(); int64 instruction_count_after = computation_->instruction_count(); CHECK_EQ(instruction_count_after, instruction_count_before + added_instructions.size()); HloSharding activation_sharding = - batch_norm->sharding().GetAsShapeTree(batch_norm->shape()).element({0}); + sharding.GetAsShapeTree(batch_norm->shape()).element({0}); + auto unique_device = batch_norm->sharding_unique_device(); + HloSharding default_sharding = + unique_device.has_value() + ? HloSharding::AssignDevice(unique_device.value()) + : HloSharding::Replicate(); for (HloInstruction* inst : added_instructions) { if (ShapeUtil::Equal(inst->shape(), activation_shape)) { inst->set_sharding(activation_sharding); } else { - inst->set_sharding(HloSharding::Replicate()); + inst->set_sharding(default_sharding); } } - tuple->set_sharding(batch_norm->sharding()); + tuple->set_sharding(sharding); } TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple))); |