diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-03-26 17:39:51 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-26 17:42:55 -0700 |
commit | 931f6d553172ddfc9ec4a7a94ea2c6233bf33cb0 (patch) | |
tree | 1690c1ef469bb6843c179bcb4a0d4d010902c3cf | |
parent | eda7aa3f7e763734f5f3550bed8b044a384b2ce8 (diff) |
[XLA] Redesign: handle metadata and sharding.
- Add a xla.OpSharding field to the HloInstructionProto.
- Metatdata handling is tested.
PiperOrigin-RevId: 190553731
-rw-r--r-- | tensorflow/compiler/xla/client/xla_client/xla_builder.cc | 7 | ||||
-rw-r--r-- | tensorflow/compiler/xla/client/xla_client/xla_builder.h | 32 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo.proto | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/BUILD | 3 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/hlo_metadata_test.cc | 9 |
5 files changed, 45 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index bf91efcfd6..1b90b45bfb 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -896,8 +896,13 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction( << "Do not add XlaOp from builder " << operand.builder_->name() << " to builder " << this->name(); instr.add_operand_ids(operand.handle()); - // TODO(b/74197823): Set metadata and sharding. } + + *instr.mutable_metadata() = metadata_; + if (sharding_) { + *instr.mutable_sharding() = *sharding_; + } + instructions_.push_back(instr); XlaOp op(handle, this); diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h index 22cf094512..cc33356cc1 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h @@ -85,6 +85,29 @@ class XlaBuilder { // Returns the computation name. const string& name() const { return name_; } + // Sets OpMetadata that will be added to all instructions until cleared. + // + // OpMetadata is often applied to a series of XLA HLO instructions. As a + // result, OpMetadata is set on the Computation Builder. All subsequent + // instructions generated via this Computation Builder will have the same + // OpMetadata attached until a call to ClearOpMetadata. + void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; } + + // Clears the HloMetadata state. + void ClearOpMetadata() { metadata_.Clear(); } + + // Sets an OpSharding that will be attached to all instructions until cleared. + void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } + + // Clears the sharding. Ops will be sharded according to the default placement + // policy. + void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; } + + // Returns the OpSharding that will be attached to all instructions. + const tensorflow::gtl::optional<OpSharding>& sharding() const { + return sharding_; + } + // Sets the builder to a mode where it will die immediately when an error is // encountered, rather than producing it in a deferred fashion when Build() is // called (which is the default). @@ -776,6 +799,15 @@ class XlaBuilder { // The unique parameter numbers. tensorflow::gtl::FlatSet<int64> parameter_numbers_; + // The metadata to attach to each op. This is structured as a "modal"-like + // operation, in order to simplify client code (and not sprinkle this metadata + // throughout the TensorFlow op kernel implementations). + OpMetadata metadata_; + + // Sharding for this operator. This is structured as a "model"-like operation, + // in order to simplify client code, similar to metadata_. + tensorflow::gtl::optional<OpSharding> sharding_; + // Mode bit that indicates whether to die when a first error is encountered. bool die_immediately_on_error_ = false; }; diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 406feadfd4..0b446c6547 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -141,6 +141,8 @@ message HloInstructionProto { repeated int64 operand_ids = 36; repeated int64 control_predecessor_ids = 37; repeated int64 called_computation_ids = 38; + + xla.OpSharding sharding = 40; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 3705d6c271..5ab25f2264 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -1810,9 +1810,8 @@ tf_cc_test( deps = [ ":local_client_test_base", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/service:computation_tracker", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/core:test_main", diff --git a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc index eded2077fc..cf971dd61b 100644 --- a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc +++ b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/local_client_test_base.h" @@ -30,7 +29,7 @@ class HloMetadataTest : public LocalClientTestBase { metadata_.set_op_name("my_sum_op"); } - void BuildAddComputation(ComputationBuilder* builder) { + void BuildAddComputation(XlaBuilder* builder) { auto x = builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder->Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder->Add(x, y); @@ -40,7 +39,7 @@ class HloMetadataTest : public LocalClientTestBase { }; TEST_F(HloMetadataTest, MetadataPropagation) { - ComputationBuilder builder(local_client_, "add"); + XlaBuilder builder("add"); builder.SetOpMetadata(metadata_); BuildAddComputation(&builder); builder.ClearOpMetadata(); @@ -61,7 +60,7 @@ TEST_F(HloMetadataTest, MetadataPropagation) { } TEST_F(HloMetadataTest, MetadataClearing) { - ComputationBuilder builder(local_client_, "add"); + XlaBuilder builder("add"); builder.SetOpMetadata(metadata_); // Some other pretend computation here. builder.ClearOpMetadata(); |