aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-02 18:03:11 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-02 18:07:05 -0800
commit09d2efecf44bf313d2e03abdc1c8884cf48e23ae (patch)
treeb834c1d8fc1a0e6a16904be00071fc3b682f00d7
parentca22df219cee6ae1daf137fa174f9d0f3874d364 (diff)
Propagate outfeed sharding, if specified from TensorFlow.
PiperOrigin-RevId: 184361221
-rw-r--r--tensorflow/compiler/xla/service/service.cc6
-rw-r--r--tensorflow/compiler/xla/service/user_computation.cc6
-rw-r--r--tensorflow/compiler/xla/service/user_computation.h3
-rw-r--r--tensorflow/compiler/xla/service/user_computation_test.cc3
4 files changed, 9 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index a57b7e5717..98dfc89867 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -1453,9 +1453,9 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) {
handle_status = computation->AddInfeedInstruction(arg->infeed_request());
break;
case OpRequest::kOutfeedRequest:
- TF_RETURN_IF_ERROR(
- computation->AddOutfeedInstruction(arg->outfeed_request()));
- return tensorflow::Status::OK();
+ handle_status =
+ computation->AddOutfeedInstruction(arg->outfeed_request());
+ break;
case OpRequest::kMapRequest: {
TF_ASSIGN_OR_RETURN(
UserComputation * to_apply,
diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc
index 2ea6507900..ef9c80b043 100644
--- a/tensorflow/compiler/xla/service/user_computation.cc
+++ b/tensorflow/compiler/xla/service/user_computation.cc
@@ -1185,7 +1185,7 @@ StatusOr<ComputationDataHandle> UserComputation::AddInfeedInstruction(
return handle;
}
-Status UserComputation::AddOutfeedInstruction(
+StatusOr<ComputationDataHandle> UserComputation::AddOutfeedInstruction(
const OutfeedRequest& outfeed_request) {
tensorflow::mutex_lock lock(mutex_);
@@ -1197,8 +1197,6 @@ Status UserComputation::AddOutfeedInstruction(
// Verify that operand is valid.
TF_RETURN_IF_ERROR(LookUpRequest(outfeed_request.operand()).status());
- // No handle is returned, but a handle must be assigned to this instruction
- // for computation versioning.
ComputationDataHandle handle = CreateComputationDataHandle();
OperationRequest& request =
(*session_computation_.mutable_requests())[handle.handle()];
@@ -1209,7 +1207,7 @@ Status UserComputation::AddOutfeedInstruction(
VLOG(1) << "AddOutfeedInstruction (" << GetVersionedHandleInternal()
<< "), data handle " << handle.handle() << ": "
<< outfeed_request.ShortDebugString();
- return Status::OK();
+ return handle;
}
StatusOr<ComputationDataHandle> UserComputation::AddCallInstruction(
diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h
index 4f92e58877..54bb24d6d7 100644
--- a/tensorflow/compiler/xla/service/user_computation.h
+++ b/tensorflow/compiler/xla/service/user_computation.h
@@ -146,7 +146,8 @@ class UserComputation {
const InfeedRequest& infeed_request);
// Enqueues an outfeed instruction onto this user computation.
- Status AddOutfeedInstruction(const OutfeedRequest& outfeed_request);
+ StatusOr<ComputationDataHandle> AddOutfeedInstruction(
+ const OutfeedRequest& outfeed_request);
// Enqueues a call instruction onto this user computation.
StatusOr<ComputationDataHandle> AddCallInstruction(
diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc
index ca02115863..2fa163953f 100644
--- a/tensorflow/compiler/xla/service/user_computation_test.cc
+++ b/tensorflow/compiler/xla/service/user_computation_test.cc
@@ -67,7 +67,8 @@ TEST_F(UserComputationTest, SimpleComputation) {
*outfeed_request.mutable_operand() = constant_handle;
*outfeed_request.mutable_shape() = kVectorShape;
outfeed_request.set_outfeed_config("abc");
- TF_ASSERT_OK(computation.AddOutfeedInstruction(outfeed_request));
+ TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle outfeed_handle,
+ computation.AddOutfeedInstruction(outfeed_request));
auto hlo_resolver = [](const VersionedComputationHandle& handle) {
return nullptr;