diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-09 16:24:56 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-09 16:27:39 -0700 |
commit | e346ac4faec2246c2d3972f158dea6aec858b904 (patch) | |
tree | 4a78a3888fbf281d1b0ecfa3c9946f49c2fe77c7 | |
parent | 130f44932fbfb3bef20911931de1eb263d55e992 (diff) |
[XLA] Redesign: implement infeed and outfeed.
- XlaBuilder::Infeed is basically ComputationBuilder::Infeed + UserComputation::AddInfeedInstruction + ComputationLowerer::Visit + HloInstruction::CreateInfeed.
- Similar for Outfeed.
PiperOrigin-RevId: 192206502
3 files changed, 42 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index 3d0cb35b48..ed9f994d39 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -781,12 +781,41 @@ XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type, } XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { - return UnimplementedOp(); + return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + HloInstructionProto instr; + if (!LayoutUtil::HasLayout(shape)) { + return InvalidArgument("Given shape to Infeed must have a layout"); + } + *instr.mutable_shape() = shape; + instr.set_infeed_config(config); + return AddInstruction(std::move(instr), HloOpcode::kInfeed); + }); } void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout, const string& outfeed_config) { - UnimplementedOp(); + NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + HloInstructionProto instr; + + *instr.mutable_shape() = ShapeUtil::MakeNil(); + + // Check and set outfeed shape. + if (!LayoutUtil::HasLayout(shape_with_layout)) { + return InvalidArgument("Given shape to Outfeed must have a layout"); + } + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) { + return InvalidArgument( + "Outfeed shape %s must be compatible with operand shape %s", + ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(), + ShapeUtil::HumanStringWithLayout(operand_shape).c_str()); + } + *instr.mutable_outfeed_shape() = shape_with_layout; + + instr.set_outfeed_config(outfeed_config); + + return AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand}); + }); } XlaOp XlaBuilder::CustomCall(const string& call_target_name, diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 9124ccdb46..c2e3cd2350 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -74,9 +74,9 @@ string ClientLibraryTestBase::TestName() const { return ::testing::UnitTest::GetInstance()->current_test_info()->name(); } +template <typename BuilderT> StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice<GlobalData*> arguments) { + BuilderT* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments) { // Build the computation, as a convenience. TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); return client_->Execute(computation, arguments, &execution_options_); @@ -651,4 +651,11 @@ template void ClientLibraryTestBase::ComputeAndCompareTuple( XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error); +template StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute( + ComputationBuilder* builder, + tensorflow::gtl::ArraySlice<GlobalData*> arguments); + +template StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute( + XlaBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments); + } // namespace xla diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 80e1bbbae8..0572acff88 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -92,9 +92,9 @@ class ClientLibraryTestBase : public ::testing::Test { // Convenience methods for building and running a computation with the member // execution options. Modify execution_options_ in your test if you want to // customize the options. + template <typename BuilderT> StatusOr<std::unique_ptr<GlobalData>> Execute( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice<GlobalData*> arguments); + BuilderT* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments); // TODO(b/74197823): Remove the template type 'BuilderT' in all methods once // the migration to XlaBuilder is complete. |