aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-09 16:24:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-09 16:27:39 -0700
commite346ac4faec2246c2d3972f158dea6aec858b904 (patch)
tree4a78a3888fbf281d1b0ecfa3c9946f49c2fe77c7
parent130f44932fbfb3bef20911931de1eb263d55e992 (diff)
[XLA] Redesign: implement infeed and outfeed.
- XlaBuilder::Infeed is basically ComputationBuilder::Infeed + UserComputation::AddInfeedInstruction + ComputationLowerer::Visit + HloInstruction::CreateInfeed. - Similar for Outfeed. PiperOrigin-RevId: 192206502
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.cc33
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc11
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h4
3 files changed, 42 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
index 3d0cb35b48..ed9f994d39 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
@@ -781,12 +781,41 @@ XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type,
}
XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) {
- return UnimplementedOp();
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ if (!LayoutUtil::HasLayout(shape)) {
+ return InvalidArgument("Given shape to Infeed must have a layout");
+ }
+ *instr.mutable_shape() = shape;
+ instr.set_infeed_config(config);
+ return AddInstruction(std::move(instr), HloOpcode::kInfeed);
+ });
}
void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
const string& outfeed_config) {
- UnimplementedOp();
+ NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+
+ *instr.mutable_shape() = ShapeUtil::MakeNil();
+
+ // Check and set outfeed shape.
+ if (!LayoutUtil::HasLayout(shape_with_layout)) {
+ return InvalidArgument("Given shape to Outfeed must have a layout");
+ }
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
+ if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) {
+ return InvalidArgument(
+ "Outfeed shape %s must be compatible with operand shape %s",
+ ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(),
+ ShapeUtil::HumanStringWithLayout(operand_shape).c_str());
+ }
+ *instr.mutable_outfeed_shape() = shape_with_layout;
+
+ instr.set_outfeed_config(outfeed_config);
+
+ return AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand});
+ });
}
XlaOp XlaBuilder::CustomCall(const string& call_target_name,
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index 9124ccdb46..c2e3cd2350 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -74,9 +74,9 @@ string ClientLibraryTestBase::TestName() const {
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
}
+template <typename BuilderT>
StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute(
- ComputationBuilder* builder,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ BuilderT* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
// Build the computation, as a convenience.
TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
return client_->Execute(computation, arguments, &execution_options_);
@@ -651,4 +651,11 @@ template void ClientLibraryTestBase::ComputeAndCompareTuple(
XlaBuilder* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error);
+template StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute(
+ ComputationBuilder* builder,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+
+template StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute(
+ XlaBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index 80e1bbbae8..0572acff88 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -92,9 +92,9 @@ class ClientLibraryTestBase : public ::testing::Test {
// Convenience methods for building and running a computation with the member
// execution options. Modify execution_options_ in your test if you want to
// customize the options.
+ template <typename BuilderT>
StatusOr<std::unique_ptr<GlobalData>> Execute(
- ComputationBuilder* builder,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ BuilderT* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments);
// TODO(b/74197823): Remove the template type 'BuilderT' in all methods once
// the migration to XlaBuilder is complete.