diff options
author | 2017-12-27 10:26:23 -0800 | |
---|---|---|
committer | 2017-12-27 10:30:07 -0800 | |
commit | baf6a0c183bb2897ff55ec884b1a6f26cc389bc2 (patch) | |
tree | 7ac5e3869f15fa0c61ab456885e99467f40215aa /tensorflow/core | |
parent | 2bf95689e547690321c0c5ef237f7df391190b5c (diff) |
Optionally store the status code/message in the response
body for RunGraph and RunStep RPCs, to workaround the fact that the
RPC subsystem truncates long metadata messages.
PiperOrigin-RevId: 180203356
Diffstat (limited to 'tensorflow/core')
9 files changed, 304 insertions, 3 deletions
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index 03b65d8cba..dcc25e4426 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -446,7 +446,13 @@ class RunManyGraphs { // When the index-th call is done, updates the overall status. void WhenDone(int index, const Status& s) { TRACEPRINTF("Partition %d %s", index, s.ToString().c_str()); - if (!s.ok()) { + auto resp = get(index)->resp.get(); + if (resp->status_code() != error::Code::OK) { + // resp->status_code will only be non-OK if s.ok(). + mutex_lock l(mu_); + UpdateStatusLocked( + Status(resp->status_code(), resp->status_error_message())); + } else if (!s.ok()) { mutex_lock l(mu_); UpdateStatusLocked(s); } @@ -539,6 +545,7 @@ Status MasterSession::ReffedClientGraph::RunPartitions( c->req->set_graph_handle(part.graph_handle); c->req->set_step_id(step_id); *c->req->mutable_exec_opts() = exec_opts; + c->req->set_store_errors_in_response_body(true); // If any feeds are provided, send the feed values together // in the RunGraph request. // In the partial case, we only want to include feeds provided in the req. diff --git a/tensorflow/core/distributed_runtime/message_wrappers.cc b/tensorflow/core/distributed_runtime/message_wrappers.cc index a4a88e6e3b..66ebb3080a 100644 --- a/tensorflow/core/distributed_runtime/message_wrappers.cc +++ b/tensorflow/core/distributed_runtime/message_wrappers.cc @@ -93,6 +93,15 @@ const RunOptions& InMemoryRunStepRequest::options() const { return options_; } RunOptions* InMemoryRunStepRequest::mutable_options() { return &options_; } +bool InMemoryRunStepRequest::store_errors_in_response_body() const { + return store_errors_in_response_body_; +} + +void InMemoryRunStepRequest::set_store_errors_in_response_body( + bool store_errors) { + store_errors_in_response_body_ = store_errors; +} + string InMemoryRunStepRequest::DebugString() const { return ToProto().DebugString(); } @@ -192,6 +201,15 @@ RunOptions* MutableProtoRunStepRequest::mutable_options() { return request_.mutable_options(); } +bool MutableProtoRunStepRequest::store_errors_in_response_body() const { + return request_.store_errors_in_response_body(); +} + +void MutableProtoRunStepRequest::set_store_errors_in_response_body( + bool store_errors) { + request_.set_store_errors_in_response_body(store_errors); +} + string MutableProtoRunStepRequest::DebugString() const { return request_.DebugString(); } @@ -250,6 +268,10 @@ const RunOptions& ProtoRunStepRequest::options() const { return request_->options(); } +bool ProtoRunStepRequest::store_errors_in_response_body() const { + return request_->store_errors_in_response_body(); +} + string ProtoRunStepRequest::DebugString() const { return request_->DebugString(); } @@ -329,6 +351,15 @@ void InMemoryRunGraphRequest::set_is_last_partial_run( is_last_partial_run_ = is_last_partial_run; } +bool InMemoryRunGraphRequest::store_errors_in_response_body() const { + return store_errors_in_response_body_; +} + +void InMemoryRunGraphRequest::set_store_errors_in_response_body( + bool store_errors) { + store_errors_in_response_body_ = store_errors; +} + const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const { if (!proto_version_) { proto_version_.reset(new RunGraphRequest); @@ -437,6 +468,15 @@ void MutableProtoRunGraphRequest::set_is_last_partial_run( request_.set_is_last_partial_run(is_last_partial_run); } +bool MutableProtoRunGraphRequest::store_errors_in_response_body() const { + return request_.store_errors_in_response_body(); +} + +void MutableProtoRunGraphRequest::set_store_errors_in_response_body( + bool store_errors) { + request_.set_store_errors_in_response_body(store_errors); +} + const RunGraphRequest& MutableProtoRunGraphRequest::ToProto() const { return request_; } @@ -486,6 +526,10 @@ bool ProtoRunGraphRequest::is_last_partial_run() const { return request_->is_last_partial_run(); } +bool ProtoRunGraphRequest::store_errors_in_response_body() const { + return request_->store_errors_in_response_body(); +} + const RunGraphRequest& ProtoRunGraphRequest::ToProto() const { return *request_; } @@ -518,6 +562,18 @@ CostGraphDef* InMemoryRunGraphResponse::mutable_cost_graph() { return &cost_graph_; } +errors::Code InMemoryRunGraphResponse::status_code() const { + return status_.code(); +} + +const string& InMemoryRunGraphResponse::status_error_message() const { + return status_.error_message(); +} + +void InMemoryRunGraphResponse::set_status(const Status& status) { + status_ = status; +} + RunGraphResponse* InMemoryRunGraphResponse::get_proto() { LOG(FATAL) << "Cannot get a mutable protobuf for an InMemoryRunGraphResponse"; return nullptr; @@ -574,6 +630,19 @@ CostGraphDef* OwnedProtoRunGraphResponse::mutable_cost_graph() { return response_.mutable_cost_graph(); } +errors::Code OwnedProtoRunGraphResponse::status_code() const { + return response_.status_code(); +} + +const string& OwnedProtoRunGraphResponse::status_error_message() const { + return response_.status_error_message(); +} + +void OwnedProtoRunGraphResponse::set_status(const Status& status) { + response_.set_status_code(status.code()); + response_.set_status_error_message(status.error_message()); +} + RunGraphResponse* OwnedProtoRunGraphResponse::get_proto() { return &response_; } size_t OwnedProtoRunGraphResponse::num_partition_graphs() const { @@ -632,6 +701,19 @@ CostGraphDef* NonOwnedProtoRunGraphResponse::mutable_cost_graph() { return response_->mutable_cost_graph(); } +errors::Code NonOwnedProtoRunGraphResponse::status_code() const { + return response_->status_code(); +} + +const string& NonOwnedProtoRunGraphResponse::status_error_message() const { + return response_->status_error_message(); +} + +void NonOwnedProtoRunGraphResponse::set_status(const Status& status) { + response_->set_status_code(status.code()); + response_->set_status_error_message(status.error_message()); +} + RunGraphResponse* NonOwnedProtoRunGraphResponse::get_proto() { return response_; } @@ -678,6 +760,18 @@ Status InMemoryRunStepResponse::AddTensorFromRunGraphResponse( RunMetadata* InMemoryRunStepResponse::mutable_metadata() { return &metadata_; } +errors::Code InMemoryRunStepResponse::status_code() const { + return status_.code(); +} + +const string& InMemoryRunStepResponse::status_error_message() const { + return status_.error_message(); +} + +void InMemoryRunStepResponse::set_status(const Status& status) { + status_ = status; +} + RunStepResponse* InMemoryRunStepResponse::get_proto() { LOG(FATAL) << "Cannot get a mutable protobuf for an InMemoryRunStepResponse"; return nullptr; @@ -716,6 +810,19 @@ RunMetadata* OwnedProtoRunStepResponse::mutable_metadata() { return response_.mutable_metadata(); } +errors::Code OwnedProtoRunStepResponse::status_code() const { + return response_.status_code(); +} + +const string& OwnedProtoRunStepResponse::status_error_message() const { + return response_.status_error_message(); +} + +void OwnedProtoRunStepResponse::set_status(const Status& status) { + response_.set_status_code(status.code()); + response_.set_status_error_message(status.error_message()); +} + RunStepResponse* OwnedProtoRunStepResponse::get_proto() { return &response_; } NonOwnedProtoRunStepResponse::NonOwnedProtoRunStepResponse( @@ -755,6 +862,19 @@ RunMetadata* NonOwnedProtoRunStepResponse::mutable_metadata() { return response_->mutable_metadata(); } +errors::Code NonOwnedProtoRunStepResponse::status_code() const { + return response_->status_code(); +} + +const string& NonOwnedProtoRunStepResponse::status_error_message() const { + return response_->status_error_message(); +} + +void NonOwnedProtoRunStepResponse::set_status(const Status& status) { + response_->set_status_code(status.code()); + response_->set_status_error_message(status.error_message()); +} + RunStepResponse* NonOwnedProtoRunStepResponse::get_proto() { return response_; } } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/message_wrappers.h b/tensorflow/core/distributed_runtime/message_wrappers.h index 0e3f5b98cb..7113d73dd7 100644 --- a/tensorflow/core/distributed_runtime/message_wrappers.h +++ b/tensorflow/core/distributed_runtime/message_wrappers.h @@ -80,6 +80,13 @@ class RunStepRequestWrapper { // Options for the run call. virtual const RunOptions& options() const = 0; + // If true then some errors, e.g., execution errors that have long + // error messages, may return an OK RunStepResponse with the actual + // error saved in the status_code/status_error_message fields of the + // response body. This is a workaround since the RPC subsystem may + // truncate long metadata messages. + virtual bool store_errors_in_response_body() const = 0; + // Returns a human-readable representation of this message for debugging. virtual string DebugString() const = 0; @@ -98,6 +105,7 @@ class MutableRunStepRequestWrapper : public RunStepRequestWrapper { virtual void add_fetch(const string& name) = 0; virtual void add_target(const string& name) = 0; virtual RunOptions* mutable_options() = 0; + virtual void set_store_errors_in_response_body(bool store_errors) = 0; }; // Specialized (and mutable) wrapper for RunStep requests between a client and @@ -118,6 +126,7 @@ class InMemoryRunStepRequest : public MutableRunStepRequestWrapper { const RunOptions& options() const override; string DebugString() const override; const RunStepRequest& ToProto() const override; + bool store_errors_in_response_body() const override; // MutableRunStepRequestWrapper methods. void set_session_handle(const string& handle) override; @@ -126,6 +135,7 @@ class InMemoryRunStepRequest : public MutableRunStepRequestWrapper { void add_fetch(const string& name) override; void add_target(const string& name) override; RunOptions* mutable_options() override; + void set_store_errors_in_response_body(bool store_errors) override; private: string session_handle_; @@ -134,6 +144,7 @@ class InMemoryRunStepRequest : public MutableRunStepRequestWrapper { gtl::InlinedVector<string, 4> fetches_; gtl::InlinedVector<string, 4> targets_; RunOptions options_; + bool store_errors_in_response_body_ = false; // Holds a cached and owned representation of the proto // representation of this request, if needed, so that `ToProto()` @@ -165,6 +176,7 @@ class MutableProtoRunStepRequest : public MutableRunStepRequestWrapper { const RunOptions& options() const override; string DebugString() const override; const RunStepRequest& ToProto() const override; + bool store_errors_in_response_body() const override; // MutableRunStepRequestWrapper methods. void set_session_handle(const string& handle) override; @@ -173,6 +185,7 @@ class MutableProtoRunStepRequest : public MutableRunStepRequestWrapper { void add_fetch(const string& name) override; void add_target(const string& name) override; RunOptions* mutable_options() override; + void set_store_errors_in_response_body(bool store_errors) override; private: RunStepRequest request_; @@ -202,6 +215,7 @@ class ProtoRunStepRequest : public RunStepRequestWrapper { const RunOptions& options() const override; string DebugString() const override; const RunStepRequest& ToProto() const override; + bool store_errors_in_response_body() const override; private: const RunStepRequest* const request_; // Not owned. @@ -262,6 +276,13 @@ class RunGraphRequestWrapper { // True if this is the last partial run request in a sequence of requests. virtual bool is_last_partial_run() const = 0; + // If true then some errors, e.g., execution errors that have long + // error messages, may return an OK RunStepResponse with the actual + // error saved in the status_code/status_error_message fields of the + // response body. This is a workaround since the RPC subsystem may + // truncate long metadata messages. + virtual bool store_errors_in_response_body() const = 0; + // Returns the wrapped data as a protocol buffer message. virtual const RunGraphRequest& ToProto() const = 0; }; @@ -285,6 +306,7 @@ class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper { virtual void add_recv_key(const string& recv_key) = 0; virtual void set_is_partial(bool is_partial) = 0; virtual void set_is_last_partial_run(bool is_last_partial_run) = 0; + virtual void set_store_errors_in_response_body(bool store_errors) = 0; }; class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper { @@ -302,6 +324,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper { bool is_partial() const override; bool is_last_partial_run() const override; const RunGraphRequest& ToProto() const override; + bool store_errors_in_response_body() const override; // MutableRunGraphRequestWrapper methods. void set_session_handle(const string& handle) override; @@ -314,6 +337,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper { void add_recv_key(const string& recv_key) override; void set_is_partial(bool is_partial) override; void set_is_last_partial_run(bool is_last_partial_run) override; + void set_store_errors_in_response_body(bool store_errors) override; private: string session_handle_; @@ -324,6 +348,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper { gtl::InlinedVector<string, 4> recvs_; bool is_partial_ = false; bool is_last_partial_run_ = false; + bool store_errors_in_response_body_ = false; // Holds a cached and owned representation of the proto // representation of this request, if needed, so that `ToProto()` @@ -349,6 +374,7 @@ class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper { const string& recv_key(size_t i) const override; bool is_partial() const override; bool is_last_partial_run() const override; + bool store_errors_in_response_body() const override; const RunGraphRequest& ToProto() const override; // MutableRunGraphRequestWrapper methods. @@ -362,6 +388,7 @@ class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper { void add_recv_key(const string& recv_key) override; void set_is_partial(bool is_partial) override; void set_is_last_partial_run(bool is_last_partial_run) override; + void set_store_errors_in_response_body(bool store_errors) override; private: RunGraphRequest request_; @@ -383,6 +410,7 @@ class ProtoRunGraphRequest : public RunGraphRequestWrapper { const string& recv_key(size_t i) const override; bool is_partial() const override; bool is_last_partial_run() const override; + bool store_errors_in_response_body() const override; const RunGraphRequest& ToProto() const override; private: @@ -429,6 +457,11 @@ class MutableRunGraphResponseWrapper { virtual GraphDef* mutable_partition_graph(size_t i) = 0; virtual void AddPartitionGraph(const GraphDef& partition_graph) = 0; + // Returned status if requested. + virtual errors::Code status_code() const = 0; + virtual const string& status_error_message() const = 0; + virtual void set_status(const Status& status) = 0; + protected: // Returns a mutable protobuf message that represents the contents of // this wrapper, for passing to an RPC subsystem that will populate @@ -458,6 +491,9 @@ class InMemoryRunGraphResponse : public MutableRunGraphResponseWrapper { size_t num_partition_graphs() const override; GraphDef* mutable_partition_graph(size_t i) override; void AddPartitionGraph(const GraphDef& partition_graph) override; + errors::Code status_code() const override; + const string& status_error_message() const override; + void set_status(const Status& status) override; protected: // NOTE: This method is not implemented. See @@ -469,6 +505,9 @@ class InMemoryRunGraphResponse : public MutableRunGraphResponseWrapper { StepStats step_stats_; CostGraphDef cost_graph_; std::vector<GraphDef> partition_graphs_; + // Store the code and message separately so that they can be updated + // independently by setters. + Status status_; }; // Proto-based message wrapper for use on the client side of the RunGraph RPC. @@ -485,6 +524,9 @@ class OwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper { size_t num_partition_graphs() const override; GraphDef* mutable_partition_graph(size_t i) override; void AddPartitionGraph(const GraphDef& partition_graph) override; + errors::Code status_code() const override; + const string& status_error_message() const override; + void set_status(const Status& status) override; protected: RunGraphResponse* get_proto() override; @@ -509,6 +551,9 @@ class NonOwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper { size_t num_partition_graphs() const override; GraphDef* mutable_partition_graph(size_t i) override; void AddPartitionGraph(const GraphDef& partition_graph) override; + errors::Code status_code() const override; + const string& status_error_message() const override; + void set_status(const Status& status) override; protected: RunGraphResponse* get_proto() override; @@ -558,6 +603,11 @@ class MutableRunStepResponseWrapper { virtual const RunMetadata& metadata() const = 0; virtual RunMetadata* mutable_metadata() = 0; + // Returned status if requested. + virtual errors::Code status_code() const = 0; + virtual const string& status_error_message() const = 0; + virtual void set_status(const Status& status) = 0; + protected: // Returns a mutable protobuf message that represents the contents of // this wrapper, for passing to an RPC subsystem that will populate @@ -585,6 +635,9 @@ class InMemoryRunStepResponse : public MutableRunStepResponseWrapper { size_t i) override; const RunMetadata& metadata() const override; RunMetadata* mutable_metadata() override; + errors::Code status_code() const override; + const string& status_error_message() const override; + void set_status(const Status& status) override; protected: // NOTE: This method is not implemented. See @@ -594,6 +647,9 @@ class InMemoryRunStepResponse : public MutableRunStepResponseWrapper { private: gtl::InlinedVector<std::pair<string, Tensor>, 4> tensors_; RunMetadata metadata_; + // Store the code and message separately so that they can be updated + // independently by setters. + Status status_; }; // Proto-based message wrapper for use on the client side of the RunStep RPC. @@ -608,6 +664,9 @@ class OwnedProtoRunStepResponse : public MutableRunStepResponseWrapper { size_t i) override; const RunMetadata& metadata() const override; RunMetadata* mutable_metadata() override; + errors::Code status_code() const override; + const string& status_error_message() const override; + void set_status(const Status& status) override; protected: RunStepResponse* get_proto() override; @@ -630,6 +689,9 @@ class NonOwnedProtoRunStepResponse : public MutableRunStepResponseWrapper { size_t i) override; const RunMetadata& metadata() const override; RunMetadata* mutable_metadata() override; + errors::Code status_code() const override; + const string& status_error_message() const override; + void set_status(const Status& status) override; protected: RunStepResponse* get_proto() override; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc index 41ee81c01d..ac27993773 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc @@ -192,7 +192,15 @@ class GrpcMasterService : public AsyncServiceInterface { delete call_opts; delete wrapped_request; delete trace; - call->SendResponse(ToGrpcStatus(status)); + if (call->request.store_errors_in_response_body() && + !status.ok()) { + call->response.set_status_code(status.code()); + call->response.set_status_error_message( + status.error_message()); + call->SendResponse(ToGrpcStatus(Status::OK())); + } else { + call->SendResponse(ToGrpcStatus(status)); + } }); ENQUEUE_REQUEST(RunStep, true); } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc index 9a08335c1c..c3325ed2a9 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc @@ -190,6 +190,9 @@ Status GrpcSession::RunHelper( req->add_feed(it.first, it.second); } + // Support long error messages by storing the error code in the response body. + req->set_store_errors_in_response_body(true); + // Build an index from fetch tensor name to first index in // output_tensor_names. std::unordered_map<string, int> output_name_to_offset; @@ -207,6 +210,11 @@ Status GrpcSession::RunHelper( call_options.SetTimeout(req->options().timeout_in_ms()); TF_RETURN_IF_ERROR(RunProto(&call_options, req.get(), resp.get())); + // Look for an extended error returned in the response body. + if (resp->status_code() != error::Code::OK) { + return Status(resp->status_code(), resp->status_error_message()); + } + if (!output_tensor_names.empty()) { outputs->resize(output_tensor_names.size()); } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc index b673f200cc..335c3febe2 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc @@ -572,6 +572,66 @@ TEST(GrpcSessionTest, Error) { Env::Default()->SleepForMicroseconds(2000000); } +TEST(GrpcSessionTest, LongErrorMessage) { + std::unique_ptr<test::TestCluster> cluster; + TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster)); + const string& master = cluster->targets()[0]; + const string& dev_a = cluster->devices()[0].name(); + const string& dev_b = cluster->devices()[1].name(); + LOG(INFO) << "master " << master << "dev_a " << dev_a << "dev_b " << dev_b; + GraphDef gdef; + std::vector<string> fetches; + { + Graph g(OpRegistry::Global()); + + // a2 = a + error(a) + // + // Subgraph for "a" fails. The master will cancel the subgraph for + // "b" and then returns the Session::Run. + auto a = test::graph::Constant(&g, Tensor()); + a->set_assigned_device_name(dev_a); + std::vector<char> long_string_buffer(1024 * 1024, 'x'); + StringPiece long_string(long_string_buffer.data(), 1024 * 1024); + string name = strings::StrCat(long_string, "fantasia!"); + auto a_err = test::graph::Error(&g, a, name); + a_err->set_assigned_device_name(dev_a); + auto a2 = test::graph::Add(&g, a, a_err); + a2->set_assigned_device_name(dev_a); + fetches.push_back(a2->name()); + + // b2 = b + delay(b) + // + // Subgraph for "b" sleeps at the node "b_delay". When the sleep + // finishes, the subgraph "b" will continue execution till it + // notices that it is canceled. Meanwhile, subgraph's executor + // and its related state (registered ops) should still be alive. + auto b = test::graph::Constant(&g, Tensor()); + b->set_assigned_device_name(dev_b); + auto b_delay = test::graph::Delay(&g, b, Microseconds(1000000)); + b_delay->set_assigned_device_name(dev_b); + auto b2 = test::graph::Add(&g, b, b_delay); + b2->set_assigned_device_name(dev_b); + fetches.push_back(b2->name()); + test::graph::ToGraphDef(&g, &gdef); + } + std::unique_ptr<Session> session(NewRemote(Options(master, 1))); + ASSERT_TRUE(session != nullptr); + + TF_CHECK_OK(session->Create(gdef)); + { + Status status = session->Run({}, fetches, {}, nullptr); + EXPECT_FALSE(status.ok()); + EXPECT_NE(status.ToString().find("fantasia!"), string::npos); + } + // session->Close() shall clean up all states related to the session-> + // E.g., deregisters subgraph with workers, etc. + TF_CHECK_OK(session->Close()); + + // Sleep a bit so that most of asynchronous works finishes before + // the test process finishes. + Env::Default()->SleepForMicroseconds(2000000); +} + TEST(SessionTest, SharedVar) { std::unique_ptr<test::TestCluster> cluster; TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 1, &cluster)); diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc index 6cd92f5fe7..8e303d6c51 100644 --- a/tensorflow/core/distributed_runtime/worker.cc +++ b/tensorflow/core/distributed_runtime/worker.cc @@ -109,6 +109,12 @@ Status Worker::PrepareRunGraph(RunGraphRequestWrapper* req, void Worker::RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request, MutableRunGraphResponseWrapper* response, StatusCallback done) { + if (request->store_errors_in_response_body()) { + done = [response, done](const Status& status) { + response->set_status(status); + done(Status::OK()); + }; + } if (request->is_partial()) { DoPartialRunGraph(opts, request, response, std::move(done)); } else { diff --git a/tensorflow/core/protobuf/master.proto b/tensorflow/core/protobuf/master.proto index 6b25a86ba4..0437cb1b83 100644 --- a/tensorflow/core/protobuf/master.proto +++ b/tensorflow/core/protobuf/master.proto @@ -23,6 +23,7 @@ option java_package = "org.tensorflow.distruntime"; import "tensorflow/core/framework/device_attributes.proto"; import "tensorflow/core/framework/graph.proto"; +import "tensorflow/core/lib/core/error_codes.proto"; import "tensorflow/core/protobuf/config.proto"; import "tensorflow/core/protobuf/named_tensor.proto"; @@ -129,6 +130,13 @@ message RunStepRequest { // Partial run handle (optional). If specified, this will be a partial run // execution, run up to the specified fetches. string partial_run_handle = 6; + + // If true then some errors, e.g., execution errors that have long + // error messages, may return an OK RunStepResponse with the actual + // error saved in the status_code/status_error_message fields of the + // response body. This is a workaround since the RPC subsystem may + // truncate long metadata messages. + bool store_errors_in_response_body = 7; } message RunStepResponse { @@ -138,6 +146,13 @@ message RunStepResponse { // Returned metadata if requested in the options. RunMetadata metadata = 2; + + // If store_errors_in_response_body is true in the request, then + // optionally the server may return an OK status for the RPC and + // fill the true status into the fields below, to allow for messages + // that are too long to fit in metadata. + error.Code status_code = 3; + string status_error_message = 4; } //////////////////////////////////////////////////////////////////////////////// diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto index 385e2dd163..9b51db1362 100644 --- a/tensorflow/core/protobuf/worker.proto +++ b/tensorflow/core/protobuf/worker.proto @@ -27,6 +27,7 @@ import "tensorflow/core/framework/step_stats.proto"; import "tensorflow/core/framework/device_attributes.proto"; import "tensorflow/core/framework/graph.proto"; import "tensorflow/core/framework/tensor.proto"; +import "tensorflow/core/lib/core/error_codes.proto"; import "tensorflow/core/protobuf/config.proto"; import "tensorflow/core/protobuf/debug.proto"; import "tensorflow/core/protobuf/named_tensor.proto"; @@ -226,7 +227,14 @@ message RunGraphRequest { // True if this is the last partial run request in a sequence of requests. bool is_last_partial_run = 7; - // Next: 9 + // If true then some errors, e.g., execution errors that have long + // error messages, may return an OK RunGraphResponse with the actual + // error saved in the status_code/status_error_message fields of the + // response body. This is a workaround since the RPC subsystem may + // truncate long metadata messages. + bool store_errors_in_response_body = 9; + + // Next: 10 } message RunGraphResponse { @@ -240,6 +248,13 @@ message RunGraphResponse { StepStats step_stats = 2; CostGraphDef cost_graph = 3; repeated GraphDef partition_graph = 4; + + // If store_errors_in_response_body is true in the request, then + // optionally the server may return an OK status for the RPC and + // fill the true status into the fields below, to allow for messages + // that are too long to fit in metadata. + error.Code status_code = 5; + string status_error_message = 6; } //////////////////////////////////////////////////////////////////////////////// |