aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-27 10:26:23 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-27 10:30:07 -0800
commitbaf6a0c183bb2897ff55ec884b1a6f26cc389bc2 (patch)
tree7ac5e3869f15fa0c61ab456885e99467f40215aa /tensorflow/core
parent2bf95689e547690321c0c5ef237f7df391190b5c (diff)
Optionally store the status code/message in the response
body for RunGraph and RunStep RPCs, to workaround the fact that the RPC subsystem truncates long metadata messages. PiperOrigin-RevId: 180203356
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc9
-rw-r--r--tensorflow/core/distributed_runtime/message_wrappers.cc120
-rw-r--r--tensorflow/core/distributed_runtime/message_wrappers.h62
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc10
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_session.cc8
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc60
-rw-r--r--tensorflow/core/distributed_runtime/worker.cc6
-rw-r--r--tensorflow/core/protobuf/master.proto15
-rw-r--r--tensorflow/core/protobuf/worker.proto17
9 files changed, 304 insertions, 3 deletions
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index 03b65d8cba..dcc25e4426 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -446,7 +446,13 @@ class RunManyGraphs {
// When the index-th call is done, updates the overall status.
void WhenDone(int index, const Status& s) {
TRACEPRINTF("Partition %d %s", index, s.ToString().c_str());
- if (!s.ok()) {
+ auto resp = get(index)->resp.get();
+ if (resp->status_code() != error::Code::OK) {
+ // resp->status_code will only be non-OK if s.ok().
+ mutex_lock l(mu_);
+ UpdateStatusLocked(
+ Status(resp->status_code(), resp->status_error_message()));
+ } else if (!s.ok()) {
mutex_lock l(mu_);
UpdateStatusLocked(s);
}
@@ -539,6 +545,7 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
c->req->set_graph_handle(part.graph_handle);
c->req->set_step_id(step_id);
*c->req->mutable_exec_opts() = exec_opts;
+ c->req->set_store_errors_in_response_body(true);
// If any feeds are provided, send the feed values together
// in the RunGraph request.
// In the partial case, we only want to include feeds provided in the req.
diff --git a/tensorflow/core/distributed_runtime/message_wrappers.cc b/tensorflow/core/distributed_runtime/message_wrappers.cc
index a4a88e6e3b..66ebb3080a 100644
--- a/tensorflow/core/distributed_runtime/message_wrappers.cc
+++ b/tensorflow/core/distributed_runtime/message_wrappers.cc
@@ -93,6 +93,15 @@ const RunOptions& InMemoryRunStepRequest::options() const { return options_; }
RunOptions* InMemoryRunStepRequest::mutable_options() { return &options_; }
+bool InMemoryRunStepRequest::store_errors_in_response_body() const {
+ return store_errors_in_response_body_;
+}
+
+void InMemoryRunStepRequest::set_store_errors_in_response_body(
+ bool store_errors) {
+ store_errors_in_response_body_ = store_errors;
+}
+
string InMemoryRunStepRequest::DebugString() const {
return ToProto().DebugString();
}
@@ -192,6 +201,15 @@ RunOptions* MutableProtoRunStepRequest::mutable_options() {
return request_.mutable_options();
}
+bool MutableProtoRunStepRequest::store_errors_in_response_body() const {
+ return request_.store_errors_in_response_body();
+}
+
+void MutableProtoRunStepRequest::set_store_errors_in_response_body(
+ bool store_errors) {
+ request_.set_store_errors_in_response_body(store_errors);
+}
+
string MutableProtoRunStepRequest::DebugString() const {
return request_.DebugString();
}
@@ -250,6 +268,10 @@ const RunOptions& ProtoRunStepRequest::options() const {
return request_->options();
}
+bool ProtoRunStepRequest::store_errors_in_response_body() const {
+ return request_->store_errors_in_response_body();
+}
+
string ProtoRunStepRequest::DebugString() const {
return request_->DebugString();
}
@@ -329,6 +351,15 @@ void InMemoryRunGraphRequest::set_is_last_partial_run(
is_last_partial_run_ = is_last_partial_run;
}
+bool InMemoryRunGraphRequest::store_errors_in_response_body() const {
+ return store_errors_in_response_body_;
+}
+
+void InMemoryRunGraphRequest::set_store_errors_in_response_body(
+ bool store_errors) {
+ store_errors_in_response_body_ = store_errors;
+}
+
const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const {
if (!proto_version_) {
proto_version_.reset(new RunGraphRequest);
@@ -437,6 +468,15 @@ void MutableProtoRunGraphRequest::set_is_last_partial_run(
request_.set_is_last_partial_run(is_last_partial_run);
}
+bool MutableProtoRunGraphRequest::store_errors_in_response_body() const {
+ return request_.store_errors_in_response_body();
+}
+
+void MutableProtoRunGraphRequest::set_store_errors_in_response_body(
+ bool store_errors) {
+ request_.set_store_errors_in_response_body(store_errors);
+}
+
const RunGraphRequest& MutableProtoRunGraphRequest::ToProto() const {
return request_;
}
@@ -486,6 +526,10 @@ bool ProtoRunGraphRequest::is_last_partial_run() const {
return request_->is_last_partial_run();
}
+bool ProtoRunGraphRequest::store_errors_in_response_body() const {
+ return request_->store_errors_in_response_body();
+}
+
const RunGraphRequest& ProtoRunGraphRequest::ToProto() const {
return *request_;
}
@@ -518,6 +562,18 @@ CostGraphDef* InMemoryRunGraphResponse::mutable_cost_graph() {
return &cost_graph_;
}
+errors::Code InMemoryRunGraphResponse::status_code() const {
+ return status_.code();
+}
+
+const string& InMemoryRunGraphResponse::status_error_message() const {
+ return status_.error_message();
+}
+
+void InMemoryRunGraphResponse::set_status(const Status& status) {
+ status_ = status;
+}
+
RunGraphResponse* InMemoryRunGraphResponse::get_proto() {
LOG(FATAL) << "Cannot get a mutable protobuf for an InMemoryRunGraphResponse";
return nullptr;
@@ -574,6 +630,19 @@ CostGraphDef* OwnedProtoRunGraphResponse::mutable_cost_graph() {
return response_.mutable_cost_graph();
}
+errors::Code OwnedProtoRunGraphResponse::status_code() const {
+ return response_.status_code();
+}
+
+const string& OwnedProtoRunGraphResponse::status_error_message() const {
+ return response_.status_error_message();
+}
+
+void OwnedProtoRunGraphResponse::set_status(const Status& status) {
+ response_.set_status_code(status.code());
+ response_.set_status_error_message(status.error_message());
+}
+
RunGraphResponse* OwnedProtoRunGraphResponse::get_proto() { return &response_; }
size_t OwnedProtoRunGraphResponse::num_partition_graphs() const {
@@ -632,6 +701,19 @@ CostGraphDef* NonOwnedProtoRunGraphResponse::mutable_cost_graph() {
return response_->mutable_cost_graph();
}
+errors::Code NonOwnedProtoRunGraphResponse::status_code() const {
+ return response_->status_code();
+}
+
+const string& NonOwnedProtoRunGraphResponse::status_error_message() const {
+ return response_->status_error_message();
+}
+
+void NonOwnedProtoRunGraphResponse::set_status(const Status& status) {
+ response_->set_status_code(status.code());
+ response_->set_status_error_message(status.error_message());
+}
+
RunGraphResponse* NonOwnedProtoRunGraphResponse::get_proto() {
return response_;
}
@@ -678,6 +760,18 @@ Status InMemoryRunStepResponse::AddTensorFromRunGraphResponse(
RunMetadata* InMemoryRunStepResponse::mutable_metadata() { return &metadata_; }
+errors::Code InMemoryRunStepResponse::status_code() const {
+ return status_.code();
+}
+
+const string& InMemoryRunStepResponse::status_error_message() const {
+ return status_.error_message();
+}
+
+void InMemoryRunStepResponse::set_status(const Status& status) {
+ status_ = status;
+}
+
RunStepResponse* InMemoryRunStepResponse::get_proto() {
LOG(FATAL) << "Cannot get a mutable protobuf for an InMemoryRunStepResponse";
return nullptr;
@@ -716,6 +810,19 @@ RunMetadata* OwnedProtoRunStepResponse::mutable_metadata() {
return response_.mutable_metadata();
}
+errors::Code OwnedProtoRunStepResponse::status_code() const {
+ return response_.status_code();
+}
+
+const string& OwnedProtoRunStepResponse::status_error_message() const {
+ return response_.status_error_message();
+}
+
+void OwnedProtoRunStepResponse::set_status(const Status& status) {
+ response_.set_status_code(status.code());
+ response_.set_status_error_message(status.error_message());
+}
+
RunStepResponse* OwnedProtoRunStepResponse::get_proto() { return &response_; }
NonOwnedProtoRunStepResponse::NonOwnedProtoRunStepResponse(
@@ -755,6 +862,19 @@ RunMetadata* NonOwnedProtoRunStepResponse::mutable_metadata() {
return response_->mutable_metadata();
}
+errors::Code NonOwnedProtoRunStepResponse::status_code() const {
+ return response_->status_code();
+}
+
+const string& NonOwnedProtoRunStepResponse::status_error_message() const {
+ return response_->status_error_message();
+}
+
+void NonOwnedProtoRunStepResponse::set_status(const Status& status) {
+ response_->set_status_code(status.code());
+ response_->set_status_error_message(status.error_message());
+}
+
RunStepResponse* NonOwnedProtoRunStepResponse::get_proto() { return response_; }
} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/message_wrappers.h b/tensorflow/core/distributed_runtime/message_wrappers.h
index 0e3f5b98cb..7113d73dd7 100644
--- a/tensorflow/core/distributed_runtime/message_wrappers.h
+++ b/tensorflow/core/distributed_runtime/message_wrappers.h
@@ -80,6 +80,13 @@ class RunStepRequestWrapper {
// Options for the run call.
virtual const RunOptions& options() const = 0;
+ // If true then some errors, e.g., execution errors that have long
+ // error messages, may return an OK RunStepResponse with the actual
+ // error saved in the status_code/status_error_message fields of the
+ // response body. This is a workaround since the RPC subsystem may
+ // truncate long metadata messages.
+ virtual bool store_errors_in_response_body() const = 0;
+
// Returns a human-readable representation of this message for debugging.
virtual string DebugString() const = 0;
@@ -98,6 +105,7 @@ class MutableRunStepRequestWrapper : public RunStepRequestWrapper {
virtual void add_fetch(const string& name) = 0;
virtual void add_target(const string& name) = 0;
virtual RunOptions* mutable_options() = 0;
+ virtual void set_store_errors_in_response_body(bool store_errors) = 0;
};
// Specialized (and mutable) wrapper for RunStep requests between a client and
@@ -118,6 +126,7 @@ class InMemoryRunStepRequest : public MutableRunStepRequestWrapper {
const RunOptions& options() const override;
string DebugString() const override;
const RunStepRequest& ToProto() const override;
+ bool store_errors_in_response_body() const override;
// MutableRunStepRequestWrapper methods.
void set_session_handle(const string& handle) override;
@@ -126,6 +135,7 @@ class InMemoryRunStepRequest : public MutableRunStepRequestWrapper {
void add_fetch(const string& name) override;
void add_target(const string& name) override;
RunOptions* mutable_options() override;
+ void set_store_errors_in_response_body(bool store_errors) override;
private:
string session_handle_;
@@ -134,6 +144,7 @@ class InMemoryRunStepRequest : public MutableRunStepRequestWrapper {
gtl::InlinedVector<string, 4> fetches_;
gtl::InlinedVector<string, 4> targets_;
RunOptions options_;
+ bool store_errors_in_response_body_ = false;
// Holds a cached and owned representation of the proto
// representation of this request, if needed, so that `ToProto()`
@@ -165,6 +176,7 @@ class MutableProtoRunStepRequest : public MutableRunStepRequestWrapper {
const RunOptions& options() const override;
string DebugString() const override;
const RunStepRequest& ToProto() const override;
+ bool store_errors_in_response_body() const override;
// MutableRunStepRequestWrapper methods.
void set_session_handle(const string& handle) override;
@@ -173,6 +185,7 @@ class MutableProtoRunStepRequest : public MutableRunStepRequestWrapper {
void add_fetch(const string& name) override;
void add_target(const string& name) override;
RunOptions* mutable_options() override;
+ void set_store_errors_in_response_body(bool store_errors) override;
private:
RunStepRequest request_;
@@ -202,6 +215,7 @@ class ProtoRunStepRequest : public RunStepRequestWrapper {
const RunOptions& options() const override;
string DebugString() const override;
const RunStepRequest& ToProto() const override;
+ bool store_errors_in_response_body() const override;
private:
const RunStepRequest* const request_; // Not owned.
@@ -262,6 +276,13 @@ class RunGraphRequestWrapper {
// True if this is the last partial run request in a sequence of requests.
virtual bool is_last_partial_run() const = 0;
+ // If true then some errors, e.g., execution errors that have long
+ // error messages, may return an OK RunStepResponse with the actual
+ // error saved in the status_code/status_error_message fields of the
+ // response body. This is a workaround since the RPC subsystem may
+ // truncate long metadata messages.
+ virtual bool store_errors_in_response_body() const = 0;
+
// Returns the wrapped data as a protocol buffer message.
virtual const RunGraphRequest& ToProto() const = 0;
};
@@ -285,6 +306,7 @@ class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper {
virtual void add_recv_key(const string& recv_key) = 0;
virtual void set_is_partial(bool is_partial) = 0;
virtual void set_is_last_partial_run(bool is_last_partial_run) = 0;
+ virtual void set_store_errors_in_response_body(bool store_errors) = 0;
};
class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
@@ -302,6 +324,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
bool is_partial() const override;
bool is_last_partial_run() const override;
const RunGraphRequest& ToProto() const override;
+ bool store_errors_in_response_body() const override;
// MutableRunGraphRequestWrapper methods.
void set_session_handle(const string& handle) override;
@@ -314,6 +337,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
void add_recv_key(const string& recv_key) override;
void set_is_partial(bool is_partial) override;
void set_is_last_partial_run(bool is_last_partial_run) override;
+ void set_store_errors_in_response_body(bool store_errors) override;
private:
string session_handle_;
@@ -324,6 +348,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
gtl::InlinedVector<string, 4> recvs_;
bool is_partial_ = false;
bool is_last_partial_run_ = false;
+ bool store_errors_in_response_body_ = false;
// Holds a cached and owned representation of the proto
// representation of this request, if needed, so that `ToProto()`
@@ -349,6 +374,7 @@ class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper {
const string& recv_key(size_t i) const override;
bool is_partial() const override;
bool is_last_partial_run() const override;
+ bool store_errors_in_response_body() const override;
const RunGraphRequest& ToProto() const override;
// MutableRunGraphRequestWrapper methods.
@@ -362,6 +388,7 @@ class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper {
void add_recv_key(const string& recv_key) override;
void set_is_partial(bool is_partial) override;
void set_is_last_partial_run(bool is_last_partial_run) override;
+ void set_store_errors_in_response_body(bool store_errors) override;
private:
RunGraphRequest request_;
@@ -383,6 +410,7 @@ class ProtoRunGraphRequest : public RunGraphRequestWrapper {
const string& recv_key(size_t i) const override;
bool is_partial() const override;
bool is_last_partial_run() const override;
+ bool store_errors_in_response_body() const override;
const RunGraphRequest& ToProto() const override;
private:
@@ -429,6 +457,11 @@ class MutableRunGraphResponseWrapper {
virtual GraphDef* mutable_partition_graph(size_t i) = 0;
virtual void AddPartitionGraph(const GraphDef& partition_graph) = 0;
+ // Returned status if requested.
+ virtual errors::Code status_code() const = 0;
+ virtual const string& status_error_message() const = 0;
+ virtual void set_status(const Status& status) = 0;
+
protected:
// Returns a mutable protobuf message that represents the contents of
// this wrapper, for passing to an RPC subsystem that will populate
@@ -458,6 +491,9 @@ class InMemoryRunGraphResponse : public MutableRunGraphResponseWrapper {
size_t num_partition_graphs() const override;
GraphDef* mutable_partition_graph(size_t i) override;
void AddPartitionGraph(const GraphDef& partition_graph) override;
+ errors::Code status_code() const override;
+ const string& status_error_message() const override;
+ void set_status(const Status& status) override;
protected:
// NOTE: This method is not implemented. See
@@ -469,6 +505,9 @@ class InMemoryRunGraphResponse : public MutableRunGraphResponseWrapper {
StepStats step_stats_;
CostGraphDef cost_graph_;
std::vector<GraphDef> partition_graphs_;
+ // Store the code and message separately so that they can be updated
+ // independently by setters.
+ Status status_;
};
// Proto-based message wrapper for use on the client side of the RunGraph RPC.
@@ -485,6 +524,9 @@ class OwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper {
size_t num_partition_graphs() const override;
GraphDef* mutable_partition_graph(size_t i) override;
void AddPartitionGraph(const GraphDef& partition_graph) override;
+ errors::Code status_code() const override;
+ const string& status_error_message() const override;
+ void set_status(const Status& status) override;
protected:
RunGraphResponse* get_proto() override;
@@ -509,6 +551,9 @@ class NonOwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper {
size_t num_partition_graphs() const override;
GraphDef* mutable_partition_graph(size_t i) override;
void AddPartitionGraph(const GraphDef& partition_graph) override;
+ errors::Code status_code() const override;
+ const string& status_error_message() const override;
+ void set_status(const Status& status) override;
protected:
RunGraphResponse* get_proto() override;
@@ -558,6 +603,11 @@ class MutableRunStepResponseWrapper {
virtual const RunMetadata& metadata() const = 0;
virtual RunMetadata* mutable_metadata() = 0;
+ // Returned status if requested.
+ virtual errors::Code status_code() const = 0;
+ virtual const string& status_error_message() const = 0;
+ virtual void set_status(const Status& status) = 0;
+
protected:
// Returns a mutable protobuf message that represents the contents of
// this wrapper, for passing to an RPC subsystem that will populate
@@ -585,6 +635,9 @@ class InMemoryRunStepResponse : public MutableRunStepResponseWrapper {
size_t i) override;
const RunMetadata& metadata() const override;
RunMetadata* mutable_metadata() override;
+ errors::Code status_code() const override;
+ const string& status_error_message() const override;
+ void set_status(const Status& status) override;
protected:
// NOTE: This method is not implemented. See
@@ -594,6 +647,9 @@ class InMemoryRunStepResponse : public MutableRunStepResponseWrapper {
private:
gtl::InlinedVector<std::pair<string, Tensor>, 4> tensors_;
RunMetadata metadata_;
+ // Store the code and message separately so that they can be updated
+ // independently by setters.
+ Status status_;
};
// Proto-based message wrapper for use on the client side of the RunStep RPC.
@@ -608,6 +664,9 @@ class OwnedProtoRunStepResponse : public MutableRunStepResponseWrapper {
size_t i) override;
const RunMetadata& metadata() const override;
RunMetadata* mutable_metadata() override;
+ errors::Code status_code() const override;
+ const string& status_error_message() const override;
+ void set_status(const Status& status) override;
protected:
RunStepResponse* get_proto() override;
@@ -630,6 +689,9 @@ class NonOwnedProtoRunStepResponse : public MutableRunStepResponseWrapper {
size_t i) override;
const RunMetadata& metadata() const override;
RunMetadata* mutable_metadata() override;
+ errors::Code status_code() const override;
+ const string& status_error_message() const override;
+ void set_status(const Status& status) override;
protected:
RunStepResponse* get_proto() override;
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
index 41ee81c01d..ac27993773 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
@@ -192,7 +192,15 @@ class GrpcMasterService : public AsyncServiceInterface {
delete call_opts;
delete wrapped_request;
delete trace;
- call->SendResponse(ToGrpcStatus(status));
+ if (call->request.store_errors_in_response_body() &&
+ !status.ok()) {
+ call->response.set_status_code(status.code());
+ call->response.set_status_error_message(
+ status.error_message());
+ call->SendResponse(ToGrpcStatus(Status::OK()));
+ } else {
+ call->SendResponse(ToGrpcStatus(status));
+ }
});
ENQUEUE_REQUEST(RunStep, true);
}
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
index 9a08335c1c..c3325ed2a9 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
@@ -190,6 +190,9 @@ Status GrpcSession::RunHelper(
req->add_feed(it.first, it.second);
}
+ // Support long error messages by storing the error code in the response body.
+ req->set_store_errors_in_response_body(true);
+
// Build an index from fetch tensor name to first index in
// output_tensor_names.
std::unordered_map<string, int> output_name_to_offset;
@@ -207,6 +210,11 @@ Status GrpcSession::RunHelper(
call_options.SetTimeout(req->options().timeout_in_ms());
TF_RETURN_IF_ERROR(RunProto(&call_options, req.get(), resp.get()));
+ // Look for an extended error returned in the response body.
+ if (resp->status_code() != error::Code::OK) {
+ return Status(resp->status_code(), resp->status_error_message());
+ }
+
if (!output_tensor_names.empty()) {
outputs->resize(output_tensor_names.size());
}
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
index b673f200cc..335c3febe2 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
@@ -572,6 +572,66 @@ TEST(GrpcSessionTest, Error) {
Env::Default()->SleepForMicroseconds(2000000);
}
+TEST(GrpcSessionTest, LongErrorMessage) {
+ std::unique_ptr<test::TestCluster> cluster;
+ TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
+ const string& master = cluster->targets()[0];
+ const string& dev_a = cluster->devices()[0].name();
+ const string& dev_b = cluster->devices()[1].name();
+ LOG(INFO) << "master " << master << "dev_a " << dev_a << "dev_b " << dev_b;
+ GraphDef gdef;
+ std::vector<string> fetches;
+ {
+ Graph g(OpRegistry::Global());
+
+ // a2 = a + error(a)
+ //
+ // Subgraph for "a" fails. The master will cancel the subgraph for
+ // "b" and then returns the Session::Run.
+ auto a = test::graph::Constant(&g, Tensor());
+ a->set_assigned_device_name(dev_a);
+ std::vector<char> long_string_buffer(1024 * 1024, 'x');
+ StringPiece long_string(long_string_buffer.data(), 1024 * 1024);
+ string name = strings::StrCat(long_string, "fantasia!");
+ auto a_err = test::graph::Error(&g, a, name);
+ a_err->set_assigned_device_name(dev_a);
+ auto a2 = test::graph::Add(&g, a, a_err);
+ a2->set_assigned_device_name(dev_a);
+ fetches.push_back(a2->name());
+
+ // b2 = b + delay(b)
+ //
+ // Subgraph for "b" sleeps at the node "b_delay". When the sleep
+ // finishes, the subgraph "b" will continue execution till it
+ // notices that it is canceled. Meanwhile, subgraph's executor
+ // and its related state (registered ops) should still be alive.
+ auto b = test::graph::Constant(&g, Tensor());
+ b->set_assigned_device_name(dev_b);
+ auto b_delay = test::graph::Delay(&g, b, Microseconds(1000000));
+ b_delay->set_assigned_device_name(dev_b);
+ auto b2 = test::graph::Add(&g, b, b_delay);
+ b2->set_assigned_device_name(dev_b);
+ fetches.push_back(b2->name());
+ test::graph::ToGraphDef(&g, &gdef);
+ }
+ std::unique_ptr<Session> session(NewRemote(Options(master, 1)));
+ ASSERT_TRUE(session != nullptr);
+
+ TF_CHECK_OK(session->Create(gdef));
+ {
+ Status status = session->Run({}, fetches, {}, nullptr);
+ EXPECT_FALSE(status.ok());
+ EXPECT_NE(status.ToString().find("fantasia!"), string::npos);
+ }
+ // session->Close() shall clean up all states related to the session->
+ // E.g., deregisters subgraph with workers, etc.
+ TF_CHECK_OK(session->Close());
+
+ // Sleep a bit so that most of asynchronous works finishes before
+ // the test process finishes.
+ Env::Default()->SleepForMicroseconds(2000000);
+}
+
TEST(SessionTest, SharedVar) {
std::unique_ptr<test::TestCluster> cluster;
TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 1, &cluster));
diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc
index 6cd92f5fe7..8e303d6c51 100644
--- a/tensorflow/core/distributed_runtime/worker.cc
+++ b/tensorflow/core/distributed_runtime/worker.cc
@@ -109,6 +109,12 @@ Status Worker::PrepareRunGraph(RunGraphRequestWrapper* req,
void Worker::RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request,
MutableRunGraphResponseWrapper* response,
StatusCallback done) {
+ if (request->store_errors_in_response_body()) {
+ done = [response, done](const Status& status) {
+ response->set_status(status);
+ done(Status::OK());
+ };
+ }
if (request->is_partial()) {
DoPartialRunGraph(opts, request, response, std::move(done));
} else {
diff --git a/tensorflow/core/protobuf/master.proto b/tensorflow/core/protobuf/master.proto
index 6b25a86ba4..0437cb1b83 100644
--- a/tensorflow/core/protobuf/master.proto
+++ b/tensorflow/core/protobuf/master.proto
@@ -23,6 +23,7 @@ option java_package = "org.tensorflow.distruntime";
import "tensorflow/core/framework/device_attributes.proto";
import "tensorflow/core/framework/graph.proto";
+import "tensorflow/core/lib/core/error_codes.proto";
import "tensorflow/core/protobuf/config.proto";
import "tensorflow/core/protobuf/named_tensor.proto";
@@ -129,6 +130,13 @@ message RunStepRequest {
// Partial run handle (optional). If specified, this will be a partial run
// execution, run up to the specified fetches.
string partial_run_handle = 6;
+
+ // If true then some errors, e.g., execution errors that have long
+ // error messages, may return an OK RunStepResponse with the actual
+ // error saved in the status_code/status_error_message fields of the
+ // response body. This is a workaround since the RPC subsystem may
+ // truncate long metadata messages.
+ bool store_errors_in_response_body = 7;
}
message RunStepResponse {
@@ -138,6 +146,13 @@ message RunStepResponse {
// Returned metadata if requested in the options.
RunMetadata metadata = 2;
+
+ // If store_errors_in_response_body is true in the request, then
+ // optionally the server may return an OK status for the RPC and
+ // fill the true status into the fields below, to allow for messages
+ // that are too long to fit in metadata.
+ error.Code status_code = 3;
+ string status_error_message = 4;
}
////////////////////////////////////////////////////////////////////////////////
diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto
index 385e2dd163..9b51db1362 100644
--- a/tensorflow/core/protobuf/worker.proto
+++ b/tensorflow/core/protobuf/worker.proto
@@ -27,6 +27,7 @@ import "tensorflow/core/framework/step_stats.proto";
import "tensorflow/core/framework/device_attributes.proto";
import "tensorflow/core/framework/graph.proto";
import "tensorflow/core/framework/tensor.proto";
+import "tensorflow/core/lib/core/error_codes.proto";
import "tensorflow/core/protobuf/config.proto";
import "tensorflow/core/protobuf/debug.proto";
import "tensorflow/core/protobuf/named_tensor.proto";
@@ -226,7 +227,14 @@ message RunGraphRequest {
// True if this is the last partial run request in a sequence of requests.
bool is_last_partial_run = 7;
- // Next: 9
+ // If true then some errors, e.g., execution errors that have long
+ // error messages, may return an OK RunGraphResponse with the actual
+ // error saved in the status_code/status_error_message fields of the
+ // response body. This is a workaround since the RPC subsystem may
+ // truncate long metadata messages.
+ bool store_errors_in_response_body = 9;
+
+ // Next: 10
}
message RunGraphResponse {
@@ -240,6 +248,13 @@ message RunGraphResponse {
StepStats step_stats = 2;
CostGraphDef cost_graph = 3;
repeated GraphDef partition_graph = 4;
+
+ // If store_errors_in_response_body is true in the request, then
+ // optionally the server may return an OK status for the RPC and
+ // fill the true status into the fields below, to allow for messages
+ // that are too long to fit in metadata.
+ error.Code status_code = 5;
+ string status_error_message = 6;
}
////////////////////////////////////////////////////////////////////////////////