aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2016-11-18 18:04:45 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-18 18:24:12 -0800
commitd94457f55b23dcc8eedb7c236a923031d7b51409 (patch)
tree2f41cda608e31a5a587e7ced514f416392f9eb22
parentaf5b2a0b154763e6e1cea7b282af0e099ef37869 (diff)
Partial run support for GRPC runtime.
Tests for distributed partial run added in session_test.py. Change: 139644597
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc4
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc11
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc20
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h23
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc9
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_session.cc57
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_session.h7
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc149
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/client/session_test.py201
10 files changed, 373 insertions, 109 deletions
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index 1918eae875..0adfd69ab2 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -1306,9 +1306,9 @@ Status MasterSession::DoPartialRun(CallOptions* opts, const RunStepRequest* req,
LOG(ERROR) << "Cleanup partition error: " << s;
}
rcg->Unref();
- mutex_lock l(mu_);
- partial_runs_.erase(prun_handle);
});
+ mutex_lock l(mu_);
+ partial_runs_.erase(prun_handle);
}
return s;
}
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
index c8a0892842..0f2990e588 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
@@ -105,6 +105,7 @@ class GrpcMasterService : public AsyncServiceInterface {
ENQUEUE_REQUEST(CreateSession, true);
ENQUEUE_REQUEST(ExtendSession, false);
for (int i = 0; i < 100; ++i) {
+ ENQUEUE_REQUEST(PartialRunSetup, false);
ENQUEUE_REQUEST(RunStep, true);
}
ENQUEUE_REQUEST(CloseSession, false);
@@ -159,6 +160,16 @@ class GrpcMasterService : public AsyncServiceInterface {
ENQUEUE_REQUEST(ExtendSession, false);
}
+ // RPC handler for setting up a partial run call.
+ void PartialRunSetupHandler(
+ MasterCall<PartialRunSetupRequest, PartialRunSetupResponse>* call) {
+ master_impl_->PartialRunSetup(&call->request, &call->response,
+ [call](const Status& status) {
+ call->SendResponse(ToGrpcStatus(status));
+ });
+ ENQUEUE_REQUEST(PartialRunSetup, false);
+ }
+
// RPC handler for running one step in a session.
void RunStepHandler(MasterCall<RunStepRequest, RunStepResponse>* call) {
CallOptions* call_opts = new CallOptions;
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc
index d3cb72730c..c42622dd50 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc
@@ -31,6 +31,7 @@ namespace grpc {
static const char* grpcMasterService_method_names[] = {
"/tensorflow.MasterService/CreateSession",
"/tensorflow.MasterService/ExtendSession",
+ "/tensorflow.MasterService/PartialRunSetup",
"/tensorflow.MasterService/RunStep",
"/tensorflow.MasterService/CloseSession",
"/tensorflow.MasterService/ListDevices",
@@ -51,13 +52,15 @@ MasterService::Stub::Stub(
::grpc::RpcMethod::NORMAL_RPC, channel),
rpcmethod_ExtendSession_(grpcMasterService_method_names[1],
::grpc::RpcMethod::NORMAL_RPC, channel),
- rpcmethod_RunStep_(grpcMasterService_method_names[2],
+ rpcmethod_PartialRunSetup_(grpcMasterService_method_names[2],
+ ::grpc::RpcMethod::NORMAL_RPC, channel),
+ rpcmethod_RunStep_(grpcMasterService_method_names[3],
::grpc::RpcMethod::NORMAL_RPC, channel),
- rpcmethod_CloseSession_(grpcMasterService_method_names[3],
+ rpcmethod_CloseSession_(grpcMasterService_method_names[4],
::grpc::RpcMethod::NORMAL_RPC, channel),
- rpcmethod_ListDevices_(grpcMasterService_method_names[4],
+ rpcmethod_ListDevices_(grpcMasterService_method_names[5],
::grpc::RpcMethod::NORMAL_RPC, channel),
- rpcmethod_Reset_(grpcMasterService_method_names[5],
+ rpcmethod_Reset_(grpcMasterService_method_names[6],
::grpc::RpcMethod::NORMAL_RPC, channel) {}
::grpc::Status MasterService::Stub::CreateSession(
@@ -74,6 +77,13 @@ MasterService::Stub::Stub(
context, request, response);
}
+::grpc::Status MasterService::Stub::PartialRunSetup(
+ ::grpc::ClientContext* context, const PartialRunSetupRequest& request,
+ PartialRunSetupResponse* response) {
+ return ::grpc::BlockingUnaryCall(channel_.get(), rpcmethod_PartialRunSetup_,
+ context, request, response);
+}
+
::grpc::Status MasterService::Stub::RunStep(::grpc::ClientContext* context,
const RunStepRequest& request,
RunStepResponse* response) {
@@ -103,7 +113,7 @@ MasterService::Stub::Stub(
}
MasterService::AsyncService::AsyncService() {
- for (int i = 0; i < 6; ++i) {
+ for (int i = 0; i < 7; ++i) {
AddMethod(new ::grpc::RpcServiceMethod(grpcMasterService_method_names[i],
::grpc::RpcMethod::NORMAL_RPC,
nullptr));
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h
index afe4b583f8..a3a2ac8020 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h
@@ -64,6 +64,9 @@ class MasterService GRPC_FINAL {
virtual ::grpc::Status ExtendSession(::grpc::ClientContext* context,
const ExtendSessionRequest& request,
ExtendSessionResponse* response) = 0;
+ virtual ::grpc::Status PartialRunSetup(
+ ::grpc::ClientContext* context, const PartialRunSetupRequest& request,
+ PartialRunSetupResponse* response) = 0;
virtual ::grpc::Status RunStep(::grpc::ClientContext* context,
const RunStepRequest& request,
RunStepResponse* response) = 0;
@@ -86,6 +89,9 @@ class MasterService GRPC_FINAL {
::grpc::Status ExtendSession(::grpc::ClientContext* context,
const ExtendSessionRequest& request,
ExtendSessionResponse* response) GRPC_OVERRIDE;
+ ::grpc::Status PartialRunSetup(
+ ::grpc::ClientContext* context, const PartialRunSetupRequest& request,
+ PartialRunSetupResponse* response) GRPC_OVERRIDE;
::grpc::Status RunStep(::grpc::ClientContext* context,
const RunStepRequest& request,
RunStepResponse* response) GRPC_OVERRIDE;
@@ -103,6 +109,7 @@ class MasterService GRPC_FINAL {
std::shared_ptr< ::grpc::ChannelInterface> channel_;
const ::grpc::RpcMethod rpcmethod_CreateSession_;
const ::grpc::RpcMethod rpcmethod_ExtendSession_;
+ const ::grpc::RpcMethod rpcmethod_PartialRunSetup_;
const ::grpc::RpcMethod rpcmethod_RunStep_;
const ::grpc::RpcMethod rpcmethod_CloseSession_;
const ::grpc::RpcMethod rpcmethod_ListDevices_;
@@ -132,12 +139,20 @@ class MasterService GRPC_FINAL {
::grpc::Service::RequestAsyncUnary(1, context, request, response,
new_call_cq, notification_cq, tag);
}
+ void RequestPartialRunSetup(
+ ::grpc::ServerContext* context, PartialRunSetupRequest* request,
+ ::grpc::ServerAsyncResponseWriter<PartialRunSetupResponse>* response,
+ ::grpc::CompletionQueue* new_call_cq,
+ ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
+ ::grpc::Service::RequestAsyncUnary(2, context, request, response,
+ new_call_cq, notification_cq, tag);
+ }
void RequestRunStep(
::grpc::ServerContext* context, RunStepRequest* request,
::grpc::ServerAsyncResponseWriter<RunStepResponse>* response,
::grpc::CompletionQueue* new_call_cq,
::grpc::ServerCompletionQueue* notification_cq, void* tag) {
- ::grpc::Service::RequestAsyncUnary(2, context, request, response,
+ ::grpc::Service::RequestAsyncUnary(3, context, request, response,
new_call_cq, notification_cq, tag);
}
void RequestCloseSession(
@@ -145,7 +160,7 @@ class MasterService GRPC_FINAL {
::grpc::ServerAsyncResponseWriter<CloseSessionResponse>* response,
::grpc::CompletionQueue* new_call_cq,
::grpc::ServerCompletionQueue* notification_cq, void* tag) {
- ::grpc::Service::RequestAsyncUnary(3, context, request, response,
+ ::grpc::Service::RequestAsyncUnary(4, context, request, response,
new_call_cq, notification_cq, tag);
}
void RequestListDevices(
@@ -153,7 +168,7 @@ class MasterService GRPC_FINAL {
::grpc::ServerAsyncResponseWriter<ListDevicesResponse>* response,
::grpc::CompletionQueue* new_call_cq,
::grpc::ServerCompletionQueue* notification_cq, void* tag) {
- ::grpc::Service::RequestAsyncUnary(4, context, request, response,
+ ::grpc::Service::RequestAsyncUnary(5, context, request, response,
new_call_cq, notification_cq, tag);
}
void RequestReset(
@@ -161,7 +176,7 @@ class MasterService GRPC_FINAL {
::grpc::ServerAsyncResponseWriter<ResetResponse>* response,
::grpc::CompletionQueue* new_call_cq,
::grpc::ServerCompletionQueue* notification_cq, void* tag) {
- ::grpc::Service::RequestAsyncUnary(5, context, request, response,
+ ::grpc::Service::RequestAsyncUnary(6, context, request, response,
new_call_cq, notification_cq, tag);
}
};
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc
index 3af1ee9cce..879f177da5 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc
@@ -52,6 +52,15 @@ class GrpcRemoteMaster : public MasterInterface {
return FromGrpcStatus(stub_->ExtendSession(&ctx, *request, response));
}
+ Status PartialRunSetup(CallOptions* call_options,
+ const PartialRunSetupRequest* request,
+ PartialRunSetupResponse* response) override {
+ ::grpc::ClientContext ctx;
+ ctx.set_fail_fast(false);
+ SetDeadline(&ctx, call_options->GetTimeout());
+ return FromGrpcStatus(stub_->PartialRunSetup(&ctx, *request, response));
+ }
+
Status RunStep(CallOptions* call_options, const RunStepRequest* request,
RunStepResponse* response) override {
::grpc::ClientContext ctx;
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
index 886f04a756..bb2fbb7fea 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
@@ -152,18 +152,22 @@ Status GrpcSession::Extend(const RunOptions& run_options,
return ExtendImpl(&call_options, graph);
}
-Status GrpcSession::Run(const RunOptions& run_options,
- const std::vector<std::pair<string, Tensor>>& inputs,
- const std::vector<string>& output_tensor_names,
- const std::vector<string>& target_node_names,
- std::vector<Tensor>* outputs,
- RunMetadata* run_metadata) {
+Status GrpcSession::RunHelper(
+ const RunOptions& run_options,
+ const std::vector<std::pair<string, Tensor>>& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names, std::vector<Tensor>* outputs,
+ RunMetadata* run_metadata, const string& prun_handle) {
// Convert to proto
RunStepRequest req;
RunStepResponse resp;
*req.mutable_options() = run_options;
+ if (!prun_handle.empty()) {
+ req.set_partial_run_handle(prun_handle);
+ }
+
for (const auto& it : inputs) {
Tensor input_tensor = it.second;
auto feed = req.add_feed();
@@ -215,6 +219,16 @@ Status GrpcSession::Run(const RunOptions& run_options,
return Status::OK();
}
+Status GrpcSession::Run(const RunOptions& run_options,
+ const std::vector<std::pair<string, Tensor>>& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs,
+ RunMetadata* run_metadata) {
+ return RunHelper(run_options, inputs, output_tensor_names, target_node_names,
+ outputs, run_metadata, /* prun_handle */ "");
+}
+
Status GrpcSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
const std::vector<string>& output_tensor_names,
const std::vector<string>& target_node_names,
@@ -242,14 +256,41 @@ Status GrpcSession::PRunSetup(const std::vector<string>& input_names,
const std::vector<string>& output_names,
const std::vector<string>& target_nodes,
string* handle) {
- return errors::Internal("Partial run is not supported for remote session.");
+ // Convert to proto
+ PartialRunSetupRequest req;
+ PartialRunSetupResponse resp;
+ CallOptions call_options;
+ {
+ mutex_lock l(mu_);
+ if (handle_.empty()) {
+ return errors::InvalidArgument("A session is not created yet....");
+ }
+
+ req.set_session_handle(handle_);
+ }
+ for (const string& feed : input_names) {
+ req.add_feed(feed);
+ }
+ for (const string& fetch : output_names) {
+ req.add_fetch(fetch);
+ }
+ for (const string& target : target_nodes) {
+ req.add_target(target);
+ }
+ call_options.SetTimeout(options_.config.operation_timeout_in_ms());
+ TF_RETURN_IF_ERROR(master_->PartialRunSetup(&call_options, &req, &resp));
+ *handle = resp.partial_run_handle();
+ return Status::OK();
}
Status GrpcSession::PRun(const string& handle,
const std::vector<std::pair<string, Tensor>>& inputs,
const std::vector<string>& output_names,
std::vector<Tensor>* outputs) {
- return errors::Internal("Partial run is not supported for remote session.");
+ RunOptions run_options;
+ run_options.set_timeout_in_ms(options_.config.operation_timeout_in_ms());
+ return RunHelper(run_options, inputs, output_names, /* targets */ {}, outputs,
+ /* run_metadata */ nullptr, handle);
}
Status GrpcSession::Close() {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.h b/tensorflow/core/distributed_runtime/rpc/grpc_session.h
index 0d532520cf..60a7b8334f 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.h
@@ -110,6 +110,13 @@ class GrpcSession : public Session {
// The current version of the graph.
int64 current_graph_version_ GUARDED_BY(mu_);
+ Status RunHelper(const RunOptions& run_options,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs, RunMetadata* run_metadata,
+ const string& prun_handle);
+
Status RunProto(CallOptions* call_options, RunStepRequest* req,
RunStepResponse* resp);
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
index ec8c06abb4..f6580cfd42 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
@@ -163,6 +163,41 @@ class GrpcWorkerService : public AsyncServiceInterface {
mutex mu_;
CancellationManager* cancellation_manager_ GUARDED_BY(mu_);
+ struct PartialRunState {
+ CancellationManager* cancellation_manager;
+ Notification executor_done;
+
+ explicit PartialRunState(CancellationManager* cm)
+ : cancellation_manager(cm) {}
+ };
+ std::unordered_map<std::pair<string, int>, std::unique_ptr<PartialRunState>,
+ hash<std::pair<string, int>>>
+ partial_runs_ GUARDED_BY(mu_);
+
+ PartialRunState* FindPartialRun(const string& graph_handle, int step_id) {
+ std::pair<string, int> k(graph_handle, step_id);
+ PartialRunState* prun_state = nullptr;
+ mutex_lock l(mu_);
+ auto it = partial_runs_.find(k);
+ if (it != partial_runs_.end()) {
+ prun_state = it->second.get();
+ }
+ return prun_state;
+ }
+
+ void InsertPartialRunLocked(const string& graph_handle, int step_id,
+ PartialRunState* partial_run_state)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ std::pair<string, int> k(graph_handle, step_id);
+ partial_runs_.emplace(
+ std::make_pair(k, std::unique_ptr<PartialRunState>(partial_run_state)));
+ }
+
+ void RemovePartialRun(const string& graph_handle, int step_id) {
+ std::pair<string, int> k(graph_handle, step_id);
+ mutex_lock l(mu_);
+ partial_runs_.erase(partial_runs_.find(k));
+ }
mutex shutdown_mu_;
bool is_shutdown_ GUARDED_BY(shutdown_mu_);
@@ -225,7 +260,11 @@ class GrpcWorkerService : public AsyncServiceInterface {
}
void RunGraphHandler(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
- env_->compute_pool->Schedule([this, call]() { DoRunGraph(call); });
+ if (call->request.is_partial()) {
+ env_->compute_pool->Schedule([this, call]() { DoPartialRunGraph(call); });
+ } else {
+ env_->compute_pool->Schedule([this, call]() { DoRunGraph(call); });
+ }
ENQUEUE_REQUEST(RunGraph, true);
}
@@ -294,10 +333,6 @@ class GrpcWorkerService : public AsyncServiceInterface {
Status PrepareRunGraph(const RunGraphRequest& req, GraphMgr::NamedTensors* in,
GraphMgr::NamedTensors* out) {
- if (req.is_partial()) {
- return errors::Unimplemented(
- "Partial run not implemented for GRPC worker service");
- }
if (req.send_size() > 0) {
// TODO(zhifengc): Let the caller decide on which device to
// allocate the tensor.
@@ -378,6 +413,110 @@ class GrpcWorkerService : public AsyncServiceInterface {
});
}
+ // TODO(suharshs): Add stats collection support to partial run.
+ void DoPartialRunGraph(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
+ const int64 step_id = call->request.step_id();
+ const string& graph_handle = call->request.graph_handle();
+ TRACEPRINTF("PartialRunGraph: %lld", step_id);
+ GraphMgr::NamedTensors in;
+ GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
+ Status s = PrepareRunGraph(call->request, &in, out);
+ auto finish = [this, call, out](const Status& s) {
+ delete out;
+ call->ClearCancelCallback();
+ call->SendResponse(ToGrpcStatus(s));
+ };
+ if (!s.ok()) {
+ finish(s);
+ return;
+ }
+
+ PartialRunState* partial_run_state = FindPartialRun(graph_handle, step_id);
+
+ CancellationManager* cm = nullptr;
+ // If this is a new partial run call we need to create a new cancellation
+ // manager.
+ // Otherwise we use the cancellation manager stored in the found partial
+ // run state.
+ if (partial_run_state == nullptr) {
+ cm = new CancellationManager;
+ } else {
+ cm = partial_run_state->cancellation_manager;
+ }
+
+ // Before we start doing anything, we set the RPC cancellation.
+ call->SetCancelCallback([this, cm, step_id]() {
+ cm->StartCancel();
+ AbortStep(step_id);
+ });
+
+ // If this is a new partial run request, the request will need to start the
+ // executors.
+ if (partial_run_state == nullptr) {
+ CancellationToken token;
+ {
+ mutex_lock l(mu_);
+ // Insert the new partial run into the partial_runs_ map.
+ partial_run_state = new PartialRunState(cm);
+ InsertPartialRunLocked(graph_handle, step_id, partial_run_state);
+ token = cancellation_manager_->get_cancellation_token();
+ cancellation_manager_->RegisterCallback(token,
+ [cm]() { cm->StartCancel(); });
+ }
+ env_->graph_mgr->ExecuteAsync(
+ graph_handle, step_id, call->request.exec_opts(),
+ nullptr /* collector */, nullptr /* cost_graph */, cm, in,
+ [this, step_id, graph_handle, token, partial_run_state](Status s) {
+ {
+ mutex_lock l(mu_);
+ cancellation_manager_->DeregisterCallback(token);
+ }
+ partial_run_state->executor_done.Notify();
+ // TODO(suharshs): Propagate the status once we keep state for
+ // each partial run call.
+ });
+ } else {
+ // Send the partial run's new inputs.
+ s = env_->graph_mgr->SendInputs(step_id, in);
+ if (!s.ok()) {
+ finish(s);
+ return;
+ }
+ }
+
+ // Receive the partial run's outputs.
+ s = env_->graph_mgr->RecvOutputs(step_id, out);
+ if (!s.ok()) {
+ finish(s);
+ return;
+ }
+
+ // Construct and return the resp.
+ for (const auto& p : *out) {
+ const string& key = p.first;
+ const Tensor& val = p.second;
+ auto* recv = call->response.add_recv();
+ recv->set_key(key);
+ // TODO(zhifengc): Deal with gpu -> cpu copy.
+ TensorProto* proto = recv->mutable_val();
+ val.AsProtoField(proto);
+ }
+
+ // If this is the last partial run request we must also wait for the entire
+ // graph execution to be completed.
+ if (call->request.is_last_partial_run()) {
+ partial_run_state->executor_done.WaitForNotification();
+ RemovePartialRun(graph_handle, step_id);
+ // Before deleting the cancellation manager on the final call, ensure
+ // that we clear the RPC cancel callback, which has a reference to the
+ // cancellation manager.
+ call->ClearCancelCallback();
+ delete cm;
+ }
+
+ finish(s);
+ }
+
// Helper for RecvTensor. Validates "key" and returns the source
// device in "*src_dev".
Status PrepareRecvTensor(const Rendezvous::ParsedKey& parsed,
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index cfa0e15d48..9f721274c9 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -2055,6 +2055,7 @@ py_test(
":math_ops",
":session",
":state_ops",
+ ":training",
":variables",
],
)
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index a20376b91d..8e3dd28c21 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -45,6 +45,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
+from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
@@ -1310,91 +1311,121 @@ class SessionTest(test_util.TensorFlowTestCase):
sess_2.run(c_1.op)
self.assertEqual(2.0, sess_2.run(c_2))
- def testPartialRun(self):
- with session.Session() as sess:
- a = array_ops.placeholder(dtypes.float32, shape=[])
- b = array_ops.placeholder(dtypes.float32, shape=[])
- c = array_ops.placeholder(dtypes.float32, shape=[])
- r1 = math_ops.add(a, b)
- r2 = math_ops.mul(r1, c)
-
- h = sess.partial_run_setup([r1, r2], [a, b, c])
- res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
- self.assertEqual(3, res)
- temp = res * 17
- res = sess.partial_run(h, r2, feed_dict={c: temp})
- self.assertEqual(153, res)
-
- # Call again on the same graph.
- h2 = sess.partial_run_setup([r1, r2], [a, b, c])
- res = sess.partial_run(h2, r1, feed_dict={a: 1, b: 2})
- self.assertEqual(3, res)
- temp = res * 18
- res = sess.partial_run(h2, r2, feed_dict={c: temp})
- self.assertEqual(162, res)
-
- def testPartialRunIncomplete(self):
- with session.Session() as sess:
- a = array_ops.placeholder(dtypes.float32, shape=[])
- b = array_ops.placeholder(dtypes.float32, shape=[])
- c = array_ops.placeholder(dtypes.float32, shape=[])
- r1 = math_ops.add(a, b)
- r2 = math_ops.mul(r1, c)
-
- h = sess.partial_run_setup([r1, r2], [a, b, c])
- res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
- self.assertEqual(3, res)
-
- def testConcurrentPartialRun(self):
- with session.Session() as sess:
- a = array_ops.placeholder(dtypes.float32, shape=[])
- b = array_ops.placeholder(dtypes.float32, shape=[])
- c = array_ops.placeholder(dtypes.float32, shape=[])
- r1 = math_ops.add(a, b)
- r2 = math_ops.mul(r1, c)
-
- h1 = sess.partial_run_setup([r1], [a, b, c])
- h2 = sess.partial_run_setup([r1, r2], [a, b, c])
- res = sess.partial_run(h1, r1, feed_dict={a: 1, b: 2})
- self.assertEqual(3, res)
- temp = res * 19
- res = sess.partial_run(h2, r1, feed_dict={a: temp, b: 9})
- self.assertEqual(66, res)
- res = sess.partial_run(h2, r2, feed_dict={c: 7})
- self.assertEqual(462, res)
-
- def testManyPartialRun(self):
- with session.Session() as sess:
- steps = 200
- inputs = []
- outputs = []
- a = constant_op.constant(2.0, dtypes.float32)
- for i in xrange(steps):
- inputs.append(array_ops.placeholder(dtypes.float32, shape=[]))
- a = math_ops.mul(a, inputs[i])
- outputs.append(a)
-
- h = sess.partial_run_setup(outputs, inputs)
- for i in xrange(steps):
- res = sess.partial_run(h, outputs[i], feed_dict={inputs[i]: 1.0})
- self.assertEqual(2.0, res)
-
- feed_dict = {}
- for i in xrange(steps):
- feed_dict[inputs[i]] = 1.0
- res = sess.run(outputs, feed_dict)
- self.assertEqual(steps, len(res))
- self.assertEqual(2.0, res[-1])
-
- def testRunAndPartialRun(self):
- with session.Session() as sess:
- a = constant_op.constant(2.0, dtypes.float32)
- b = a * 2
- c = b * 3
- r1 = sess.run([b, c])
- h = sess.partial_run_setup([b, c], [])
- r2 = sess.partial_run(h, [b, c])
- self.assertEqual(r1, r2)
+ def runTestPartialRun(self, sess):
+ a = array_ops.placeholder(dtypes.float32, shape=[])
+ b = array_ops.placeholder(dtypes.float32, shape=[])
+ c = array_ops.placeholder(dtypes.float32, shape=[])
+ r1 = math_ops.add(a, b)
+ r2 = math_ops.mul(r1, c)
+
+ h = sess.partial_run_setup([r1, r2], [a, b, c])
+ res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
+ self.assertEqual(3, res)
+ temp = res * 17
+ res = sess.partial_run(h, r2, feed_dict={c: temp})
+ self.assertEqual(153, res)
+
+ # Call again on the same graph.
+ h2 = sess.partial_run_setup([r1, r2], [a, b, c])
+ res = sess.partial_run(h2, r1, feed_dict={a: 1, b: 2})
+ self.assertEqual(3, res)
+ temp = res * 18
+ res = sess.partial_run(h2, r2, feed_dict={c: temp})
+ self.assertEqual(162, res)
+
+ def runTestPartialRunIncomplete(self, sess):
+ a = array_ops.placeholder(dtypes.float32, shape=[])
+ b = array_ops.placeholder(dtypes.float32, shape=[])
+ c = array_ops.placeholder(dtypes.float32, shape=[])
+ r1 = math_ops.add(a, b)
+ r2 = math_ops.mul(r1, c)
+
+ h = sess.partial_run_setup([r1, r2], [a, b, c])
+ res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
+ self.assertEqual(3, res)
+
+ def runTestConcurrentPartialRun(self, sess):
+ a = array_ops.placeholder(dtypes.float32, shape=[])
+ b = array_ops.placeholder(dtypes.float32, shape=[])
+ c = array_ops.placeholder(dtypes.float32, shape=[])
+ r1 = math_ops.add(a, b)
+ r2 = math_ops.mul(r1, c)
+
+ h1 = sess.partial_run_setup([r1], [a, b, c])
+ h2 = sess.partial_run_setup([r1, r2], [a, b, c])
+ res = sess.partial_run(h1, r1, feed_dict={a: 1, b: 2})
+ self.assertEqual(3, res)
+ temp = res * 19
+ res = sess.partial_run(h2, r1, feed_dict={a: temp, b: 9})
+ self.assertEqual(66, res)
+ res = sess.partial_run(h2, r2, feed_dict={c: 7})
+ self.assertEqual(462, res)
+
+ def runTestManyPartialRun(self, sess):
+ steps = 200
+ inputs = []
+ outputs = []
+ a = constant_op.constant(2.0, dtypes.float32)
+ for i in xrange(steps):
+ inputs.append(array_ops.placeholder(dtypes.float32, shape=[]))
+ a = math_ops.mul(a, inputs[i])
+ outputs.append(a)
+
+ h = sess.partial_run_setup(outputs, inputs)
+ for i in xrange(steps):
+ res = sess.partial_run(h, outputs[i], feed_dict={inputs[i]: 1.0})
+ self.assertEqual(2.0, res)
+
+ feed_dict = {}
+ for i in xrange(steps):
+ feed_dict[inputs[i]] = 1.0
+ res = sess.run(outputs, feed_dict)
+ self.assertEqual(steps, len(res))
+ self.assertEqual(2.0, res[-1])
+
+ def runTestRunAndPartialRun(self, sess):
+ a = constant_op.constant(2.0, dtypes.float32)
+ b = a * 2
+ c = b * 3
+ r1 = sess.run([b, c])
+ h = sess.partial_run_setup([b, c], [])
+ r2 = sess.partial_run(h, [b, c])
+ self.assertEqual(r1, r2)
+
+ def testPartialRunDirect(self):
+ self.runTestPartialRun(session.Session())
+
+ def testPartialRunIncompleteDirect(self):
+ self.runTestPartialRunIncomplete(session.Session())
+
+ def testConcurrentPartialRunDirect(self):
+ self.runTestConcurrentPartialRun(session.Session())
+
+ def testManyPartialRunDirect(self):
+ self.runTestManyPartialRun(session.Session())
+
+ def testRunAndPartialRunDirect(self):
+ self.runTestRunAndPartialRun(session.Session())
+
+ def testPartialRunDist(self):
+ server = server_lib.Server.create_local_server()
+ self.runTestPartialRun(session.Session(server.target))
+
+ def testPartialRunIncompleteDist(self):
+ server = server_lib.Server.create_local_server()
+ self.runTestPartialRunIncomplete(session.Session(server.target))
+
+ def testConcurrentPartialRunDist(self):
+ server = server_lib.Server.create_local_server()
+ self.runTestConcurrentPartialRun(session.Session(server.target))
+
+ def testManyPartialRunDist(self):
+ server = server_lib.Server.create_local_server()
+ self.runTestManyPartialRun(session.Session(server.target))
+
+ def testRunAndPartialRunDist(self):
+ server = server_lib.Server.create_local_server()
+ self.runTestRunAndPartialRun(session.Session(server.target))
def testFeedDictKeyException(self):
with session.Session() as sess: