diff options
author | 2016-11-18 18:04:45 -0800 | |
---|---|---|
committer | 2016-11-18 18:24:12 -0800 | |
commit | d94457f55b23dcc8eedb7c236a923031d7b51409 (patch) | |
tree | 2f41cda608e31a5a587e7ced514f416392f9eb22 | |
parent | af5b2a0b154763e6e1cea7b282af0e099ef37869 (diff) |
Partial run support for GRPC runtime.
Tests for distributed partial run added in session_test.py.
Change: 139644597
10 files changed, 373 insertions, 109 deletions
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index 1918eae875..0adfd69ab2 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -1306,9 +1306,9 @@ Status MasterSession::DoPartialRun(CallOptions* opts, const RunStepRequest* req, LOG(ERROR) << "Cleanup partition error: " << s; } rcg->Unref(); - mutex_lock l(mu_); - partial_runs_.erase(prun_handle); }); + mutex_lock l(mu_); + partial_runs_.erase(prun_handle); } return s; } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc index c8a0892842..0f2990e588 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc @@ -105,6 +105,7 @@ class GrpcMasterService : public AsyncServiceInterface { ENQUEUE_REQUEST(CreateSession, true); ENQUEUE_REQUEST(ExtendSession, false); for (int i = 0; i < 100; ++i) { + ENQUEUE_REQUEST(PartialRunSetup, false); ENQUEUE_REQUEST(RunStep, true); } ENQUEUE_REQUEST(CloseSession, false); @@ -159,6 +160,16 @@ class GrpcMasterService : public AsyncServiceInterface { ENQUEUE_REQUEST(ExtendSession, false); } + // RPC handler for setting up a partial run call. + void PartialRunSetupHandler( + MasterCall<PartialRunSetupRequest, PartialRunSetupResponse>* call) { + master_impl_->PartialRunSetup(&call->request, &call->response, + [call](const Status& status) { + call->SendResponse(ToGrpcStatus(status)); + }); + ENQUEUE_REQUEST(PartialRunSetup, false); + } + // RPC handler for running one step in a session. void RunStepHandler(MasterCall<RunStepRequest, RunStepResponse>* call) { CallOptions* call_opts = new CallOptions; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc index d3cb72730c..c42622dd50 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc @@ -31,6 +31,7 @@ namespace grpc { static const char* grpcMasterService_method_names[] = { "/tensorflow.MasterService/CreateSession", "/tensorflow.MasterService/ExtendSession", + "/tensorflow.MasterService/PartialRunSetup", "/tensorflow.MasterService/RunStep", "/tensorflow.MasterService/CloseSession", "/tensorflow.MasterService/ListDevices", @@ -51,13 +52,15 @@ MasterService::Stub::Stub( ::grpc::RpcMethod::NORMAL_RPC, channel), rpcmethod_ExtendSession_(grpcMasterService_method_names[1], ::grpc::RpcMethod::NORMAL_RPC, channel), - rpcmethod_RunStep_(grpcMasterService_method_names[2], + rpcmethod_PartialRunSetup_(grpcMasterService_method_names[2], + ::grpc::RpcMethod::NORMAL_RPC, channel), + rpcmethod_RunStep_(grpcMasterService_method_names[3], ::grpc::RpcMethod::NORMAL_RPC, channel), - rpcmethod_CloseSession_(grpcMasterService_method_names[3], + rpcmethod_CloseSession_(grpcMasterService_method_names[4], ::grpc::RpcMethod::NORMAL_RPC, channel), - rpcmethod_ListDevices_(grpcMasterService_method_names[4], + rpcmethod_ListDevices_(grpcMasterService_method_names[5], ::grpc::RpcMethod::NORMAL_RPC, channel), - rpcmethod_Reset_(grpcMasterService_method_names[5], + rpcmethod_Reset_(grpcMasterService_method_names[6], ::grpc::RpcMethod::NORMAL_RPC, channel) {} ::grpc::Status MasterService::Stub::CreateSession( @@ -74,6 +77,13 @@ MasterService::Stub::Stub( context, request, response); } +::grpc::Status MasterService::Stub::PartialRunSetup( + ::grpc::ClientContext* context, const PartialRunSetupRequest& request, + PartialRunSetupResponse* response) { + return ::grpc::BlockingUnaryCall(channel_.get(), rpcmethod_PartialRunSetup_, + context, request, response); +} + ::grpc::Status MasterService::Stub::RunStep(::grpc::ClientContext* context, const RunStepRequest& request, RunStepResponse* response) { @@ -103,7 +113,7 @@ MasterService::Stub::Stub( } MasterService::AsyncService::AsyncService() { - for (int i = 0; i < 6; ++i) { + for (int i = 0; i < 7; ++i) { AddMethod(new ::grpc::RpcServiceMethod(grpcMasterService_method_names[i], ::grpc::RpcMethod::NORMAL_RPC, nullptr)); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h index afe4b583f8..a3a2ac8020 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h @@ -64,6 +64,9 @@ class MasterService GRPC_FINAL { virtual ::grpc::Status ExtendSession(::grpc::ClientContext* context, const ExtendSessionRequest& request, ExtendSessionResponse* response) = 0; + virtual ::grpc::Status PartialRunSetup( + ::grpc::ClientContext* context, const PartialRunSetupRequest& request, + PartialRunSetupResponse* response) = 0; virtual ::grpc::Status RunStep(::grpc::ClientContext* context, const RunStepRequest& request, RunStepResponse* response) = 0; @@ -86,6 +89,9 @@ class MasterService GRPC_FINAL { ::grpc::Status ExtendSession(::grpc::ClientContext* context, const ExtendSessionRequest& request, ExtendSessionResponse* response) GRPC_OVERRIDE; + ::grpc::Status PartialRunSetup( + ::grpc::ClientContext* context, const PartialRunSetupRequest& request, + PartialRunSetupResponse* response) GRPC_OVERRIDE; ::grpc::Status RunStep(::grpc::ClientContext* context, const RunStepRequest& request, RunStepResponse* response) GRPC_OVERRIDE; @@ -103,6 +109,7 @@ class MasterService GRPC_FINAL { std::shared_ptr< ::grpc::ChannelInterface> channel_; const ::grpc::RpcMethod rpcmethod_CreateSession_; const ::grpc::RpcMethod rpcmethod_ExtendSession_; + const ::grpc::RpcMethod rpcmethod_PartialRunSetup_; const ::grpc::RpcMethod rpcmethod_RunStep_; const ::grpc::RpcMethod rpcmethod_CloseSession_; const ::grpc::RpcMethod rpcmethod_ListDevices_; @@ -132,12 +139,20 @@ class MasterService GRPC_FINAL { ::grpc::Service::RequestAsyncUnary(1, context, request, response, new_call_cq, notification_cq, tag); } + void RequestPartialRunSetup( + ::grpc::ServerContext* context, PartialRunSetupRequest* request, + ::grpc::ServerAsyncResponseWriter<PartialRunSetupResponse>* response, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncUnary(2, context, request, response, + new_call_cq, notification_cq, tag); + } void RequestRunStep( ::grpc::ServerContext* context, RunStepRequest* request, ::grpc::ServerAsyncResponseWriter<RunStepResponse>* response, ::grpc::CompletionQueue* new_call_cq, ::grpc::ServerCompletionQueue* notification_cq, void* tag) { - ::grpc::Service::RequestAsyncUnary(2, context, request, response, + ::grpc::Service::RequestAsyncUnary(3, context, request, response, new_call_cq, notification_cq, tag); } void RequestCloseSession( @@ -145,7 +160,7 @@ class MasterService GRPC_FINAL { ::grpc::ServerAsyncResponseWriter<CloseSessionResponse>* response, ::grpc::CompletionQueue* new_call_cq, ::grpc::ServerCompletionQueue* notification_cq, void* tag) { - ::grpc::Service::RequestAsyncUnary(3, context, request, response, + ::grpc::Service::RequestAsyncUnary(4, context, request, response, new_call_cq, notification_cq, tag); } void RequestListDevices( @@ -153,7 +168,7 @@ class MasterService GRPC_FINAL { ::grpc::ServerAsyncResponseWriter<ListDevicesResponse>* response, ::grpc::CompletionQueue* new_call_cq, ::grpc::ServerCompletionQueue* notification_cq, void* tag) { - ::grpc::Service::RequestAsyncUnary(4, context, request, response, + ::grpc::Service::RequestAsyncUnary(5, context, request, response, new_call_cq, notification_cq, tag); } void RequestReset( @@ -161,7 +176,7 @@ class MasterService GRPC_FINAL { ::grpc::ServerAsyncResponseWriter<ResetResponse>* response, ::grpc::CompletionQueue* new_call_cq, ::grpc::ServerCompletionQueue* notification_cq, void* tag) { - ::grpc::Service::RequestAsyncUnary(5, context, request, response, + ::grpc::Service::RequestAsyncUnary(6, context, request, response, new_call_cq, notification_cq, tag); } }; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc index 3af1ee9cce..879f177da5 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc @@ -52,6 +52,15 @@ class GrpcRemoteMaster : public MasterInterface { return FromGrpcStatus(stub_->ExtendSession(&ctx, *request, response)); } + Status PartialRunSetup(CallOptions* call_options, + const PartialRunSetupRequest* request, + PartialRunSetupResponse* response) override { + ::grpc::ClientContext ctx; + ctx.set_fail_fast(false); + SetDeadline(&ctx, call_options->GetTimeout()); + return FromGrpcStatus(stub_->PartialRunSetup(&ctx, *request, response)); + } + Status RunStep(CallOptions* call_options, const RunStepRequest* request, RunStepResponse* response) override { ::grpc::ClientContext ctx; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc index 886f04a756..bb2fbb7fea 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc @@ -152,18 +152,22 @@ Status GrpcSession::Extend(const RunOptions& run_options, return ExtendImpl(&call_options, graph); } -Status GrpcSession::Run(const RunOptions& run_options, - const std::vector<std::pair<string, Tensor>>& inputs, - const std::vector<string>& output_tensor_names, - const std::vector<string>& target_node_names, - std::vector<Tensor>* outputs, - RunMetadata* run_metadata) { +Status GrpcSession::RunHelper( + const RunOptions& run_options, + const std::vector<std::pair<string, Tensor>>& inputs, + const std::vector<string>& output_tensor_names, + const std::vector<string>& target_node_names, std::vector<Tensor>* outputs, + RunMetadata* run_metadata, const string& prun_handle) { // Convert to proto RunStepRequest req; RunStepResponse resp; *req.mutable_options() = run_options; + if (!prun_handle.empty()) { + req.set_partial_run_handle(prun_handle); + } + for (const auto& it : inputs) { Tensor input_tensor = it.second; auto feed = req.add_feed(); @@ -215,6 +219,16 @@ Status GrpcSession::Run(const RunOptions& run_options, return Status::OK(); } +Status GrpcSession::Run(const RunOptions& run_options, + const std::vector<std::pair<string, Tensor>>& inputs, + const std::vector<string>& output_tensor_names, + const std::vector<string>& target_node_names, + std::vector<Tensor>* outputs, + RunMetadata* run_metadata) { + return RunHelper(run_options, inputs, output_tensor_names, target_node_names, + outputs, run_metadata, /* prun_handle */ ""); +} + Status GrpcSession::Run(const std::vector<std::pair<string, Tensor>>& inputs, const std::vector<string>& output_tensor_names, const std::vector<string>& target_node_names, @@ -242,14 +256,41 @@ Status GrpcSession::PRunSetup(const std::vector<string>& input_names, const std::vector<string>& output_names, const std::vector<string>& target_nodes, string* handle) { - return errors::Internal("Partial run is not supported for remote session."); + // Convert to proto + PartialRunSetupRequest req; + PartialRunSetupResponse resp; + CallOptions call_options; + { + mutex_lock l(mu_); + if (handle_.empty()) { + return errors::InvalidArgument("A session is not created yet...."); + } + + req.set_session_handle(handle_); + } + for (const string& feed : input_names) { + req.add_feed(feed); + } + for (const string& fetch : output_names) { + req.add_fetch(fetch); + } + for (const string& target : target_nodes) { + req.add_target(target); + } + call_options.SetTimeout(options_.config.operation_timeout_in_ms()); + TF_RETURN_IF_ERROR(master_->PartialRunSetup(&call_options, &req, &resp)); + *handle = resp.partial_run_handle(); + return Status::OK(); } Status GrpcSession::PRun(const string& handle, const std::vector<std::pair<string, Tensor>>& inputs, const std::vector<string>& output_names, std::vector<Tensor>* outputs) { - return errors::Internal("Partial run is not supported for remote session."); + RunOptions run_options; + run_options.set_timeout_in_ms(options_.config.operation_timeout_in_ms()); + return RunHelper(run_options, inputs, output_names, /* targets */ {}, outputs, + /* run_metadata */ nullptr, handle); } Status GrpcSession::Close() { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.h b/tensorflow/core/distributed_runtime/rpc/grpc_session.h index 0d532520cf..60a7b8334f 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_session.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.h @@ -110,6 +110,13 @@ class GrpcSession : public Session { // The current version of the graph. int64 current_graph_version_ GUARDED_BY(mu_); + Status RunHelper(const RunOptions& run_options, + const std::vector<std::pair<string, Tensor> >& inputs, + const std::vector<string>& output_tensor_names, + const std::vector<string>& target_node_names, + std::vector<Tensor>* outputs, RunMetadata* run_metadata, + const string& prun_handle); + Status RunProto(CallOptions* call_options, RunStepRequest* req, RunStepResponse* resp); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc index ec8c06abb4..f6580cfd42 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc @@ -163,6 +163,41 @@ class GrpcWorkerService : public AsyncServiceInterface { mutex mu_; CancellationManager* cancellation_manager_ GUARDED_BY(mu_); + struct PartialRunState { + CancellationManager* cancellation_manager; + Notification executor_done; + + explicit PartialRunState(CancellationManager* cm) + : cancellation_manager(cm) {} + }; + std::unordered_map<std::pair<string, int>, std::unique_ptr<PartialRunState>, + hash<std::pair<string, int>>> + partial_runs_ GUARDED_BY(mu_); + + PartialRunState* FindPartialRun(const string& graph_handle, int step_id) { + std::pair<string, int> k(graph_handle, step_id); + PartialRunState* prun_state = nullptr; + mutex_lock l(mu_); + auto it = partial_runs_.find(k); + if (it != partial_runs_.end()) { + prun_state = it->second.get(); + } + return prun_state; + } + + void InsertPartialRunLocked(const string& graph_handle, int step_id, + PartialRunState* partial_run_state) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + std::pair<string, int> k(graph_handle, step_id); + partial_runs_.emplace( + std::make_pair(k, std::unique_ptr<PartialRunState>(partial_run_state))); + } + + void RemovePartialRun(const string& graph_handle, int step_id) { + std::pair<string, int> k(graph_handle, step_id); + mutex_lock l(mu_); + partial_runs_.erase(partial_runs_.find(k)); + } mutex shutdown_mu_; bool is_shutdown_ GUARDED_BY(shutdown_mu_); @@ -225,7 +260,11 @@ class GrpcWorkerService : public AsyncServiceInterface { } void RunGraphHandler(WorkerCall<RunGraphRequest, RunGraphResponse>* call) { - env_->compute_pool->Schedule([this, call]() { DoRunGraph(call); }); + if (call->request.is_partial()) { + env_->compute_pool->Schedule([this, call]() { DoPartialRunGraph(call); }); + } else { + env_->compute_pool->Schedule([this, call]() { DoRunGraph(call); }); + } ENQUEUE_REQUEST(RunGraph, true); } @@ -294,10 +333,6 @@ class GrpcWorkerService : public AsyncServiceInterface { Status PrepareRunGraph(const RunGraphRequest& req, GraphMgr::NamedTensors* in, GraphMgr::NamedTensors* out) { - if (req.is_partial()) { - return errors::Unimplemented( - "Partial run not implemented for GRPC worker service"); - } if (req.send_size() > 0) { // TODO(zhifengc): Let the caller decide on which device to // allocate the tensor. @@ -378,6 +413,110 @@ class GrpcWorkerService : public AsyncServiceInterface { }); } + // TODO(suharshs): Add stats collection support to partial run. + void DoPartialRunGraph(WorkerCall<RunGraphRequest, RunGraphResponse>* call) { + const int64 step_id = call->request.step_id(); + const string& graph_handle = call->request.graph_handle(); + TRACEPRINTF("PartialRunGraph: %lld", step_id); + GraphMgr::NamedTensors in; + GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors; + Status s = PrepareRunGraph(call->request, &in, out); + auto finish = [this, call, out](const Status& s) { + delete out; + call->ClearCancelCallback(); + call->SendResponse(ToGrpcStatus(s)); + }; + if (!s.ok()) { + finish(s); + return; + } + + PartialRunState* partial_run_state = FindPartialRun(graph_handle, step_id); + + CancellationManager* cm = nullptr; + // If this is a new partial run call we need to create a new cancellation + // manager. + // Otherwise we use the cancellation manager stored in the found partial + // run state. + if (partial_run_state == nullptr) { + cm = new CancellationManager; + } else { + cm = partial_run_state->cancellation_manager; + } + + // Before we start doing anything, we set the RPC cancellation. + call->SetCancelCallback([this, cm, step_id]() { + cm->StartCancel(); + AbortStep(step_id); + }); + + // If this is a new partial run request, the request will need to start the + // executors. + if (partial_run_state == nullptr) { + CancellationToken token; + { + mutex_lock l(mu_); + // Insert the new partial run into the partial_runs_ map. + partial_run_state = new PartialRunState(cm); + InsertPartialRunLocked(graph_handle, step_id, partial_run_state); + token = cancellation_manager_->get_cancellation_token(); + cancellation_manager_->RegisterCallback(token, + [cm]() { cm->StartCancel(); }); + } + env_->graph_mgr->ExecuteAsync( + graph_handle, step_id, call->request.exec_opts(), + nullptr /* collector */, nullptr /* cost_graph */, cm, in, + [this, step_id, graph_handle, token, partial_run_state](Status s) { + { + mutex_lock l(mu_); + cancellation_manager_->DeregisterCallback(token); + } + partial_run_state->executor_done.Notify(); + // TODO(suharshs): Propagate the status once we keep state for + // each partial run call. + }); + } else { + // Send the partial run's new inputs. + s = env_->graph_mgr->SendInputs(step_id, in); + if (!s.ok()) { + finish(s); + return; + } + } + + // Receive the partial run's outputs. + s = env_->graph_mgr->RecvOutputs(step_id, out); + if (!s.ok()) { + finish(s); + return; + } + + // Construct and return the resp. + for (const auto& p : *out) { + const string& key = p.first; + const Tensor& val = p.second; + auto* recv = call->response.add_recv(); + recv->set_key(key); + // TODO(zhifengc): Deal with gpu -> cpu copy. + TensorProto* proto = recv->mutable_val(); + val.AsProtoField(proto); + } + + // If this is the last partial run request we must also wait for the entire + // graph execution to be completed. + if (call->request.is_last_partial_run()) { + partial_run_state->executor_done.WaitForNotification(); + RemovePartialRun(graph_handle, step_id); + // Before deleting the cancellation manager on the final call, ensure + // that we clear the RPC cancel callback, which has a reference to the + // cancellation manager. + call->ClearCancelCallback(); + delete cm; + } + + finish(s); + } + // Helper for RecvTensor. Validates "key" and returns the source // device in "*src_dev". Status PrepareRecvTensor(const Rendezvous::ParsedKey& parsed, diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index cfa0e15d48..9f721274c9 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2055,6 +2055,7 @@ py_test( ":math_ops", ":session", ":state_ops", + ":training", ":variables", ], ) diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index a20376b91d..8e3dd28c21 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -45,6 +45,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest +from tensorflow.python.training import server_lib from tensorflow.python.util import compat @@ -1310,91 +1311,121 @@ class SessionTest(test_util.TensorFlowTestCase): sess_2.run(c_1.op) self.assertEqual(2.0, sess_2.run(c_2)) - def testPartialRun(self): - with session.Session() as sess: - a = array_ops.placeholder(dtypes.float32, shape=[]) - b = array_ops.placeholder(dtypes.float32, shape=[]) - c = array_ops.placeholder(dtypes.float32, shape=[]) - r1 = math_ops.add(a, b) - r2 = math_ops.mul(r1, c) - - h = sess.partial_run_setup([r1, r2], [a, b, c]) - res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2}) - self.assertEqual(3, res) - temp = res * 17 - res = sess.partial_run(h, r2, feed_dict={c: temp}) - self.assertEqual(153, res) - - # Call again on the same graph. - h2 = sess.partial_run_setup([r1, r2], [a, b, c]) - res = sess.partial_run(h2, r1, feed_dict={a: 1, b: 2}) - self.assertEqual(3, res) - temp = res * 18 - res = sess.partial_run(h2, r2, feed_dict={c: temp}) - self.assertEqual(162, res) - - def testPartialRunIncomplete(self): - with session.Session() as sess: - a = array_ops.placeholder(dtypes.float32, shape=[]) - b = array_ops.placeholder(dtypes.float32, shape=[]) - c = array_ops.placeholder(dtypes.float32, shape=[]) - r1 = math_ops.add(a, b) - r2 = math_ops.mul(r1, c) - - h = sess.partial_run_setup([r1, r2], [a, b, c]) - res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2}) - self.assertEqual(3, res) - - def testConcurrentPartialRun(self): - with session.Session() as sess: - a = array_ops.placeholder(dtypes.float32, shape=[]) - b = array_ops.placeholder(dtypes.float32, shape=[]) - c = array_ops.placeholder(dtypes.float32, shape=[]) - r1 = math_ops.add(a, b) - r2 = math_ops.mul(r1, c) - - h1 = sess.partial_run_setup([r1], [a, b, c]) - h2 = sess.partial_run_setup([r1, r2], [a, b, c]) - res = sess.partial_run(h1, r1, feed_dict={a: 1, b: 2}) - self.assertEqual(3, res) - temp = res * 19 - res = sess.partial_run(h2, r1, feed_dict={a: temp, b: 9}) - self.assertEqual(66, res) - res = sess.partial_run(h2, r2, feed_dict={c: 7}) - self.assertEqual(462, res) - - def testManyPartialRun(self): - with session.Session() as sess: - steps = 200 - inputs = [] - outputs = [] - a = constant_op.constant(2.0, dtypes.float32) - for i in xrange(steps): - inputs.append(array_ops.placeholder(dtypes.float32, shape=[])) - a = math_ops.mul(a, inputs[i]) - outputs.append(a) - - h = sess.partial_run_setup(outputs, inputs) - for i in xrange(steps): - res = sess.partial_run(h, outputs[i], feed_dict={inputs[i]: 1.0}) - self.assertEqual(2.0, res) - - feed_dict = {} - for i in xrange(steps): - feed_dict[inputs[i]] = 1.0 - res = sess.run(outputs, feed_dict) - self.assertEqual(steps, len(res)) - self.assertEqual(2.0, res[-1]) - - def testRunAndPartialRun(self): - with session.Session() as sess: - a = constant_op.constant(2.0, dtypes.float32) - b = a * 2 - c = b * 3 - r1 = sess.run([b, c]) - h = sess.partial_run_setup([b, c], []) - r2 = sess.partial_run(h, [b, c]) - self.assertEqual(r1, r2) + def runTestPartialRun(self, sess): + a = array_ops.placeholder(dtypes.float32, shape=[]) + b = array_ops.placeholder(dtypes.float32, shape=[]) + c = array_ops.placeholder(dtypes.float32, shape=[]) + r1 = math_ops.add(a, b) + r2 = math_ops.mul(r1, c) + + h = sess.partial_run_setup([r1, r2], [a, b, c]) + res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2}) + self.assertEqual(3, res) + temp = res * 17 + res = sess.partial_run(h, r2, feed_dict={c: temp}) + self.assertEqual(153, res) + + # Call again on the same graph. + h2 = sess.partial_run_setup([r1, r2], [a, b, c]) + res = sess.partial_run(h2, r1, feed_dict={a: 1, b: 2}) + self.assertEqual(3, res) + temp = res * 18 + res = sess.partial_run(h2, r2, feed_dict={c: temp}) + self.assertEqual(162, res) + + def runTestPartialRunIncomplete(self, sess): + a = array_ops.placeholder(dtypes.float32, shape=[]) + b = array_ops.placeholder(dtypes.float32, shape=[]) + c = array_ops.placeholder(dtypes.float32, shape=[]) + r1 = math_ops.add(a, b) + r2 = math_ops.mul(r1, c) + + h = sess.partial_run_setup([r1, r2], [a, b, c]) + res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2}) + self.assertEqual(3, res) + + def runTestConcurrentPartialRun(self, sess): + a = array_ops.placeholder(dtypes.float32, shape=[]) + b = array_ops.placeholder(dtypes.float32, shape=[]) + c = array_ops.placeholder(dtypes.float32, shape=[]) + r1 = math_ops.add(a, b) + r2 = math_ops.mul(r1, c) + + h1 = sess.partial_run_setup([r1], [a, b, c]) + h2 = sess.partial_run_setup([r1, r2], [a, b, c]) + res = sess.partial_run(h1, r1, feed_dict={a: 1, b: 2}) + self.assertEqual(3, res) + temp = res * 19 + res = sess.partial_run(h2, r1, feed_dict={a: temp, b: 9}) + self.assertEqual(66, res) + res = sess.partial_run(h2, r2, feed_dict={c: 7}) + self.assertEqual(462, res) + + def runTestManyPartialRun(self, sess): + steps = 200 + inputs = [] + outputs = [] + a = constant_op.constant(2.0, dtypes.float32) + for i in xrange(steps): + inputs.append(array_ops.placeholder(dtypes.float32, shape=[])) + a = math_ops.mul(a, inputs[i]) + outputs.append(a) + + h = sess.partial_run_setup(outputs, inputs) + for i in xrange(steps): + res = sess.partial_run(h, outputs[i], feed_dict={inputs[i]: 1.0}) + self.assertEqual(2.0, res) + + feed_dict = {} + for i in xrange(steps): + feed_dict[inputs[i]] = 1.0 + res = sess.run(outputs, feed_dict) + self.assertEqual(steps, len(res)) + self.assertEqual(2.0, res[-1]) + + def runTestRunAndPartialRun(self, sess): + a = constant_op.constant(2.0, dtypes.float32) + b = a * 2 + c = b * 3 + r1 = sess.run([b, c]) + h = sess.partial_run_setup([b, c], []) + r2 = sess.partial_run(h, [b, c]) + self.assertEqual(r1, r2) + + def testPartialRunDirect(self): + self.runTestPartialRun(session.Session()) + + def testPartialRunIncompleteDirect(self): + self.runTestPartialRunIncomplete(session.Session()) + + def testConcurrentPartialRunDirect(self): + self.runTestConcurrentPartialRun(session.Session()) + + def testManyPartialRunDirect(self): + self.runTestManyPartialRun(session.Session()) + + def testRunAndPartialRunDirect(self): + self.runTestRunAndPartialRun(session.Session()) + + def testPartialRunDist(self): + server = server_lib.Server.create_local_server() + self.runTestPartialRun(session.Session(server.target)) + + def testPartialRunIncompleteDist(self): + server = server_lib.Server.create_local_server() + self.runTestPartialRunIncomplete(session.Session(server.target)) + + def testConcurrentPartialRunDist(self): + server = server_lib.Server.create_local_server() + self.runTestConcurrentPartialRun(session.Session(server.target)) + + def testManyPartialRunDist(self): + server = server_lib.Server.create_local_server() + self.runTestManyPartialRun(session.Session(server.target)) + + def testRunAndPartialRunDist(self): + server = server_lib.Server.create_local_server() + self.runTestRunAndPartialRun(session.Session(server.target)) def testFeedDictKeyException(self): with session.Session() as sess: |