aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-03 23:03:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 23:07:04 -0700
commite57874169fca3cfdd15cf0dda3717a6374a7dcb9 (patch)
tree8c84491ba2200c19d3a6291c26dea6f196ff33c4 /tensorflow/compiler/xla
parent6795491bcc0c276e27be6a9e1a4a14c019c2ba37 (diff)
[XLA] Update Tf2Xla bridge to use Scatter HLO.
PiperOrigin-RevId: 215687800
Diffstat (limited to 'tensorflow/compiler/xla')
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc3
-rw-r--r--tensorflow/compiler/xla/service/inliner.cc32
-rw-r--r--tensorflow/compiler/xla/service/inliner_test.cc30
4 files changed, 54 insertions, 14 deletions
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index e0ec91dba1..d196252db1 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -208,6 +208,9 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle,
case HloOpcode::kWhile:
// TODO(b/32495713): We aren't checking the condition and body
// computations themselves.
+ case HloOpcode::kScatter:
+ // TODO(b/32495713): We aren't checking the embedded computation in
+ // Scatter.
case HloOpcode::kSend:
case HloOpcode::kRecv:
case HloOpcode::kParameter:
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 7527e35c95..93e04eb3db 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -146,7 +146,8 @@ void HloModule::ReplaceComputations(
case HloOpcode::kCall:
case HloOpcode::kMap:
case HloOpcode::kReduce:
- case HloOpcode::kReduceWindow: {
+ case HloOpcode::kReduceWindow:
+ case HloOpcode::kScatter: {
HloComputation* new_arg = tensorflow::gtl::FindWithDefault(
replacements, instruction->to_apply(), nullptr);
if (new_arg != nullptr) {
diff --git a/tensorflow/compiler/xla/service/inliner.cc b/tensorflow/compiler/xla/service/inliner.cc
index 5fd779ebf9..50c408f5bb 100644
--- a/tensorflow/compiler/xla/service/inliner.cc
+++ b/tensorflow/compiler/xla/service/inliner.cc
@@ -71,26 +71,23 @@ Status InlinerVisitor::HandleMap(HloInstruction* map) {
// profitability model for inlining is defined.
if (hlo_query::AllOperandsAreParameters(root)) {
if (root.opcode() == HloOpcode::kFusion ||
- root.opcode() == HloOpcode::kParameter ||
root.opcode() == HloOpcode::kTrace) {
// Cloning not supported for these instructions.
return Status::OK();
}
VLOG(10) << "inlining map({X ... Y}, op) => : op(X ... Y) with function "
<< root.ToShortString();
- // If the input is a constant then the shape of the constant could be
- // different than the map shape. Hence, a broadcast is needed, else the
- // cloned operand with new shape and operands work.
- if (root.opcode() != HloOpcode::kConstant) {
- std::vector<HloInstruction*> params;
- for (int64 o = 0; o < root.operands().size(); o++) {
- params.push_back(map->operands()[root.operand(o)->parameter_number()]);
- }
- HloInstruction* placed_instruction = computation_->AddInstruction(
- root.CloneWithNewOperands(map->shape(), params));
+ if (root.opcode() == HloOpcode::kParameter) {
+ // If the root is a parameter, then use the corresponding operand as the
+ // result of the computation.
TF_RETURN_IF_ERROR(
- computation_->ReplaceInstruction(map, placed_instruction));
- } else {
+ map->ReplaceAllUsesWith(map->operands()[root.parameter_number()]));
+ TF_RETURN_IF_ERROR(computation_->RemoveInstruction(map));
+ } else if (root.opcode() == HloOpcode::kConstant) {
+ // If the input is a constant then the shape of the constant could be
+ // different than the map shape. Hence, a broadcast is needed, else the
+ // cloned operand with new shape and operands work.
+ //
// The constant is in an embedded computation and needs to be recreated
// as part of the computation that the broadcast is inserted into.
HloInstruction* constant = computation_->AddInstruction(root.Clone());
@@ -98,6 +95,15 @@ Status InlinerVisitor::HandleMap(HloInstruction* map) {
HloInstruction::CreateBroadcast(map->shape(), constant, {}));
TF_RETURN_IF_ERROR(
computation_->ReplaceInstruction(map, placed_instruction));
+ } else {
+ std::vector<HloInstruction*> params;
+ for (int64 o = 0; o < root.operands().size(); o++) {
+ params.push_back(map->operands()[root.operand(o)->parameter_number()]);
+ }
+ HloInstruction* placed_instruction = computation_->AddInstruction(
+ root.CloneWithNewOperands(map->shape(), params));
+ TF_RETURN_IF_ERROR(
+ computation_->ReplaceInstruction(map, placed_instruction));
}
changed_ = true;
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc
index 7e967f035c..98e0f2cfd7 100644
--- a/tensorflow/compiler/xla/service/inliner_test.cc
+++ b/tensorflow/compiler/xla/service/inliner_test.cc
@@ -146,6 +146,36 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) {
EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
+TEST_F(InlinerTest, MapParameter) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+
+ auto param_builder = HloComputation::Builder(TestName());
+ param_builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32, "p0"));
+ param_builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "p1"));
+ auto param_f32 = param_builder.Build();
+
+ auto builder = HloComputation::Builder("MapParamFunction");
+ auto lhs = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1)));
+ auto rhs = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(4)));
+ builder.AddInstruction(
+ HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, param_f32.get()));
+
+ auto computation = builder.Build();
+ auto hlo_module = CreateNewVerifiedModule();
+ hlo_module->AddEmbeddedComputation(std::move(param_f32));
+ hlo_module->AddEntryComputation(std::move(computation));
+
+ Inliner inliner;
+ EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
+ EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), rhs);
+
+ // Verify execution on CPU.
+ auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
+ auto expected = LiteralUtil::CreateR0<float>(4);
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
+}
} // namespace
} // namespace xla