aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-11 18:38:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-11 18:42:28 -0700
commit4e29ebd67cd4409cbdfa6510b06acd780166aa9d (patch)
treec0047ef1ef2276910d60fee24fa261da54041e33
parentff6c11008213424b7a1dd77346f996be693b004a (diff)
[XLA] Redesign: test sharding.
Also set the sharding to the instruction when created from proto. PiperOrigin-RevId: 192543024
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.h31
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc6
2 files changed, 37 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
index 24e0be2ac1..e583b4fe48 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
@@ -959,6 +959,37 @@ XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D<NativeT>& values) {
return ConstantFromArray(values);
}
+// RAII-style object: sets the current sharding assignment in builder on
+// construction, and sets back to the previous assignment on destruction.
+//
+// TODO(b/74197823): This is a part of a NOT YET ready refactor.
+class XlaScopedShardingAssignment {
+ public:
+ XlaScopedShardingAssignment(xla::XlaBuilder* builder,
+ tensorflow::gtl::optional<OpSharding> sharding)
+ : builder_(builder), prev_sharding_(builder->sharding()) {
+ SetSharding(sharding);
+ }
+
+ XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete;
+ XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) =
+ delete;
+
+ ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); }
+
+ private:
+ void SetSharding(const tensorflow::gtl::optional<OpSharding>& sharding) {
+ if (sharding.has_value()) {
+ builder_->SetSharding(sharding.value());
+ } else {
+ builder_->ClearSharding();
+ }
+ }
+
+ xla::XlaBuilder* const builder_;
+ tensorflow::gtl::optional<OpSharding> prev_sharding_;
+};
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index a986bbd511..5d2d7a9727 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -159,6 +159,12 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction->fft_length_.push_back(fft_len);
}
+ if (proto.has_sharding()) {
+ TF_ASSIGN_OR_RETURN(const auto& sharding,
+ HloSharding::FromProto(proto.sharding()));
+ instruction->set_sharding(sharding);
+ }
+
if (proto.has_gather_dimension_numbers()) {
instruction->gather_dimension_numbers_ =
MakeUnique<GatherDimensionNumbers>(proto.gather_dimension_numbers());