diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-02-02 18:03:11 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-02 18:07:05 -0800 |
commit | 09d2efecf44bf313d2e03abdc1c8884cf48e23ae (patch) | |
tree | b834c1d8fc1a0e6a16904be00071fc3b682f00d7 | |
parent | ca22df219cee6ae1daf137fa174f9d0f3874d364 (diff) |
Propagate outfeed sharding, if specified from TensorFlow.
PiperOrigin-RevId: 184361221
4 files changed, 9 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index a57b7e5717..98dfc89867 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -1453,9 +1453,9 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { handle_status = computation->AddInfeedInstruction(arg->infeed_request()); break; case OpRequest::kOutfeedRequest: - TF_RETURN_IF_ERROR( - computation->AddOutfeedInstruction(arg->outfeed_request())); - return tensorflow::Status::OK(); + handle_status = + computation->AddOutfeedInstruction(arg->outfeed_request()); + break; case OpRequest::kMapRequest: { TF_ASSIGN_OR_RETURN( UserComputation * to_apply, diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index 2ea6507900..ef9c80b043 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -1185,7 +1185,7 @@ StatusOr<ComputationDataHandle> UserComputation::AddInfeedInstruction( return handle; } -Status UserComputation::AddOutfeedInstruction( +StatusOr<ComputationDataHandle> UserComputation::AddOutfeedInstruction( const OutfeedRequest& outfeed_request) { tensorflow::mutex_lock lock(mutex_); @@ -1197,8 +1197,6 @@ Status UserComputation::AddOutfeedInstruction( // Verify that operand is valid. TF_RETURN_IF_ERROR(LookUpRequest(outfeed_request.operand()).status()); - // No handle is returned, but a handle must be assigned to this instruction - // for computation versioning. ComputationDataHandle handle = CreateComputationDataHandle(); OperationRequest& request = (*session_computation_.mutable_requests())[handle.handle()]; @@ -1209,7 +1207,7 @@ Status UserComputation::AddOutfeedInstruction( VLOG(1) << "AddOutfeedInstruction (" << GetVersionedHandleInternal() << "), data handle " << handle.handle() << ": " << outfeed_request.ShortDebugString(); - return Status::OK(); + return handle; } StatusOr<ComputationDataHandle> UserComputation::AddCallInstruction( diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h index 4f92e58877..54bb24d6d7 100644 --- a/tensorflow/compiler/xla/service/user_computation.h +++ b/tensorflow/compiler/xla/service/user_computation.h @@ -146,7 +146,8 @@ class UserComputation { const InfeedRequest& infeed_request); // Enqueues an outfeed instruction onto this user computation. - Status AddOutfeedInstruction(const OutfeedRequest& outfeed_request); + StatusOr<ComputationDataHandle> AddOutfeedInstruction( + const OutfeedRequest& outfeed_request); // Enqueues a call instruction onto this user computation. StatusOr<ComputationDataHandle> AddCallInstruction( diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc index ca02115863..2fa163953f 100644 --- a/tensorflow/compiler/xla/service/user_computation_test.cc +++ b/tensorflow/compiler/xla/service/user_computation_test.cc @@ -67,7 +67,8 @@ TEST_F(UserComputationTest, SimpleComputation) { *outfeed_request.mutable_operand() = constant_handle; *outfeed_request.mutable_shape() = kVectorShape; outfeed_request.set_outfeed_config("abc"); - TF_ASSERT_OK(computation.AddOutfeedInstruction(outfeed_request)); + TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle outfeed_handle, + computation.AddOutfeedInstruction(outfeed_request)); auto hlo_resolver = [](const VersionedComputationHandle& handle) { return nullptr; |