aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc67
-rw-r--r--tensorflow/core/common_runtime/direct_session.h22
-rw-r--r--tensorflow/core/common_runtime/direct_session_test.cc98
-rw-r--r--tensorflow/core/distributed_runtime/call_options.cc10
-rw-r--r--tensorflow/core/distributed_runtime/call_options.h7
-rw-r--r--tensorflow/core/distributed_runtime/master_interface.h18
-rw-r--r--tensorflow/core/distributed_runtime/rpc/BUILD2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc30
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_session.cc81
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_session.h22
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc88
-rw-r--r--tensorflow/core/framework/config.proto8
-rw-r--r--tensorflow/core/lib/core/notification.h9
-rw-r--r--tensorflow/core/public/session.h30
-rw-r--r--tensorflow/python/kernel_tests/fifo_queue_test.py16
15 files changed, 445 insertions, 63 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 3dd713d99a..e607e226e6 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -127,7 +127,8 @@ DirectSession::DirectSession(const SessionOptions& options,
const DeviceMgr* device_mgr)
: options_(options),
device_mgr_(device_mgr),
- cancellation_manager_(new CancellationManager()) {
+ cancellation_manager_(new CancellationManager()),
+ operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) {
if (options_.config.use_per_session_threads()) {
thread_pool_ = NewThreadPool(options_);
} else {
@@ -254,16 +255,16 @@ Status DirectSession::Run(const NamedTensorList& inputs,
const std::vector<string>& target_nodes,
std::vector<Tensor>* outputs) {
RunOutputs run_outputs;
- return RunWithOpts(RunOptions(), inputs, output_names, target_nodes, outputs,
- &run_outputs);
+ return Run(RunOptions(), inputs, output_names, target_nodes, outputs,
+ &run_outputs);
}
-Status DirectSession::RunWithOpts(const RunOptions& run_options,
- const NamedTensorList& inputs,
- const std::vector<string>& output_names,
- const std::vector<string>& target_nodes,
- std::vector<Tensor>* outputs,
- RunOutputs* run_outputs) {
+Status DirectSession::Run(const RunOptions& run_options,
+ const NamedTensorList& inputs,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_nodes,
+ std::vector<Tensor>* outputs,
+ RunOutputs* run_outputs) {
{
mutex_lock l(graph_def_lock_);
if (!graph_created_) {
@@ -297,7 +298,10 @@ Status DirectSession::RunWithOpts(const RunOptions& run_options,
const int num_executors = executors_and_keys->items.size();
ExecutorBarrier* barrier = new ExecutorBarrier(
num_executors, run_state.rendez, [&run_state](const Status& ret) {
- run_state.status = ret;
+ {
+ mutex_lock l(run_state.mu_);
+ run_state.status.Update(ret);
+ }
run_state.executors_done.Notify();
});
@@ -319,13 +323,18 @@ Status DirectSession::RunWithOpts(const RunOptions& run_options,
item.executor->RunAsync(args, barrier->Get());
}
- run_state.executors_done.WaitForNotification();
+ WaitForNotification(&run_state, run_options.timeout_in_ms() > 0
+ ? run_options.timeout_in_ms()
+ : operation_timeout_in_ms_);
if (run_options.trace_level() == RunOptions::FULL_TRACE) {
delete args.stats_collector;
}
- TF_RETURN_IF_ERROR(run_state.status);
+ {
+ mutex_lock l(run_state.mu_);
+ TF_RETURN_IF_ERROR(run_state.status);
+ }
// Receive outputs.
TF_RETURN_IF_ERROR(
@@ -459,9 +468,12 @@ Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
mutex_lock l(executor_lock_);
bool done = true;
if (s.ok()) {
- if (!run_state->status.ok()) {
- LOG(WARNING) << "An error unrelated to this prun has been detected. "
- << run_state->status;
+ {
+ mutex_lock l(run_state->mu_);
+ if (!run_state->status.ok()) {
+ LOG(WARNING) << "An error unrelated to this prun has been detected. "
+ << run_state->status;
+ }
}
for (const auto& it : inputs) {
run_state->pending_inputs.erase(it.first);
@@ -472,7 +484,7 @@ Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
done = run_state->pending_outputs.size() == 0;
}
if (done) {
- run_state->executors_done.WaitForNotification();
+ WaitForNotification(run_state, operation_timeout_in_ms_);
partial_runs_.erase(handle);
delete run_state;
}
@@ -899,6 +911,29 @@ Status DirectSession::CreateGraphs(gtl::ArraySlice<string> feeds,
return ::tensorflow::Status::OK();
}
+void DirectSession::WaitForNotification(RunState* run_state,
+ int64 timeout_in_ms) {
+ if (timeout_in_ms > 0) {
+ bool timed_out =
+ run_state->executors_done.WaitForNotificationWithTimeout(timeout_in_ms);
+ if (timed_out) {
+ {
+ mutex_lock l(run_state->mu_);
+ run_state->status.Update(Status(error::DEADLINE_EXCEEDED,
+ "Timed out waiting for notification"));
+ }
+ // TODO(sherrym): This cancels all steps in the session, even ones that
+ // have not exceeded their deadline. An alternative would be to use a
+ // two-level cancellation manager with a Session-global one containing
+ // several step-local ones. Probably the RunState should have its own
+ // CancellationManager.
+ cancellation_manager_->StartCancel();
+ }
+ } else {
+ run_state->executors_done.WaitForNotification();
+ }
+}
+
class DirectSessionFactory : public SessionFactory {
public:
DirectSessionFactory() {}
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h
index cf6c1f2eb3..c172265008 100644
--- a/tensorflow/core/common_runtime/direct_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -61,12 +61,12 @@ class DirectSession : public Session {
std::vector<Tensor>* outputs) override;
// NOTE: Experimental and subject to change.
- ::tensorflow::Status RunWithOpts(const RunOptions& run_options,
- const NamedTensorList& inputs,
- const std::vector<string>& output_names,
- const std::vector<string>& target_nodes,
- std::vector<Tensor>* outputs,
- RunOutputs* run_outputs) override;
+ ::tensorflow::Status Run(const ::tensorflow::RunOptions& run_options,
+ const NamedTensorList& inputs,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_nodes,
+ std::vector<Tensor>* outputs,
+ RunOutputs* run_outputs) override;
// NOTE: PRunSetup and PRun are added to support partial execution. This
// feature is experimental and subject to change.
@@ -121,7 +121,8 @@ class DirectSession : public Session {
// is "notified" when all executors are done. 'pending_inputs' are the set
// of pending feeds and 'pending_outputs' are the set of pending fetches.
struct RunState {
- Status status;
+ mutex mu_;
+ Status status GUARDED_BY(mu_);
IntraProcessRendezvous* rendez = nullptr;
Notification executors_done;
std::unordered_set<string> pending_inputs;
@@ -194,6 +195,10 @@ class DirectSession : public Session {
const std::vector<string>& fetches,
const ExecutorsAndKeys* executors_and_keys, const RunState* run_state);
+ // Use the appropriate WaitForNotification function based on whether
+ // operation_timeout_in_ms is greater than 0.
+ void WaitForNotification(RunState* run_state, int64 timeout_in_ms);
+
const SessionOptions options_;
// Device structures.
@@ -242,6 +247,9 @@ class DirectSession : public Session {
// For generating step ids that are unique across all sessions.
static std::atomic_int_fast64_t step_id_counter_;
+ // Global timeout for all blocking operations in this session.
+ const int64 operation_timeout_in_ms_ = 0;
+
TF_DISALLOW_COPY_AND_ASSIGN(DirectSession);
};
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index 5c3f83b399..0cad812300 100644
--- a/tensorflow/core/common_runtime/direct_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -273,8 +273,8 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetworkWithOpts) {
RunOutputs run_outputs;
EXPECT_EQ(run_outputs.step_stats().dev_stats_size(), 0);
- Status s = session->RunWithOpts(run_options, inputs, output_names,
- target_nodes, &outputs, &run_outputs);
+ Status s = session->Run(run_options, inputs, output_names, target_nodes,
+ &outputs, &run_outputs);
TF_ASSERT_OK(s);
ASSERT_EQ(1, outputs.size());
@@ -560,5 +560,99 @@ TEST(DirectSessionTest, PartialRunMultiOutputFeed) {
ASSERT_EQ(true, outputs[0].flat<bool>()(0));
}
+TEST(DirectSessionTest, TimeoutSession) {
+ GraphDef graph;
+ // Creates a graph with one FIFOQueue and one dequeue op.
+ protobuf::TextFormat::ParseFromString(R"proto(
+ node {
+ name: 'fifo_queue'
+ op: 'FIFOQueue'
+ device: '/device:CPU:0'
+ attr {
+ key: 'capacity'
+ value {
+ i: 10
+ }
+ }
+ attr {
+ key: 'component_types'
+ value {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ key: 'container'
+ value {
+ s: ''
+ }
+ }
+ attr {
+ key: 'shapes'
+ value {
+ list {
+ }
+ }
+ }
+ attr {
+ key: 'shared_name'
+ value {
+ s: ''
+ }
+ }
+ }
+ node {
+ name: 'fifo_queue_Dequeue'
+ op: 'QueueDequeue'
+ input: 'fifo_queue'
+ device: '/device:CPU:0'
+ attr {
+ key: 'component_types'
+ value {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ key: 'timeout_ms'
+ value {
+ i: -1
+ }
+ }
+ }
+ versions {
+ producer: 9
+ }
+ )proto",
+ &graph);
+
+ // Creates a session with operation_timeout_in_ms set to 100 milliseconds.
+ SessionOptions options;
+ (*options.config.mutable_device_count())["CPU"] = 2;
+ options.config.set_operation_timeout_in_ms(100);
+ std::unique_ptr<Session> session(NewSession(options));
+ ASSERT_TRUE(session != nullptr);
+ TF_ASSERT_OK(session->Create(graph));
+
+ // Verifies that the error code is DEADLINE_EXCEEDED.
+ Status s = session->Run({}, {}, {"fifo_queue_Dequeue"}, nullptr);
+ ASSERT_EQ(error::DEADLINE_EXCEEDED, s.code());
+ session->Close();
+
+ // Creates a session with no operation_timeout_in_ms.
+ session.reset(CreateSession());
+ ASSERT_TRUE(session != nullptr);
+ TF_ASSERT_OK(session->Create(graph));
+ RunOptions run_options;
+ run_options.set_timeout_in_ms(20);
+ // Verifies that the error code is DEADLINE_EXCEEDED.
+ Status s2 = session->Run(run_options, {}, {}, {"fifo_queue_Dequeue"}, nullptr,
+ nullptr);
+ ASSERT_EQ(error::DEADLINE_EXCEEDED, s2.code());
+ session->Close();
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/call_options.cc b/tensorflow/core/distributed_runtime/call_options.cc
index b9d583b754..a99cec9205 100644
--- a/tensorflow/core/distributed_runtime/call_options.cc
+++ b/tensorflow/core/distributed_runtime/call_options.cc
@@ -41,4 +41,14 @@ void CallOptions::ClearCancelCallback() {
cancel_func_ = nullptr;
}
+int64 CallOptions::GetTimeout() {
+ mutex_lock l(mu_);
+ return timeout_in_ms_;
+}
+
+void CallOptions::SetTimeout(int64 ms) {
+ mutex_lock l(mu_);
+ timeout_in_ms_ = ms;
+}
+
} // end namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/call_options.h b/tensorflow/core/distributed_runtime/call_options.h
index de0b85f692..6a9967a35b 100644
--- a/tensorflow/core/distributed_runtime/call_options.h
+++ b/tensorflow/core/distributed_runtime/call_options.h
@@ -60,10 +60,17 @@ class CallOptions {
void SetCancelCallback(CancelFunction cancel_func);
void ClearCancelCallback();
+ // Get and set operation timeout. Timeout value is in milliseconds.
+ int64 GetTimeout();
+ void SetTimeout(int64 ms);
+
private:
mutex mu_;
CancelFunction cancel_func_ GUARDED_BY(mu_);
+ // RPC operation timeout in milliseconds.
+ int64 timeout_in_ms_ GUARDED_BY(mu_);
+
TF_DISALLOW_COPY_AND_ASSIGN(CallOptions);
};
diff --git a/tensorflow/core/distributed_runtime/master_interface.h b/tensorflow/core/distributed_runtime/master_interface.h
index 602cfbd8a3..365447b657 100644
--- a/tensorflow/core/distributed_runtime/master_interface.h
+++ b/tensorflow/core/distributed_runtime/master_interface.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_INTERFACE_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_INTERFACE_H_
+#include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/master.pb.h"
@@ -28,22 +29,27 @@ namespace tensorflow {
class MasterInterface {
public:
virtual ~MasterInterface() {}
- virtual Status CreateSession(const CreateSessionRequest* request,
+ virtual Status CreateSession(CallOptions* call_options,
+ const CreateSessionRequest* request,
CreateSessionResponse* response) = 0;
- virtual Status ExtendSession(const ExtendSessionRequest* request,
+ virtual Status ExtendSession(CallOptions* call_options,
+ const ExtendSessionRequest* request,
ExtendSessionResponse* response) = 0;
- virtual Status RunStep(const RunStepRequest* request,
+ virtual Status RunStep(CallOptions* call_options,
+ const RunStepRequest* request,
RunStepResponse* response) = 0;
- virtual Status CloseSession(const CloseSessionRequest* request,
+ virtual Status CloseSession(CallOptions* call_options,
+ const CloseSessionRequest* request,
CloseSessionResponse* response) = 0;
- virtual Status ListDevices(const ListDevicesRequest* request,
+ virtual Status ListDevices(CallOptions* call_options,
+ const ListDevicesRequest* request,
ListDevicesResponse* response) = 0;
- virtual Status Reset(const ResetRequest* request,
+ virtual Status Reset(CallOptions* call_options, const ResetRequest* request,
ResetResponse* response) = 0;
};
diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD
index b66719620d..7e2091d114 100644
--- a/tensorflow/core/distributed_runtime/rpc/BUILD
+++ b/tensorflow/core/distributed_runtime/rpc/BUILD
@@ -165,6 +165,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:master_proto_cc",
"//tensorflow/core:master_service_proto_cc",
+ "//tensorflow/core/distributed_runtime:call_options",
"//tensorflow/core/distributed_runtime:master_interface",
],
alwayslink = 1,
@@ -313,6 +314,7 @@ cc_library(
"//tensorflow/core:master_proto_cc",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
+ "//tensorflow/core/distributed_runtime:call_options",
"//tensorflow/core/distributed_runtime:master_interface",
],
alwayslink = 1,
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc
index e358aed31f..0b084a0b02 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h"
+#include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/distributed_runtime/master_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -33,43 +34,60 @@ class GrpcRemoteMaster : public MasterInterface {
~GrpcRemoteMaster() override {}
- Status CreateSession(const CreateSessionRequest* request,
+ Status CreateSession(CallOptions* call_options,
+ const CreateSessionRequest* request,
CreateSessionResponse* response) override {
::grpc::ClientContext ctx;
+ SetDeadline(&ctx, call_options->GetTimeout());
return FromGrpcStatus(stub_->CreateSession(&ctx, *request, response));
}
- Status ExtendSession(const ExtendSessionRequest* request,
+ Status ExtendSession(CallOptions* call_options,
+ const ExtendSessionRequest* request,
ExtendSessionResponse* response) override {
::grpc::ClientContext ctx;
+ SetDeadline(&ctx, call_options->GetTimeout());
return FromGrpcStatus(stub_->ExtendSession(&ctx, *request, response));
}
- Status RunStep(const RunStepRequest* request,
+ Status RunStep(CallOptions* call_options, const RunStepRequest* request,
RunStepResponse* response) override {
::grpc::ClientContext ctx;
+ SetDeadline(&ctx, call_options->GetTimeout());
return FromGrpcStatus(stub_->RunStep(&ctx, *request, response));
}
- Status CloseSession(const CloseSessionRequest* request,
+ Status CloseSession(CallOptions* call_options,
+ const CloseSessionRequest* request,
CloseSessionResponse* response) override {
::grpc::ClientContext ctx;
+ SetDeadline(&ctx, call_options->GetTimeout());
return FromGrpcStatus(stub_->CloseSession(&ctx, *request, response));
}
- Status ListDevices(const ListDevicesRequest* request,
+ Status ListDevices(CallOptions* call_options,
+ const ListDevicesRequest* request,
ListDevicesResponse* response) override {
::grpc::ClientContext ctx;
+ SetDeadline(&ctx, call_options->GetTimeout());
return FromGrpcStatus(stub_->ListDevices(&ctx, *request, response));
}
- Status Reset(const ResetRequest* request, ResetResponse* response) override {
+ Status Reset(CallOptions* call_options, const ResetRequest* request,
+ ResetResponse* response) override {
::grpc::ClientContext ctx;
+ SetDeadline(&ctx, call_options->GetTimeout());
return FromGrpcStatus(stub_->Reset(&ctx, *request, response));
}
private:
std::unique_ptr<grpc::MasterService::Stub> stub_;
+
+ void SetDeadline(::grpc::ClientContext* ctx, int64 time_in_ms) {
+ if (time_in_ms > 0) {
+ ctx->set_deadline(gpr_time_from_millis(time_in_ms, GPR_TIMESPAN));
+ }
+ }
};
MasterInterface* NewGrpcMaster(SharedGrpcChannelPtr channel) {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
index 6924fc5537..30edc5d062 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <unordered_map>
#include "tensorflow/core/common_runtime/session_factory.h"
+#include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/distributed_runtime/master_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h"
@@ -67,7 +68,8 @@ void ReEncodeConsts(GraphDef* gdef) {
}
} // namespace
-Status GrpcSession::Create(const GraphDef& graph) {
+Status GrpcSession::CreateImpl(CallOptions* call_options,
+ const GraphDef& graph) {
if (!handle_.empty()) {
return errors::InvalidArgument("A session is alive.");
}
@@ -76,7 +78,7 @@ Status GrpcSession::Create(const GraphDef& graph) {
*req.mutable_graph_def() = graph;
ReEncodeConsts(req.mutable_graph_def());
CreateSessionResponse resp;
- Status s = master_->CreateSession(&req, &resp);
+ Status s = master_->CreateSession(call_options, &req, &resp);
if (s.ok()) {
mutex_lock l(mu_);
swap(handle_, *(resp.mutable_session_handle()));
@@ -85,7 +87,21 @@ Status GrpcSession::Create(const GraphDef& graph) {
return s;
}
-Status GrpcSession::Extend(const GraphDef& graph) {
+Status GrpcSession::Create(const GraphDef& graph) {
+ CallOptions call_options;
+ call_options.SetTimeout(options_.config.operation_timeout_in_ms());
+ return CreateImpl(&call_options, graph);
+}
+
+Status GrpcSession::Create(const RunOptions& run_options,
+ const GraphDef& graph) {
+ CallOptions call_options;
+ call_options.SetTimeout(run_options.timeout_in_ms());
+ return CreateImpl(&call_options, graph);
+}
+
+Status GrpcSession::ExtendImpl(CallOptions* call_options,
+ const GraphDef& graph) {
if (handle_.empty()) {
// Session was unitialized, so simply initialize the session with 'graph'.
return Create(graph);
@@ -96,17 +112,31 @@ Status GrpcSession::Extend(const GraphDef& graph) {
*req.mutable_graph_def() = graph;
req.set_current_graph_version(current_graph_version_);
ExtendSessionResponse resp;
- Status s = master_->ExtendSession(&req, &resp);
+ Status s = master_->ExtendSession(call_options, &req, &resp);
if (s.ok()) {
current_graph_version_ = resp.new_graph_version();
}
return s;
}
-Status GrpcSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
- const std::vector<string>& output_names,
- const std::vector<string>& target_nodes,
- std::vector<Tensor>* outputs) {
+Status GrpcSession::Extend(const GraphDef& graph) {
+ CallOptions call_options;
+ call_options.SetTimeout(options_.config.operation_timeout_in_ms());
+ return ExtendImpl(&call_options, graph);
+}
+
+Status GrpcSession::Extend(const RunOptions& run_options,
+ const GraphDef& graph) {
+ CallOptions call_options;
+ call_options.SetTimeout(run_options.timeout_in_ms());
+ return ExtendImpl(&call_options, graph);
+}
+
+Status GrpcSession::Run(const RunOptions& run_options,
+ const std::vector<std::pair<string, Tensor>>& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs, RunOutputs* run_outputs) {
// Convert to proto
RunStepRequest req;
RunStepResponse resp;
@@ -121,19 +151,21 @@ Status GrpcSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
// Build an index from fetch tensor name to offset.
std::unordered_map<string, int> output_name_to_offset;
- for (const string& output_name : output_names) {
+ for (const string& output_name : output_tensor_names) {
req.add_fetch(output_name);
output_name_to_offset.insert(
std::make_pair(output_name, output_name_to_offset.size()));
}
- for (const string& target : target_nodes) {
+ for (const string& target : target_node_names) {
req.add_target(target);
}
- TF_RETURN_IF_ERROR(RunProto(&req, &resp));
+ CallOptions call_options;
+ call_options.SetTimeout(run_options.timeout_in_ms());
+ TF_RETURN_IF_ERROR(RunProto(&call_options, &req, &resp));
- if (!output_names.empty()) {
- outputs->resize(output_names.size());
+ if (!output_tensor_names.empty()) {
+ outputs->resize(output_tensor_names.size());
}
// Convert response back to Tensors in the correct order.
@@ -156,13 +188,24 @@ Status GrpcSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
return Status::OK();
}
-Status GrpcSession::RunProto(RunStepRequest* req, RunStepResponse* resp) {
+Status GrpcSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs) {
+ RunOptions run_options;
+ run_options.set_timeout_in_ms(options_.config.operation_timeout_in_ms());
+ return Run(run_options, inputs, output_tensor_names, target_node_names,
+ outputs, nullptr);
+}
+
+Status GrpcSession::RunProto(CallOptions* call_options, RunStepRequest* req,
+ RunStepResponse* resp) {
if (handle_.empty()) {
return errors::InvalidArgument("A session is not created yet....");
}
req->set_session_handle(handle_);
- return master_->RunStep(req, resp);
+ return master_->RunStep(call_options, req, resp);
}
Status GrpcSession::PRunSetup(const std::vector<string>& input_names,
@@ -187,7 +230,9 @@ Status GrpcSession::Close() {
req.set_session_handle(handle_);
handle_.clear();
CloseSessionResponse resp;
- return master_->CloseSession(&req, &resp);
+ CallOptions call_options;
+ call_options.SetTimeout(options_.config.operation_timeout_in_ms());
+ return master_->CloseSession(&call_options, &req, &resp);
}
std::vector<DeviceAttributes> GrpcSession::ListDevices() {
@@ -195,7 +240,9 @@ std::vector<DeviceAttributes> GrpcSession::ListDevices() {
ListDevicesRequest req;
ListDevicesResponse resp;
- Status s = master_->ListDevices(&req, &resp);
+ CallOptions call_options;
+ call_options.SetTimeout(options_.config.operation_timeout_in_ms());
+ Status s = master_->ListDevices(&call_options, &req, &resp);
if (!s.ok()) {
LOG(ERROR) << "Could not list devices: " << s;
return devices;
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.h b/tensorflow/core/distributed_runtime/rpc/grpc_session.h
index 9bc6034ba6..abf7b2a44a 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.h
@@ -19,6 +19,8 @@ limitations under the License.
#include <string>
#include <vector>
+#include "tensorflow/core/distributed_runtime/call_options.h"
+#include "tensorflow/core/framework/config.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -52,13 +54,22 @@ class GrpcSession : public Session {
// the graph computation defined by "graph", and will have version
// number "initial_version".
Status Create(const GraphDef& graph) override;
+ Status Create(const RunOptions& run_options, const GraphDef& graph) override;
+ // Runs with and without RunOptions.
Status Run(const std::vector<std::pair<string, Tensor> >& inputs,
- const std::vector<string>& output_names,
- const std::vector<string>& target_nodes,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
std::vector<Tensor>* outputs) override;
+ Status Run(const RunOptions& run_options,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs, RunOutputs* run_outputs);
Status Extend(const GraphDef& graph) override;
+ Status Extend(const RunOptions& run_options, const GraphDef& graph) override;
+
Status Close() override;
// NOTE: This API is still experimental and may change.
@@ -87,7 +98,12 @@ class GrpcSession : public Session {
// The current version of the graph.
int64 current_graph_version_ GUARDED_BY(mu_);
- Status RunProto(RunStepRequest* req, RunStepResponse* resp);
+ Status RunProto(CallOptions* call_options, RunStepRequest* req,
+ RunStepResponse* resp);
+
+ // Implementations for all the public interfaces.
+ Status CreateImpl(CallOptions* call_options, const GraphDef& graph);
+ Status ExtendImpl(CallOptions* call_options, const GraphDef& graph);
TF_DISALLOW_COPY_AND_ASSIGN(GrpcSession);
};
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
index 86a9b07c2c..7b7507515e 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
@@ -746,5 +746,93 @@ TEST(SessionTest, ExtendValidation) {
EXPECT_NE(s.error_message().find("'b', which was created by a previous call"),
string::npos);
}
+// Tests that Create() with "operation_timeout_in_ms" set times out.
+TEST(SessionTest, CreateTimeoutWithSessionOptions) {
+ // Creates a RemoteSession with "operation_timeout_in_ms" set to 100.
+ SessionOptions options = Options("example.org", 1);
+ options.config.set_operation_timeout_in_ms(100);
+ std::unique_ptr<Session> session(NewRemote(options));
+
+ // Creates a long running op.
+ Graph graph(OpRegistry::Global());
+ Node* b = test::graph::Constant(&graph, Tensor());
+ test::graph::Delay(&graph, b, Microseconds(1000000));
+ GraphDef gdef;
+ test::graph::ToGraphDef(&graph, &gdef);
+ Status status = session->Create(gdef);
+ EXPECT_EQ(error::DEADLINE_EXCEEDED, status.code());
+}
+
+// Tests that Create() with "timeout_in_ms" in RunOptions set times out.
+TEST(SessionTest, CreateTimeoutWithRunOptions) {
+ SessionOptions options = Options("example.org", 1);
+ std::unique_ptr<Session> session(NewRemote(options));
+
+ // Creates a long running op.
+ Graph graph(OpRegistry::Global());
+ Node* b = test::graph::Constant(&graph, Tensor());
+ test::graph::Delay(&graph, b, Microseconds(1000000));
+ GraphDef gdef;
+ test::graph::ToGraphDef(&graph, &gdef);
+ RunOptions run_options;
+ // Sets RunOption timeout_in_ms to 20.
+ run_options.set_timeout_in_ms(20);
+ Status status = session->Create(run_options, gdef);
+ EXPECT_EQ(error::DEADLINE_EXCEEDED, status.code());
+}
+
+// Tests that Run() with "operation_timeout_in_ms" set times out.
+TEST(SessionTest, RunTimeoutWithSessionOptions) {
+ // Creates a RemoteSession with "operation_timeout_in_ms" set to 100.
+ std::unique_ptr<test::TestCluster> cluster;
+ TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 1, &cluster));
+ SessionOptions options = Options(cluster->targets()[0], 100);
+ options.config.set_operation_timeout_in_ms(1);
+ std::unique_ptr<Session> session(NewRemote(options));
+
+ // Creates a long running op.
+ Graph graph(OpRegistry::Global());
+ Node* b = test::graph::Constant(&graph, Tensor());
+ Node* b_delay = test::graph::Delay(&graph, b, Microseconds(2000000));
+ GraphDef gdef;
+ test::graph::ToGraphDef(&graph, &gdef);
+ RunOptions run_options;
+ TF_CHECK_OK(session->Create(run_options, gdef));
+
+ // Verifies that Run() times out, and the error code is DEADLINE_EXCEEDED.
+ std::vector<std::pair<string, Tensor>> inputs;
+ Status status = session->Run(inputs, {}, {b_delay->name()}, nullptr);
+ // TODO(sherrym): Due to potentially a GRPC bug, we sometimes get
+ // GRPC_CHTTP2_INTERNAL_ERROR which is mapped to error::INTERNAL.
+ EXPECT_TRUE(error::DEADLINE_EXCEEDED == status.code() ||
+ error::INTERNAL == status.code());
+}
+
+// Tests that Run() with "timeout_in_ms" set times out.
+TEST(SessionTest, RunTimeoutWithRunOptions) {
+ std::unique_ptr<test::TestCluster> cluster;
+ TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 1, &cluster));
+ SessionOptions options = Options(cluster->targets()[0], 1);
+ std::unique_ptr<Session> session(NewRemote(options));
+
+ // Creates a long running op.
+ Graph graph(OpRegistry::Global());
+ Node* b = test::graph::Constant(&graph, Tensor());
+ Node* b_delay = test::graph::Delay(&graph, b, Microseconds(1000000));
+ GraphDef gdef;
+ test::graph::ToGraphDef(&graph, &gdef);
+ TF_CHECK_OK(session->Create(gdef));
+
+ // Verifies that Run() times out, and the error code is DEADLINE_EXCEEDED.
+ std::vector<std::pair<string, Tensor>> inputs;
+ RunOptions run_options;
+ run_options.set_timeout_in_ms(100);
+ Status status = session->Run(run_options, inputs, {}, {b_delay->name()},
+ nullptr, nullptr);
+ // TODO(sherrym): Due to potentially a GRPC bug, we sometimes get
+ // GRPC_CHTTP2_INTERNAL_ERROR which is mapped to error::INTERNAL.
+ EXPECT_TRUE(error::DEADLINE_EXCEEDED == status.code() ||
+ error::INTERNAL == status.code());
+}
} // namespace tensorflow
diff --git a/tensorflow/core/framework/config.proto b/tensorflow/core/framework/config.proto
index 7482d14d75..c6da869cce 100644
--- a/tensorflow/core/framework/config.proto
+++ b/tensorflow/core/framework/config.proto
@@ -131,6 +131,11 @@ message ConfigProto {
// Options that apply to all graphs.
GraphOptions graph_options = 10;
+
+ // Global timeout for all blocking operations in this session. If non-zero,
+ // and not overridden on a per-operation basis, this value will be used as the
+ // deadline for all blocking operations.
+ int64 operation_timeout_in_ms = 11;
};
// EXPERIMENTAL. Options for a single Run() call.
@@ -140,6 +145,9 @@ message RunOptions {
FULL_TRACE = 1;
}
TraceLevel trace_level = 1;
+
+ // Time to wait for operation to complete in milliseconds.
+ int64 timeout_in_ms = 2;
}
// EXPERIMENTAL. Metadata output (i.e., non-Tensor) for a single Run() call.
diff --git a/tensorflow/core/lib/core/notification.h b/tensorflow/core/lib/core/notification.h
index fe1e400ad5..464236a051 100644
--- a/tensorflow/core/lib/core/notification.h
+++ b/tensorflow/core/lib/core/notification.h
@@ -17,6 +17,8 @@ limitations under the License.
#define TENSORFLOW_UTIL_NOTIFICATION_H_
#include <assert.h>
+#include <chrono> // NOLINT
+#include <condition_variable> // NOLINT
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
@@ -47,6 +49,13 @@ class Notification {
}
}
+ bool WaitForNotificationWithTimeout(int64 timeout_in_ms) {
+ mutex_lock l(mu_);
+ std::cv_status s =
+ cv_.wait_for(l, std::chrono::milliseconds(timeout_in_ms));
+ return (s == std::cv_status::timeout) ? true : false;
+ }
+
private:
mutex mu_;
condition_variable cv_;
diff --git a/tensorflow/core/public/session.h b/tensorflow/core/public/session.h
index 6b64f4671e..3c084b73cf 100644
--- a/tensorflow/core/public/session.h
+++ b/tensorflow/core/public/session.h
@@ -115,16 +115,34 @@ class Session {
const std::vector<string>& target_node_names,
std::vector<Tensor>* outputs) = 0;
+ /// \brief Implementations which support `RunOptions`.
+ //
+ /// NOTE: This API is still experimental and may change.
+ virtual Status Create(const RunOptions& run_options, const GraphDef& graph) {
+ return errors::Unimplemented(
+ "Create(const RunOptions& run_options, const GraphDef& graph) is not "
+ "supported for this session.");
+ }
+ virtual Status Extend(const RunOptions& run_options, const GraphDef& graph) {
+ return errors::Unimplemented(
+ "Extend(const RunOptions& run_options, const GraphDef& graph) is not "
+ "supported for this session.");
+ }
+ virtual Status Close(const RunOptions& run_options) {
+ return errors::Unimplemented(
+ "Close(const RunOptions& run_options) is not supported for this "
+ "session.");
+ }
+
/// \brief Like `Run`, but allows users to pass in a `RunOptions` proto and
/// to retrieve non-Tensor metadata output via a `RunOutputs` proto for this
/// step.
/// NOTE: This API is still experimental and may change.
- virtual Status RunWithOpts(
- const RunOptions& run_options,
- const std::vector<std::pair<string, Tensor> >& inputs,
- const std::vector<string>& output_tensor_names,
- const std::vector<string>& target_node_names,
- std::vector<Tensor>* outputs, RunOutputs* run_outputs) {
+ virtual Status Run(const RunOptions& run_options,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs, RunOutputs* run_outputs) {
return errors::Unimplemented(
"RunWithOpts() is not supported for this session.");
}
diff --git a/tensorflow/python/kernel_tests/fifo_queue_test.py b/tensorflow/python/kernel_tests/fifo_queue_test.py
index 22a8b07dba..af232f65cc 100644
--- a/tensorflow/python/kernel_tests/fifo_queue_test.py
+++ b/tensorflow/python/kernel_tests/fifo_queue_test.py
@@ -25,6 +25,7 @@ import time
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
+from tensorflow.python.pywrap_tensorflow import StatusNotOK
class FIFOQueueTest(tf.test.TestCase):
@@ -1150,5 +1151,20 @@ class FIFOQueueTest(tf.test.TestCase):
self.assertAllEqual(input_elem, output_elem)
+class FIFOQueueWithTimeoutTest(tf.test.TestCase):
+
+ def testDequeueWithTimeout(self):
+ with self.test_session(
+ config=tf.ConfigProto(operation_timeout_in_ms=20)) as sess:
+ q = tf.FIFOQueue(10, tf.float32)
+ dequeued_t = q.dequeue()
+
+ # Intentionally do not run any enqueue_ops so that dequeue will block
+ # until operation_timeout_in_ms.
+ with self.assertRaisesRegexp(StatusNotOK,
+ "Timed out waiting for notification"):
+ sess.run(dequeued_t)
+
+
if __name__ == "__main__":
tf.test.main()