diff options
-rw-r--r-- | tensorflow/core/common_runtime/direct_session.cc | 67 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/direct_session.h | 22 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/direct_session_test.cc | 98 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/call_options.cc | 10 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/call_options.h | 7 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/master_interface.h | 18 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/rpc/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc | 30 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/rpc/grpc_session.cc | 81 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/rpc/grpc_session.h | 22 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc | 88 | ||||
-rw-r--r-- | tensorflow/core/framework/config.proto | 8 | ||||
-rw-r--r-- | tensorflow/core/lib/core/notification.h | 9 | ||||
-rw-r--r-- | tensorflow/core/public/session.h | 30 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/fifo_queue_test.py | 16 |
15 files changed, 445 insertions, 63 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 3dd713d99a..e607e226e6 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -127,7 +127,8 @@ DirectSession::DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr) : options_(options), device_mgr_(device_mgr), - cancellation_manager_(new CancellationManager()) { + cancellation_manager_(new CancellationManager()), + operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) { if (options_.config.use_per_session_threads()) { thread_pool_ = NewThreadPool(options_); } else { @@ -254,16 +255,16 @@ Status DirectSession::Run(const NamedTensorList& inputs, const std::vector<string>& target_nodes, std::vector<Tensor>* outputs) { RunOutputs run_outputs; - return RunWithOpts(RunOptions(), inputs, output_names, target_nodes, outputs, - &run_outputs); + return Run(RunOptions(), inputs, output_names, target_nodes, outputs, + &run_outputs); } -Status DirectSession::RunWithOpts(const RunOptions& run_options, - const NamedTensorList& inputs, - const std::vector<string>& output_names, - const std::vector<string>& target_nodes, - std::vector<Tensor>* outputs, - RunOutputs* run_outputs) { +Status DirectSession::Run(const RunOptions& run_options, + const NamedTensorList& inputs, + const std::vector<string>& output_names, + const std::vector<string>& target_nodes, + std::vector<Tensor>* outputs, + RunOutputs* run_outputs) { { mutex_lock l(graph_def_lock_); if (!graph_created_) { @@ -297,7 +298,10 @@ Status DirectSession::RunWithOpts(const RunOptions& run_options, const int num_executors = executors_and_keys->items.size(); ExecutorBarrier* barrier = new ExecutorBarrier( num_executors, run_state.rendez, [&run_state](const Status& ret) { - run_state.status = ret; + { + mutex_lock l(run_state.mu_); + run_state.status.Update(ret); + } run_state.executors_done.Notify(); }); @@ -319,13 +323,18 @@ Status DirectSession::RunWithOpts(const RunOptions& run_options, item.executor->RunAsync(args, barrier->Get()); } - run_state.executors_done.WaitForNotification(); + WaitForNotification(&run_state, run_options.timeout_in_ms() > 0 + ? run_options.timeout_in_ms() + : operation_timeout_in_ms_); if (run_options.trace_level() == RunOptions::FULL_TRACE) { delete args.stats_collector; } - TF_RETURN_IF_ERROR(run_state.status); + { + mutex_lock l(run_state.mu_); + TF_RETURN_IF_ERROR(run_state.status); + } // Receive outputs. TF_RETURN_IF_ERROR( @@ -459,9 +468,12 @@ Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs, mutex_lock l(executor_lock_); bool done = true; if (s.ok()) { - if (!run_state->status.ok()) { - LOG(WARNING) << "An error unrelated to this prun has been detected. " - << run_state->status; + { + mutex_lock l(run_state->mu_); + if (!run_state->status.ok()) { + LOG(WARNING) << "An error unrelated to this prun has been detected. " + << run_state->status; + } } for (const auto& it : inputs) { run_state->pending_inputs.erase(it.first); @@ -472,7 +484,7 @@ Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs, done = run_state->pending_outputs.size() == 0; } if (done) { - run_state->executors_done.WaitForNotification(); + WaitForNotification(run_state, operation_timeout_in_ms_); partial_runs_.erase(handle); delete run_state; } @@ -899,6 +911,29 @@ Status DirectSession::CreateGraphs(gtl::ArraySlice<string> feeds, return ::tensorflow::Status::OK(); } +void DirectSession::WaitForNotification(RunState* run_state, + int64 timeout_in_ms) { + if (timeout_in_ms > 0) { + bool timed_out = + run_state->executors_done.WaitForNotificationWithTimeout(timeout_in_ms); + if (timed_out) { + { + mutex_lock l(run_state->mu_); + run_state->status.Update(Status(error::DEADLINE_EXCEEDED, + "Timed out waiting for notification")); + } + // TODO(sherrym): This cancels all steps in the session, even ones that + // have not exceeded their deadline. An alternative would be to use a + // two-level cancellation manager with a Session-global one containing + // several step-local ones. Probably the RunState should have its own + // CancellationManager. + cancellation_manager_->StartCancel(); + } + } else { + run_state->executors_done.WaitForNotification(); + } +} + class DirectSessionFactory : public SessionFactory { public: DirectSessionFactory() {} diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h index cf6c1f2eb3..c172265008 100644 --- a/tensorflow/core/common_runtime/direct_session.h +++ b/tensorflow/core/common_runtime/direct_session.h @@ -61,12 +61,12 @@ class DirectSession : public Session { std::vector<Tensor>* outputs) override; // NOTE: Experimental and subject to change. - ::tensorflow::Status RunWithOpts(const RunOptions& run_options, - const NamedTensorList& inputs, - const std::vector<string>& output_names, - const std::vector<string>& target_nodes, - std::vector<Tensor>* outputs, - RunOutputs* run_outputs) override; + ::tensorflow::Status Run(const ::tensorflow::RunOptions& run_options, + const NamedTensorList& inputs, + const std::vector<string>& output_names, + const std::vector<string>& target_nodes, + std::vector<Tensor>* outputs, + RunOutputs* run_outputs) override; // NOTE: PRunSetup and PRun are added to support partial execution. This // feature is experimental and subject to change. @@ -121,7 +121,8 @@ class DirectSession : public Session { // is "notified" when all executors are done. 'pending_inputs' are the set // of pending feeds and 'pending_outputs' are the set of pending fetches. struct RunState { - Status status; + mutex mu_; + Status status GUARDED_BY(mu_); IntraProcessRendezvous* rendez = nullptr; Notification executors_done; std::unordered_set<string> pending_inputs; @@ -194,6 +195,10 @@ class DirectSession : public Session { const std::vector<string>& fetches, const ExecutorsAndKeys* executors_and_keys, const RunState* run_state); + // Use the appropriate WaitForNotification function based on whether + // operation_timeout_in_ms is greater than 0. + void WaitForNotification(RunState* run_state, int64 timeout_in_ms); + const SessionOptions options_; // Device structures. @@ -242,6 +247,9 @@ class DirectSession : public Session { // For generating step ids that are unique across all sessions. static std::atomic_int_fast64_t step_id_counter_; + // Global timeout for all blocking operations in this session. + const int64 operation_timeout_in_ms_ = 0; + TF_DISALLOW_COPY_AND_ASSIGN(DirectSession); }; diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc index 5c3f83b399..0cad812300 100644 --- a/tensorflow/core/common_runtime/direct_session_test.cc +++ b/tensorflow/core/common_runtime/direct_session_test.cc @@ -273,8 +273,8 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetworkWithOpts) { RunOutputs run_outputs; EXPECT_EQ(run_outputs.step_stats().dev_stats_size(), 0); - Status s = session->RunWithOpts(run_options, inputs, output_names, - target_nodes, &outputs, &run_outputs); + Status s = session->Run(run_options, inputs, output_names, target_nodes, + &outputs, &run_outputs); TF_ASSERT_OK(s); ASSERT_EQ(1, outputs.size()); @@ -560,5 +560,99 @@ TEST(DirectSessionTest, PartialRunMultiOutputFeed) { ASSERT_EQ(true, outputs[0].flat<bool>()(0)); } +TEST(DirectSessionTest, TimeoutSession) { + GraphDef graph; + // Creates a graph with one FIFOQueue and one dequeue op. + protobuf::TextFormat::ParseFromString(R"proto( + node { + name: 'fifo_queue' + op: 'FIFOQueue' + device: '/device:CPU:0' + attr { + key: 'capacity' + value { + i: 10 + } + } + attr { + key: 'component_types' + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: 'container' + value { + s: '' + } + } + attr { + key: 'shapes' + value { + list { + } + } + } + attr { + key: 'shared_name' + value { + s: '' + } + } + } + node { + name: 'fifo_queue_Dequeue' + op: 'QueueDequeue' + input: 'fifo_queue' + device: '/device:CPU:0' + attr { + key: 'component_types' + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: 'timeout_ms' + value { + i: -1 + } + } + } + versions { + producer: 9 + } + )proto", + &graph); + + // Creates a session with operation_timeout_in_ms set to 100 milliseconds. + SessionOptions options; + (*options.config.mutable_device_count())["CPU"] = 2; + options.config.set_operation_timeout_in_ms(100); + std::unique_ptr<Session> session(NewSession(options)); + ASSERT_TRUE(session != nullptr); + TF_ASSERT_OK(session->Create(graph)); + + // Verifies that the error code is DEADLINE_EXCEEDED. + Status s = session->Run({}, {}, {"fifo_queue_Dequeue"}, nullptr); + ASSERT_EQ(error::DEADLINE_EXCEEDED, s.code()); + session->Close(); + + // Creates a session with no operation_timeout_in_ms. + session.reset(CreateSession()); + ASSERT_TRUE(session != nullptr); + TF_ASSERT_OK(session->Create(graph)); + RunOptions run_options; + run_options.set_timeout_in_ms(20); + // Verifies that the error code is DEADLINE_EXCEEDED. + Status s2 = session->Run(run_options, {}, {}, {"fifo_queue_Dequeue"}, nullptr, + nullptr); + ASSERT_EQ(error::DEADLINE_EXCEEDED, s2.code()); + session->Close(); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/call_options.cc b/tensorflow/core/distributed_runtime/call_options.cc index b9d583b754..a99cec9205 100644 --- a/tensorflow/core/distributed_runtime/call_options.cc +++ b/tensorflow/core/distributed_runtime/call_options.cc @@ -41,4 +41,14 @@ void CallOptions::ClearCancelCallback() { cancel_func_ = nullptr; } +int64 CallOptions::GetTimeout() { + mutex_lock l(mu_); + return timeout_in_ms_; +} + +void CallOptions::SetTimeout(int64 ms) { + mutex_lock l(mu_); + timeout_in_ms_ = ms; +} + } // end namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/call_options.h b/tensorflow/core/distributed_runtime/call_options.h index de0b85f692..6a9967a35b 100644 --- a/tensorflow/core/distributed_runtime/call_options.h +++ b/tensorflow/core/distributed_runtime/call_options.h @@ -60,10 +60,17 @@ class CallOptions { void SetCancelCallback(CancelFunction cancel_func); void ClearCancelCallback(); + // Get and set operation timeout. Timeout value is in milliseconds. + int64 GetTimeout(); + void SetTimeout(int64 ms); + private: mutex mu_; CancelFunction cancel_func_ GUARDED_BY(mu_); + // RPC operation timeout in milliseconds. + int64 timeout_in_ms_ GUARDED_BY(mu_); + TF_DISALLOW_COPY_AND_ASSIGN(CallOptions); }; diff --git a/tensorflow/core/distributed_runtime/master_interface.h b/tensorflow/core/distributed_runtime/master_interface.h index 602cfbd8a3..365447b657 100644 --- a/tensorflow/core/distributed_runtime/master_interface.h +++ b/tensorflow/core/distributed_runtime/master_interface.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_INTERFACE_H_ #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_INTERFACE_H_ +#include "tensorflow/core/distributed_runtime/call_options.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/protobuf/master.pb.h" @@ -28,22 +29,27 @@ namespace tensorflow { class MasterInterface { public: virtual ~MasterInterface() {} - virtual Status CreateSession(const CreateSessionRequest* request, + virtual Status CreateSession(CallOptions* call_options, + const CreateSessionRequest* request, CreateSessionResponse* response) = 0; - virtual Status ExtendSession(const ExtendSessionRequest* request, + virtual Status ExtendSession(CallOptions* call_options, + const ExtendSessionRequest* request, ExtendSessionResponse* response) = 0; - virtual Status RunStep(const RunStepRequest* request, + virtual Status RunStep(CallOptions* call_options, + const RunStepRequest* request, RunStepResponse* response) = 0; - virtual Status CloseSession(const CloseSessionRequest* request, + virtual Status CloseSession(CallOptions* call_options, + const CloseSessionRequest* request, CloseSessionResponse* response) = 0; - virtual Status ListDevices(const ListDevicesRequest* request, + virtual Status ListDevices(CallOptions* call_options, + const ListDevicesRequest* request, ListDevicesResponse* response) = 0; - virtual Status Reset(const ResetRequest* request, + virtual Status Reset(CallOptions* call_options, const ResetRequest* request, ResetResponse* response) = 0; }; diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD index b66719620d..7e2091d114 100644 --- a/tensorflow/core/distributed_runtime/rpc/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -165,6 +165,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:master_proto_cc", "//tensorflow/core:master_service_proto_cc", + "//tensorflow/core/distributed_runtime:call_options", "//tensorflow/core/distributed_runtime:master_interface", ], alwayslink = 1, @@ -313,6 +314,7 @@ cc_library( "//tensorflow/core:master_proto_cc", "//tensorflow/core:protos_all_cc", "//tensorflow/core:tensorflow", + "//tensorflow/core/distributed_runtime:call_options", "//tensorflow/core/distributed_runtime:master_interface", ], alwayslink = 1, diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc index e358aed31f..0b084a0b02 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h" +#include "tensorflow/core/distributed_runtime/call_options.h" #include "tensorflow/core/distributed_runtime/master_interface.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/lib/core/errors.h" @@ -33,43 +34,60 @@ class GrpcRemoteMaster : public MasterInterface { ~GrpcRemoteMaster() override {} - Status CreateSession(const CreateSessionRequest* request, + Status CreateSession(CallOptions* call_options, + const CreateSessionRequest* request, CreateSessionResponse* response) override { ::grpc::ClientContext ctx; + SetDeadline(&ctx, call_options->GetTimeout()); return FromGrpcStatus(stub_->CreateSession(&ctx, *request, response)); } - Status ExtendSession(const ExtendSessionRequest* request, + Status ExtendSession(CallOptions* call_options, + const ExtendSessionRequest* request, ExtendSessionResponse* response) override { ::grpc::ClientContext ctx; + SetDeadline(&ctx, call_options->GetTimeout()); return FromGrpcStatus(stub_->ExtendSession(&ctx, *request, response)); } - Status RunStep(const RunStepRequest* request, + Status RunStep(CallOptions* call_options, const RunStepRequest* request, RunStepResponse* response) override { ::grpc::ClientContext ctx; + SetDeadline(&ctx, call_options->GetTimeout()); return FromGrpcStatus(stub_->RunStep(&ctx, *request, response)); } - Status CloseSession(const CloseSessionRequest* request, + Status CloseSession(CallOptions* call_options, + const CloseSessionRequest* request, CloseSessionResponse* response) override { ::grpc::ClientContext ctx; + SetDeadline(&ctx, call_options->GetTimeout()); return FromGrpcStatus(stub_->CloseSession(&ctx, *request, response)); } - Status ListDevices(const ListDevicesRequest* request, + Status ListDevices(CallOptions* call_options, + const ListDevicesRequest* request, ListDevicesResponse* response) override { ::grpc::ClientContext ctx; + SetDeadline(&ctx, call_options->GetTimeout()); return FromGrpcStatus(stub_->ListDevices(&ctx, *request, response)); } - Status Reset(const ResetRequest* request, ResetResponse* response) override { + Status Reset(CallOptions* call_options, const ResetRequest* request, + ResetResponse* response) override { ::grpc::ClientContext ctx; + SetDeadline(&ctx, call_options->GetTimeout()); return FromGrpcStatus(stub_->Reset(&ctx, *request, response)); } private: std::unique_ptr<grpc::MasterService::Stub> stub_; + + void SetDeadline(::grpc::ClientContext* ctx, int64 time_in_ms) { + if (time_in_ms > 0) { + ctx->set_deadline(gpr_time_from_millis(time_in_ms, GPR_TIMESPAN)); + } + } }; MasterInterface* NewGrpcMaster(SharedGrpcChannelPtr channel) { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc index 6924fc5537..30edc5d062 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc @@ -18,6 +18,7 @@ limitations under the License. #include <unordered_map> #include "tensorflow/core/common_runtime/session_factory.h" +#include "tensorflow/core/distributed_runtime/call_options.h" #include "tensorflow/core/distributed_runtime/master_interface.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h" @@ -67,7 +68,8 @@ void ReEncodeConsts(GraphDef* gdef) { } } // namespace -Status GrpcSession::Create(const GraphDef& graph) { +Status GrpcSession::CreateImpl(CallOptions* call_options, + const GraphDef& graph) { if (!handle_.empty()) { return errors::InvalidArgument("A session is alive."); } @@ -76,7 +78,7 @@ Status GrpcSession::Create(const GraphDef& graph) { *req.mutable_graph_def() = graph; ReEncodeConsts(req.mutable_graph_def()); CreateSessionResponse resp; - Status s = master_->CreateSession(&req, &resp); + Status s = master_->CreateSession(call_options, &req, &resp); if (s.ok()) { mutex_lock l(mu_); swap(handle_, *(resp.mutable_session_handle())); @@ -85,7 +87,21 @@ Status GrpcSession::Create(const GraphDef& graph) { return s; } -Status GrpcSession::Extend(const GraphDef& graph) { +Status GrpcSession::Create(const GraphDef& graph) { + CallOptions call_options; + call_options.SetTimeout(options_.config.operation_timeout_in_ms()); + return CreateImpl(&call_options, graph); +} + +Status GrpcSession::Create(const RunOptions& run_options, + const GraphDef& graph) { + CallOptions call_options; + call_options.SetTimeout(run_options.timeout_in_ms()); + return CreateImpl(&call_options, graph); +} + +Status GrpcSession::ExtendImpl(CallOptions* call_options, + const GraphDef& graph) { if (handle_.empty()) { // Session was unitialized, so simply initialize the session with 'graph'. return Create(graph); @@ -96,17 +112,31 @@ Status GrpcSession::Extend(const GraphDef& graph) { *req.mutable_graph_def() = graph; req.set_current_graph_version(current_graph_version_); ExtendSessionResponse resp; - Status s = master_->ExtendSession(&req, &resp); + Status s = master_->ExtendSession(call_options, &req, &resp); if (s.ok()) { current_graph_version_ = resp.new_graph_version(); } return s; } -Status GrpcSession::Run(const std::vector<std::pair<string, Tensor>>& inputs, - const std::vector<string>& output_names, - const std::vector<string>& target_nodes, - std::vector<Tensor>* outputs) { +Status GrpcSession::Extend(const GraphDef& graph) { + CallOptions call_options; + call_options.SetTimeout(options_.config.operation_timeout_in_ms()); + return ExtendImpl(&call_options, graph); +} + +Status GrpcSession::Extend(const RunOptions& run_options, + const GraphDef& graph) { + CallOptions call_options; + call_options.SetTimeout(run_options.timeout_in_ms()); + return ExtendImpl(&call_options, graph); +} + +Status GrpcSession::Run(const RunOptions& run_options, + const std::vector<std::pair<string, Tensor>>& inputs, + const std::vector<string>& output_tensor_names, + const std::vector<string>& target_node_names, + std::vector<Tensor>* outputs, RunOutputs* run_outputs) { // Convert to proto RunStepRequest req; RunStepResponse resp; @@ -121,19 +151,21 @@ Status GrpcSession::Run(const std::vector<std::pair<string, Tensor>>& inputs, // Build an index from fetch tensor name to offset. std::unordered_map<string, int> output_name_to_offset; - for (const string& output_name : output_names) { + for (const string& output_name : output_tensor_names) { req.add_fetch(output_name); output_name_to_offset.insert( std::make_pair(output_name, output_name_to_offset.size())); } - for (const string& target : target_nodes) { + for (const string& target : target_node_names) { req.add_target(target); } - TF_RETURN_IF_ERROR(RunProto(&req, &resp)); + CallOptions call_options; + call_options.SetTimeout(run_options.timeout_in_ms()); + TF_RETURN_IF_ERROR(RunProto(&call_options, &req, &resp)); - if (!output_names.empty()) { - outputs->resize(output_names.size()); + if (!output_tensor_names.empty()) { + outputs->resize(output_tensor_names.size()); } // Convert response back to Tensors in the correct order. @@ -156,13 +188,24 @@ Status GrpcSession::Run(const std::vector<std::pair<string, Tensor>>& inputs, return Status::OK(); } -Status GrpcSession::RunProto(RunStepRequest* req, RunStepResponse* resp) { +Status GrpcSession::Run(const std::vector<std::pair<string, Tensor>>& inputs, + const std::vector<string>& output_tensor_names, + const std::vector<string>& target_node_names, + std::vector<Tensor>* outputs) { + RunOptions run_options; + run_options.set_timeout_in_ms(options_.config.operation_timeout_in_ms()); + return Run(run_options, inputs, output_tensor_names, target_node_names, + outputs, nullptr); +} + +Status GrpcSession::RunProto(CallOptions* call_options, RunStepRequest* req, + RunStepResponse* resp) { if (handle_.empty()) { return errors::InvalidArgument("A session is not created yet...."); } req->set_session_handle(handle_); - return master_->RunStep(req, resp); + return master_->RunStep(call_options, req, resp); } Status GrpcSession::PRunSetup(const std::vector<string>& input_names, @@ -187,7 +230,9 @@ Status GrpcSession::Close() { req.set_session_handle(handle_); handle_.clear(); CloseSessionResponse resp; - return master_->CloseSession(&req, &resp); + CallOptions call_options; + call_options.SetTimeout(options_.config.operation_timeout_in_ms()); + return master_->CloseSession(&call_options, &req, &resp); } std::vector<DeviceAttributes> GrpcSession::ListDevices() { @@ -195,7 +240,9 @@ std::vector<DeviceAttributes> GrpcSession::ListDevices() { ListDevicesRequest req; ListDevicesResponse resp; - Status s = master_->ListDevices(&req, &resp); + CallOptions call_options; + call_options.SetTimeout(options_.config.operation_timeout_in_ms()); + Status s = master_->ListDevices(&call_options, &req, &resp); if (!s.ok()) { LOG(ERROR) << "Could not list devices: " << s; return devices; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.h b/tensorflow/core/distributed_runtime/rpc/grpc_session.h index 9bc6034ba6..abf7b2a44a 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_session.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.h @@ -19,6 +19,8 @@ limitations under the License. #include <string> #include <vector> +#include "tensorflow/core/distributed_runtime/call_options.h" +#include "tensorflow/core/framework/config.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" @@ -52,13 +54,22 @@ class GrpcSession : public Session { // the graph computation defined by "graph", and will have version // number "initial_version". Status Create(const GraphDef& graph) override; + Status Create(const RunOptions& run_options, const GraphDef& graph) override; + // Runs with and without RunOptions. Status Run(const std::vector<std::pair<string, Tensor> >& inputs, - const std::vector<string>& output_names, - const std::vector<string>& target_nodes, + const std::vector<string>& output_tensor_names, + const std::vector<string>& target_node_names, std::vector<Tensor>* outputs) override; + Status Run(const RunOptions& run_options, + const std::vector<std::pair<string, Tensor> >& inputs, + const std::vector<string>& output_tensor_names, + const std::vector<string>& target_node_names, + std::vector<Tensor>* outputs, RunOutputs* run_outputs); Status Extend(const GraphDef& graph) override; + Status Extend(const RunOptions& run_options, const GraphDef& graph) override; + Status Close() override; // NOTE: This API is still experimental and may change. @@ -87,7 +98,12 @@ class GrpcSession : public Session { // The current version of the graph. int64 current_graph_version_ GUARDED_BY(mu_); - Status RunProto(RunStepRequest* req, RunStepResponse* resp); + Status RunProto(CallOptions* call_options, RunStepRequest* req, + RunStepResponse* resp); + + // Implementations for all the public interfaces. + Status CreateImpl(CallOptions* call_options, const GraphDef& graph); + Status ExtendImpl(CallOptions* call_options, const GraphDef& graph); TF_DISALLOW_COPY_AND_ASSIGN(GrpcSession); }; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc index 86a9b07c2c..7b7507515e 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc @@ -746,5 +746,93 @@ TEST(SessionTest, ExtendValidation) { EXPECT_NE(s.error_message().find("'b', which was created by a previous call"), string::npos); } +// Tests that Create() with "operation_timeout_in_ms" set times out. +TEST(SessionTest, CreateTimeoutWithSessionOptions) { + // Creates a RemoteSession with "operation_timeout_in_ms" set to 100. + SessionOptions options = Options("example.org", 1); + options.config.set_operation_timeout_in_ms(100); + std::unique_ptr<Session> session(NewRemote(options)); + + // Creates a long running op. + Graph graph(OpRegistry::Global()); + Node* b = test::graph::Constant(&graph, Tensor()); + test::graph::Delay(&graph, b, Microseconds(1000000)); + GraphDef gdef; + test::graph::ToGraphDef(&graph, &gdef); + Status status = session->Create(gdef); + EXPECT_EQ(error::DEADLINE_EXCEEDED, status.code()); +} + +// Tests that Create() with "timeout_in_ms" in RunOptions set times out. +TEST(SessionTest, CreateTimeoutWithRunOptions) { + SessionOptions options = Options("example.org", 1); + std::unique_ptr<Session> session(NewRemote(options)); + + // Creates a long running op. + Graph graph(OpRegistry::Global()); + Node* b = test::graph::Constant(&graph, Tensor()); + test::graph::Delay(&graph, b, Microseconds(1000000)); + GraphDef gdef; + test::graph::ToGraphDef(&graph, &gdef); + RunOptions run_options; + // Sets RunOption timeout_in_ms to 20. + run_options.set_timeout_in_ms(20); + Status status = session->Create(run_options, gdef); + EXPECT_EQ(error::DEADLINE_EXCEEDED, status.code()); +} + +// Tests that Run() with "operation_timeout_in_ms" set times out. +TEST(SessionTest, RunTimeoutWithSessionOptions) { + // Creates a RemoteSession with "operation_timeout_in_ms" set to 100. + std::unique_ptr<test::TestCluster> cluster; + TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 1, &cluster)); + SessionOptions options = Options(cluster->targets()[0], 100); + options.config.set_operation_timeout_in_ms(1); + std::unique_ptr<Session> session(NewRemote(options)); + + // Creates a long running op. + Graph graph(OpRegistry::Global()); + Node* b = test::graph::Constant(&graph, Tensor()); + Node* b_delay = test::graph::Delay(&graph, b, Microseconds(2000000)); + GraphDef gdef; + test::graph::ToGraphDef(&graph, &gdef); + RunOptions run_options; + TF_CHECK_OK(session->Create(run_options, gdef)); + + // Verifies that Run() times out, and the error code is DEADLINE_EXCEEDED. + std::vector<std::pair<string, Tensor>> inputs; + Status status = session->Run(inputs, {}, {b_delay->name()}, nullptr); + // TODO(sherrym): Due to potentially a GRPC bug, we sometimes get + // GRPC_CHTTP2_INTERNAL_ERROR which is mapped to error::INTERNAL. + EXPECT_TRUE(error::DEADLINE_EXCEEDED == status.code() || + error::INTERNAL == status.code()); +} + +// Tests that Run() with "timeout_in_ms" set times out. +TEST(SessionTest, RunTimeoutWithRunOptions) { + std::unique_ptr<test::TestCluster> cluster; + TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 1, &cluster)); + SessionOptions options = Options(cluster->targets()[0], 1); + std::unique_ptr<Session> session(NewRemote(options)); + + // Creates a long running op. + Graph graph(OpRegistry::Global()); + Node* b = test::graph::Constant(&graph, Tensor()); + Node* b_delay = test::graph::Delay(&graph, b, Microseconds(1000000)); + GraphDef gdef; + test::graph::ToGraphDef(&graph, &gdef); + TF_CHECK_OK(session->Create(gdef)); + + // Verifies that Run() times out, and the error code is DEADLINE_EXCEEDED. + std::vector<std::pair<string, Tensor>> inputs; + RunOptions run_options; + run_options.set_timeout_in_ms(100); + Status status = session->Run(run_options, inputs, {}, {b_delay->name()}, + nullptr, nullptr); + // TODO(sherrym): Due to potentially a GRPC bug, we sometimes get + // GRPC_CHTTP2_INTERNAL_ERROR which is mapped to error::INTERNAL. + EXPECT_TRUE(error::DEADLINE_EXCEEDED == status.code() || + error::INTERNAL == status.code()); +} } // namespace tensorflow diff --git a/tensorflow/core/framework/config.proto b/tensorflow/core/framework/config.proto index 7482d14d75..c6da869cce 100644 --- a/tensorflow/core/framework/config.proto +++ b/tensorflow/core/framework/config.proto @@ -131,6 +131,11 @@ message ConfigProto { // Options that apply to all graphs. GraphOptions graph_options = 10; + + // Global timeout for all blocking operations in this session. If non-zero, + // and not overridden on a per-operation basis, this value will be used as the + // deadline for all blocking operations. + int64 operation_timeout_in_ms = 11; }; // EXPERIMENTAL. Options for a single Run() call. @@ -140,6 +145,9 @@ message RunOptions { FULL_TRACE = 1; } TraceLevel trace_level = 1; + + // Time to wait for operation to complete in milliseconds. + int64 timeout_in_ms = 2; } // EXPERIMENTAL. Metadata output (i.e., non-Tensor) for a single Run() call. diff --git a/tensorflow/core/lib/core/notification.h b/tensorflow/core/lib/core/notification.h index fe1e400ad5..464236a051 100644 --- a/tensorflow/core/lib/core/notification.h +++ b/tensorflow/core/lib/core/notification.h @@ -17,6 +17,8 @@ limitations under the License. #define TENSORFLOW_UTIL_NOTIFICATION_H_ #include <assert.h> +#include <chrono> // NOLINT +#include <condition_variable> // NOLINT #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" @@ -47,6 +49,13 @@ class Notification { } } + bool WaitForNotificationWithTimeout(int64 timeout_in_ms) { + mutex_lock l(mu_); + std::cv_status s = + cv_.wait_for(l, std::chrono::milliseconds(timeout_in_ms)); + return (s == std::cv_status::timeout) ? true : false; + } + private: mutex mu_; condition_variable cv_; diff --git a/tensorflow/core/public/session.h b/tensorflow/core/public/session.h index 6b64f4671e..3c084b73cf 100644 --- a/tensorflow/core/public/session.h +++ b/tensorflow/core/public/session.h @@ -115,16 +115,34 @@ class Session { const std::vector<string>& target_node_names, std::vector<Tensor>* outputs) = 0; + /// \brief Implementations which support `RunOptions`. + // + /// NOTE: This API is still experimental and may change. + virtual Status Create(const RunOptions& run_options, const GraphDef& graph) { + return errors::Unimplemented( + "Create(const RunOptions& run_options, const GraphDef& graph) is not " + "supported for this session."); + } + virtual Status Extend(const RunOptions& run_options, const GraphDef& graph) { + return errors::Unimplemented( + "Extend(const RunOptions& run_options, const GraphDef& graph) is not " + "supported for this session."); + } + virtual Status Close(const RunOptions& run_options) { + return errors::Unimplemented( + "Close(const RunOptions& run_options) is not supported for this " + "session."); + } + /// \brief Like `Run`, but allows users to pass in a `RunOptions` proto and /// to retrieve non-Tensor metadata output via a `RunOutputs` proto for this /// step. /// NOTE: This API is still experimental and may change. - virtual Status RunWithOpts( - const RunOptions& run_options, - const std::vector<std::pair<string, Tensor> >& inputs, - const std::vector<string>& output_tensor_names, - const std::vector<string>& target_node_names, - std::vector<Tensor>* outputs, RunOutputs* run_outputs) { + virtual Status Run(const RunOptions& run_options, + const std::vector<std::pair<string, Tensor> >& inputs, + const std::vector<string>& output_tensor_names, + const std::vector<string>& target_node_names, + std::vector<Tensor>* outputs, RunOutputs* run_outputs) { return errors::Unimplemented( "RunWithOpts() is not supported for this session."); } diff --git a/tensorflow/python/kernel_tests/fifo_queue_test.py b/tensorflow/python/kernel_tests/fifo_queue_test.py index 22a8b07dba..af232f65cc 100644 --- a/tensorflow/python/kernel_tests/fifo_queue_test.py +++ b/tensorflow/python/kernel_tests/fifo_queue_test.py @@ -25,6 +25,7 @@ import time import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf +from tensorflow.python.pywrap_tensorflow import StatusNotOK class FIFOQueueTest(tf.test.TestCase): @@ -1150,5 +1151,20 @@ class FIFOQueueTest(tf.test.TestCase): self.assertAllEqual(input_elem, output_elem) +class FIFOQueueWithTimeoutTest(tf.test.TestCase): + + def testDequeueWithTimeout(self): + with self.test_session( + config=tf.ConfigProto(operation_timeout_in_ms=20)) as sess: + q = tf.FIFOQueue(10, tf.float32) + dequeued_t = q.dequeue() + + # Intentionally do not run any enqueue_ops so that dequeue will block + # until operation_timeout_in_ms. + with self.assertRaisesRegexp(StatusNotOK, + "Timed out waiting for notification"): + sess.run(dequeued_t) + + if __name__ == "__main__": tf.test.main() |