diff options
author | 2018-04-11 18:38:38 -0700 | |
---|---|---|
committer | 2018-04-11 18:42:28 -0700 | |
commit | 4e29ebd67cd4409cbdfa6510b06acd780166aa9d (patch) | |
tree | c0047ef1ef2276910d60fee24fa261da54041e33 | |
parent | ff6c11008213424b7a1dd77346f996be693b004a (diff) |
[XLA] Redesign: test sharding.
Also set the sharding to the instruction when created from proto.
PiperOrigin-RevId: 192543024
-rw-r--r-- | tensorflow/compiler/xla/client/xla_client/xla_builder.h | 31 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 6 |
2 files changed, 37 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h index 24e0be2ac1..e583b4fe48 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h @@ -959,6 +959,37 @@ XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D<NativeT>& values) { return ConstantFromArray(values); } +// RAII-style object: sets the current sharding assignment in builder on +// construction, and sets back to the previous assignment on destruction. +// +// TODO(b/74197823): This is a part of a NOT YET ready refactor. +class XlaScopedShardingAssignment { + public: + XlaScopedShardingAssignment(xla::XlaBuilder* builder, + tensorflow::gtl::optional<OpSharding> sharding) + : builder_(builder), prev_sharding_(builder->sharding()) { + SetSharding(sharding); + } + + XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete; + XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) = + delete; + + ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); } + + private: + void SetSharding(const tensorflow::gtl::optional<OpSharding>& sharding) { + if (sharding.has_value()) { + builder_->SetSharding(sharding.value()); + } else { + builder_->ClearSharding(); + } + } + + xla::XlaBuilder* const builder_; + tensorflow::gtl::optional<OpSharding> prev_sharding_; +}; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index a986bbd511..5d2d7a9727 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -159,6 +159,12 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( instruction->fft_length_.push_back(fft_len); } + if (proto.has_sharding()) { + TF_ASSIGN_OR_RETURN(const auto& sharding, + HloSharding::FromProto(proto.sharding())); + instruction->set_sharding(sharding); + } + if (proto.has_gather_dimension_numbers()) { instruction->gather_dimension_numbers_ = MakeUnique<GatherDimensionNumbers>(proto.gather_dimension_numbers()); |