aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc84
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h22
-rw-r--r--tensorflow/compiler/xla/client/xla_builder_test.cc52
3 files changed, 103 insertions, 55 deletions
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 1cb61f77fb..073d66bcd2 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -45,21 +45,6 @@ int64 GetUniqueId() {
return id;
}
-// Returns true if an instruction with the given opcode can be the root of the
-// computation.
-bool CanBeRoot(HloOpcode opcode) {
- switch (opcode) {
- case HloOpcode::kAfterAll:
- case HloOpcode::kSend:
- case HloOpcode::kSendDone:
- case HloOpcode::kOutfeed:
- case HloOpcode::kTrace:
- return false;
- default:
- return true;
- }
-}
-
} // namespace
XlaOp operator-(const XlaOp& x) { return Neg(x); }
@@ -142,28 +127,13 @@ XlaOp XlaBuilder::ReportErrorOrReturn(
return ReportErrorOrReturn(op_creator());
}
-StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64* root_id) const {
+StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64 root_id) const {
TF_RETURN_IF_ERROR(first_error_);
-
- TF_RET_CHECK(root_id != nullptr);
+ TF_RET_CHECK((root_id >= 0) && (root_id < instructions_.size()));
ProgramShape program_shape;
- // Not all instructions can be roots. Walk backwards from the last added
- // instruction until a valid root is found.
- int64 index = instructions_.size() - 1;
- for (; index >= 0; index--) {
- TF_ASSIGN_OR_RETURN(HloOpcode opcode,
- StringToHloOpcode(instructions_[index].opcode()));
- if (CanBeRoot(opcode)) {
- break;
- }
- }
- if (index < 0) {
- return FailedPrecondition("no root instruction was found");
- }
- *root_id = instructions_[index].id();
- *program_shape.mutable_result() = instructions_[index].shape();
+ *program_shape.mutable_result() = instructions_[root_id].shape();
// Check that the parameter numbers are continuous from 0, and add parameter
// shapes and names to the program shape.
@@ -188,8 +158,15 @@ StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64* root_id) const {
}
StatusOr<ProgramShape> XlaBuilder::GetProgramShape() const {
- int64 root;
- return GetProgramShape(&root);
+ TF_RET_CHECK(!instructions_.empty());
+ return GetProgramShape(instructions_.back().id());
+}
+
+StatusOr<ProgramShape> XlaBuilder::GetProgramShape(XlaOp root) const {
+ if (root.builder_ != this) {
+ return InvalidArgument("Given root operation is not in this computation.");
+ }
+ return GetProgramShape(root.handle());
}
void XlaBuilder::IsConstantVisitor(const int64 op_handle,
@@ -257,17 +234,29 @@ StatusOr<XlaComputation> XlaBuilder::Build() {
first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace);
return AppendStatus(first_error_, backtrace);
}
+ return Build(instructions_.back().id());
+}
+
+StatusOr<XlaComputation> XlaBuilder::Build(XlaOp root) {
+ if (root.builder_ != this) {
+ return InvalidArgument("Given root operation is not in this computation.");
+ }
+ return Build(root.handle());
+}
+
+StatusOr<XlaComputation> XlaBuilder::Build(int64 root_id) {
+ if (!first_error_.ok()) {
+ string backtrace;
+ first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace);
+ return AppendStatus(first_error_, backtrace);
+ }
HloComputationProto entry;
entry.set_id(GetUniqueId()); // Give the computation a global unique id.
entry.set_name(StrCat(name_, entry.id())); // Ensure that the name is unique.
- {
- int64 root_id;
- TF_ASSIGN_OR_RETURN(*entry.mutable_program_shape(),
- GetProgramShape(&root_id));
- entry.set_root_id(root_id);
- }
+ TF_ASSIGN_OR_RETURN(*entry.mutable_program_shape(), GetProgramShape(root_id));
+ entry.set_root_id(root_id);
for (auto& instruction : instructions_) {
// Ensures that the instruction names are unique among the whole graph.
@@ -1099,11 +1088,11 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) {
sharding_builder::AssignDevice(0);
XlaScopedShardingAssignment scoped_sharding(this,
infeed_instruction_sharding);
- TF_ASSIGN_OR_RETURN(infeed,
- AddInstruction(std::move(instr), HloOpcode::kInfeed));
+ TF_ASSIGN_OR_RETURN(
+ infeed, AddInstruction(std::move(instr), HloOpcode::kInfeed, {}));
} else {
- TF_ASSIGN_OR_RETURN(infeed,
- AddInstruction(std::move(instr), HloOpcode::kInfeed));
+ TF_ASSIGN_OR_RETURN(
+ infeed, AddInstruction(std::move(instr), HloOpcode::kInfeed, {}));
}
// The infeed instruction produces a tuple of the infed data and a token
@@ -2163,11 +2152,6 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
LookUpInstruction(root_op));
- TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
- if (!CanBeRoot(opcode)) {
- return InvalidArgument("the operand with opcode %s cannot be root",
- root->opcode().c_str());
- }
HloComputationProto entry;
entry.set_id(GetUniqueId()); // Give the computation a global unique id.
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index 8726cc6f93..3c5f8c8d53 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -195,9 +195,14 @@ class XlaBuilder {
// Builds the computation with the requested operations, or returns a non-ok
// status. Note that all ops that have been enqueued will be moved to the
- // computation being returned.
+ // computation being returned. The root of the computation will be the last
+ // added operation.
StatusOr<XlaComputation> Build();
+ // Overload of Build which specifies a particular root instruction for the
+ // computation.
+ StatusOr<XlaComputation> Build(XlaOp root);
+
// Builds the computation with the requested operations, or notes an error in
// the parent XlaBuilder and returns an empty computation if building failed.
// This function is intended to be used where the returned XlaComputation is
@@ -225,9 +230,14 @@ class XlaBuilder {
// Returns the shape of the given op.
StatusOr<Shape> GetShape(const XlaOp& op) const;
- // Returns the (inferred) result for the current computation's shape.
+ // Returns the (inferred) result for the current computation's shape. This
+ // assumes the root instruction is the last added instruction.
StatusOr<ProgramShape> GetProgramShape() const;
+ // Returns the (inferred) result for the current computation's shape using the
+ // given operation as the root.
+ StatusOr<ProgramShape> GetProgramShape(XlaOp root) const;
+
// Reports an error to the builder, by
// * storing it internally and capturing a backtrace if it's the first error
// (this deferred value will be produced on the call to
@@ -255,6 +265,9 @@ class XlaBuilder {
StatusOr<bool> IsConstant(const XlaOp& operand) const;
private:
+ // Build helper which takes the id of the root operation..
+ StatusOr<XlaComputation> Build(int64 root_id);
+
// Enqueues a "retrieve parameter value" instruction for a parameter that was
// passed to the computation.
XlaOp Parameter(int64 parameter_number, const Shape& shape,
@@ -969,9 +982,8 @@ class XlaBuilder {
// shape.
StatusOr<XlaOp> Reshape(const Shape& shape, const XlaOp& operand);
- // Returns the (inferred) result for the program shape for the current
- // computation and fills the root_id in the pointer.
- StatusOr<ProgramShape> GetProgramShape(int64* root_id) const;
+ // Returns the (inferred) result for the program shape using the given root.
+ StatusOr<ProgramShape> GetProgramShape(int64 root_id) const;
// Returns shapes for the operands.
StatusOr<std::vector<Shape>> GetOperandShapes(
diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc
index 28a207b137..afe5be29f0 100644
--- a/tensorflow/compiler/xla/client/xla_builder_test.cc
+++ b/tensorflow/compiler/xla/client/xla_builder_test.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
@@ -46,6 +47,17 @@ class XlaBuilderTest : public ::testing::Test {
return HloModule::CreateFromProto(proto, config);
}
+ // Overload which explicitly specifies the root instruction.
+ StatusOr<std::unique_ptr<HloModule>> BuildHloModule(XlaBuilder* b,
+ XlaOp root) {
+ TF_ASSIGN_OR_RETURN(XlaComputation computation, b->Build(root));
+ const HloModuleProto& proto = computation.proto();
+ TF_ASSIGN_OR_RETURN(const auto& config,
+ HloModule::CreateModuleConfigFromProto(
+ proto, legacy_flags::GetDebugOptionsFromFlags()));
+ return HloModule::CreateFromProto(proto, config);
+ }
+
// Returns the name of the test currently being run.
string TestName() const {
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
@@ -320,5 +332,45 @@ TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesErrors) {
EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error"));
}
+TEST_F(XlaBuilderTest, BuildWithSpecificRoot) {
+ XlaBuilder b(TestName());
+ XlaOp constant = ConstantR0<float>(&b, 1.0);
+ Add(constant, ConstantR0<float>(&b, 2.0));
+ TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/constant));
+ auto root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Constant());
+}
+
+TEST_F(XlaBuilderTest, BuildWithSpecificRootAndMultipleParameters) {
+ // Specifying a particular root in Build should still include all entry
+ // parameters.
+ XlaBuilder b(TestName());
+ const Shape shape = ShapeUtil::MakeShape(F32, {42, 123});
+ XlaOp x = Parameter(&b, 0, shape, "x");
+ XlaOp y = Parameter(&b, 1, shape, "y");
+ XlaOp z = Parameter(&b, 2, shape, "z");
+ Add(x, Sub(y, z));
+ TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/x));
+ auto root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Parameter());
+ EXPECT_EQ(module->entry_computation()->num_parameters(), 3);
+ EXPECT_EQ(module->entry_computation()->instruction_count(), 5);
+}
+
+TEST_F(XlaBuilderTest, BuildWithSpecificRootWithWrongBuilder) {
+ XlaBuilder b(TestName());
+ XlaBuilder other_b(TestName());
+ const Shape shape = ShapeUtil::MakeShape(F32, {42, 123});
+
+ Parameter(&b, 0, shape, "param");
+ XlaOp other_param = Parameter(&other_b, 0, shape, "other_param");
+
+ Status status = b.Build(other_param).status();
+ ASSERT_IS_NOT_OK(status);
+ EXPECT_THAT(
+ status.error_message(),
+ ::testing::HasSubstr("root operation is not in this computation"));
+}
+
} // namespace
} // namespace xla