diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-03 23:03:11 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-03 23:07:04 -0700 |
commit | e57874169fca3cfdd15cf0dda3717a6374a7dcb9 (patch) | |
tree | 8c84491ba2200c19d3a6291c26dea6f196ff33c4 /tensorflow/compiler/xla | |
parent | 6795491bcc0c276e27be6a9e1a4a14c019c2ba37 (diff) |
[XLA] Update Tf2Xla bridge to use Scatter HLO.
PiperOrigin-RevId: 215687800
Diffstat (limited to 'tensorflow/compiler/xla')
-rw-r--r-- | tensorflow/compiler/xla/client/xla_builder.cc | 3 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_module.cc | 3 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/inliner.cc | 32 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/inliner_test.cc | 30 |
4 files changed, 54 insertions, 14 deletions
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index e0ec91dba1..d196252db1 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -208,6 +208,9 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle, case HloOpcode::kWhile: // TODO(b/32495713): We aren't checking the condition and body // computations themselves. + case HloOpcode::kScatter: + // TODO(b/32495713): We aren't checking the embedded computation in + // Scatter. case HloOpcode::kSend: case HloOpcode::kRecv: case HloOpcode::kParameter: diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 7527e35c95..93e04eb3db 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -146,7 +146,8 @@ void HloModule::ReplaceComputations( case HloOpcode::kCall: case HloOpcode::kMap: case HloOpcode::kReduce: - case HloOpcode::kReduceWindow: { + case HloOpcode::kReduceWindow: + case HloOpcode::kScatter: { HloComputation* new_arg = tensorflow::gtl::FindWithDefault( replacements, instruction->to_apply(), nullptr); if (new_arg != nullptr) { diff --git a/tensorflow/compiler/xla/service/inliner.cc b/tensorflow/compiler/xla/service/inliner.cc index 5fd779ebf9..50c408f5bb 100644 --- a/tensorflow/compiler/xla/service/inliner.cc +++ b/tensorflow/compiler/xla/service/inliner.cc @@ -71,26 +71,23 @@ Status InlinerVisitor::HandleMap(HloInstruction* map) { // profitability model for inlining is defined. if (hlo_query::AllOperandsAreParameters(root)) { if (root.opcode() == HloOpcode::kFusion || - root.opcode() == HloOpcode::kParameter || root.opcode() == HloOpcode::kTrace) { // Cloning not supported for these instructions. return Status::OK(); } VLOG(10) << "inlining map({X ... Y}, op) => : op(X ... Y) with function " << root.ToShortString(); - // If the input is a constant then the shape of the constant could be - // different than the map shape. Hence, a broadcast is needed, else the - // cloned operand with new shape and operands work. - if (root.opcode() != HloOpcode::kConstant) { - std::vector<HloInstruction*> params; - for (int64 o = 0; o < root.operands().size(); o++) { - params.push_back(map->operands()[root.operand(o)->parameter_number()]); - } - HloInstruction* placed_instruction = computation_->AddInstruction( - root.CloneWithNewOperands(map->shape(), params)); + if (root.opcode() == HloOpcode::kParameter) { + // If the root is a parameter, then use the corresponding operand as the + // result of the computation. TF_RETURN_IF_ERROR( - computation_->ReplaceInstruction(map, placed_instruction)); - } else { + map->ReplaceAllUsesWith(map->operands()[root.parameter_number()])); + TF_RETURN_IF_ERROR(computation_->RemoveInstruction(map)); + } else if (root.opcode() == HloOpcode::kConstant) { + // If the input is a constant then the shape of the constant could be + // different than the map shape. Hence, a broadcast is needed, else the + // cloned operand with new shape and operands work. + // // The constant is in an embedded computation and needs to be recreated // as part of the computation that the broadcast is inserted into. HloInstruction* constant = computation_->AddInstruction(root.Clone()); @@ -98,6 +95,15 @@ Status InlinerVisitor::HandleMap(HloInstruction* map) { HloInstruction::CreateBroadcast(map->shape(), constant, {})); TF_RETURN_IF_ERROR( computation_->ReplaceInstruction(map, placed_instruction)); + } else { + std::vector<HloInstruction*> params; + for (int64 o = 0; o < root.operands().size(); o++) { + params.push_back(map->operands()[root.operand(o)->parameter_number()]); + } + HloInstruction* placed_instruction = computation_->AddInstruction( + root.CloneWithNewOperands(map->shape(), params)); + TF_RETURN_IF_ERROR( + computation_->ReplaceInstruction(map, placed_instruction)); } changed_ = true; return Status::OK(); diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc index 7e967f035c..98e0f2cfd7 100644 --- a/tensorflow/compiler/xla/service/inliner_test.cc +++ b/tensorflow/compiler/xla/service/inliner_test.cc @@ -146,6 +146,36 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) { EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } +TEST_F(InlinerTest, MapParameter) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + + auto param_builder = HloComputation::Builder(TestName()); + param_builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32, "p0")); + param_builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "p1")); + auto param_f32 = param_builder.Build(); + + auto builder = HloComputation::Builder("MapParamFunction"); + auto lhs = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1))); + auto rhs = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(4))); + builder.AddInstruction( + HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, param_f32.get())); + + auto computation = builder.Build(); + auto hlo_module = CreateNewVerifiedModule(); + hlo_module->AddEmbeddedComputation(std::move(param_f32)); + hlo_module->AddEntryComputation(std::move(computation)); + + Inliner inliner; + EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); + EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), rhs); + + // Verify execution on CPU. + auto result = ExecuteAndTransfer(hlo_module->Clone(), {}); + auto expected = LiteralUtil::CreateR0<float>(4); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); +} } // namespace } // namespace xla |