aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-26 17:39:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-26 17:42:55 -0700
commit931f6d553172ddfc9ec4a7a94ea2c6233bf33cb0 (patch)
tree1690c1ef469bb6843c179bcb4a0d4d010902c3cf
parenteda7aa3f7e763734f5f3550bed8b044a384b2ce8 (diff)
[XLA] Redesign: handle metadata and sharding.
- Add a xla.OpSharding field to the HloInstructionProto. - Metatdata handling is tested. PiperOrigin-RevId: 190553731
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.cc7
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.h32
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto2
-rw-r--r--tensorflow/compiler/xla/tests/BUILD3
-rw-r--r--tensorflow/compiler/xla/tests/hlo_metadata_test.cc9
5 files changed, 45 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
index bf91efcfd6..1b90b45bfb 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
@@ -896,8 +896,13 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction(
<< "Do not add XlaOp from builder " << operand.builder_->name()
<< " to builder " << this->name();
instr.add_operand_ids(operand.handle());
- // TODO(b/74197823): Set metadata and sharding.
}
+
+ *instr.mutable_metadata() = metadata_;
+ if (sharding_) {
+ *instr.mutable_sharding() = *sharding_;
+ }
+
instructions_.push_back(instr);
XlaOp op(handle, this);
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
index 22cf094512..cc33356cc1 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
@@ -85,6 +85,29 @@ class XlaBuilder {
// Returns the computation name.
const string& name() const { return name_; }
+ // Sets OpMetadata that will be added to all instructions until cleared.
+ //
+ // OpMetadata is often applied to a series of XLA HLO instructions. As a
+ // result, OpMetadata is set on the Computation Builder. All subsequent
+ // instructions generated via this Computation Builder will have the same
+ // OpMetadata attached until a call to ClearOpMetadata.
+ void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; }
+
+ // Clears the HloMetadata state.
+ void ClearOpMetadata() { metadata_.Clear(); }
+
+ // Sets an OpSharding that will be attached to all instructions until cleared.
+ void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }
+
+ // Clears the sharding. Ops will be sharded according to the default placement
+ // policy.
+ void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; }
+
+ // Returns the OpSharding that will be attached to all instructions.
+ const tensorflow::gtl::optional<OpSharding>& sharding() const {
+ return sharding_;
+ }
+
// Sets the builder to a mode where it will die immediately when an error is
// encountered, rather than producing it in a deferred fashion when Build() is
// called (which is the default).
@@ -776,6 +799,15 @@ class XlaBuilder {
// The unique parameter numbers.
tensorflow::gtl::FlatSet<int64> parameter_numbers_;
+ // The metadata to attach to each op. This is structured as a "modal"-like
+ // operation, in order to simplify client code (and not sprinkle this metadata
+ // throughout the TensorFlow op kernel implementations).
+ OpMetadata metadata_;
+
+ // Sharding for this operator. This is structured as a "model"-like operation,
+ // in order to simplify client code, similar to metadata_.
+ tensorflow::gtl::optional<OpSharding> sharding_;
+
// Mode bit that indicates whether to die when a first error is encountered.
bool die_immediately_on_error_ = false;
};
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index 406feadfd4..0b446c6547 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -141,6 +141,8 @@ message HloInstructionProto {
repeated int64 operand_ids = 36;
repeated int64 control_predecessor_ids = 37;
repeated int64 called_computation_ids = 38;
+
+ xla.OpSharding sharding = 40;
}
// Serialization of HloComputation.
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 3705d6c271..5ab25f2264 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -1810,9 +1810,8 @@ tf_cc_test(
deps = [
":local_client_test_base",
"//tensorflow/compiler/xla:test_helpers",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
- "//tensorflow/compiler/xla/service:computation_tracker",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/service:cpu_plugin",
"//tensorflow/compiler/xla/service:local_service",
"//tensorflow/core:test_main",
diff --git a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc
index eded2077fc..cf971dd61b 100644
--- a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc
+++ b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc
@@ -13,9 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/service/computation_tracker.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/service/local_service.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/local_client_test_base.h"
@@ -30,7 +29,7 @@ class HloMetadataTest : public LocalClientTestBase {
metadata_.set_op_name("my_sum_op");
}
- void BuildAddComputation(ComputationBuilder* builder) {
+ void BuildAddComputation(XlaBuilder* builder) {
auto x = builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
auto y = builder->Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
builder->Add(x, y);
@@ -40,7 +39,7 @@ class HloMetadataTest : public LocalClientTestBase {
};
TEST_F(HloMetadataTest, MetadataPropagation) {
- ComputationBuilder builder(local_client_, "add");
+ XlaBuilder builder("add");
builder.SetOpMetadata(metadata_);
BuildAddComputation(&builder);
builder.ClearOpMetadata();
@@ -61,7 +60,7 @@ TEST_F(HloMetadataTest, MetadataPropagation) {
}
TEST_F(HloMetadataTest, MetadataClearing) {
- ComputationBuilder builder(local_client_, "add");
+ XlaBuilder builder("add");
builder.SetOpMetadata(metadata_);
// Some other pretend computation here.
builder.ClearOpMetadata();