/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/core/distributed_runtime/local_master.h" #include #include "tensorflow/core/distributed_runtime/master.h" #include "tensorflow/core/platform/mutex.h" namespace tensorflow { namespace { Status WaitForNotification(CallOptions* call_options, const int64 default_timeout_in_ms, Notification* n) { int64 timeout_in_ms = call_options->GetTimeout(); if (timeout_in_ms == 0) { timeout_in_ms = default_timeout_in_ms; } if (timeout_in_ms > 0) { int64 timeout_in_us = timeout_in_ms * 1000; bool notified = WaitForNotificationWithTimeout(n, timeout_in_us); if (!notified) { call_options->StartCancel(); // The call has borrowed pointers to the request and response // messages, so we must still wait for the call to complete. n->WaitForNotification(); return errors::DeadlineExceeded("Operation timed out."); } } else { n->WaitForNotification(); } return Status::OK(); } } // namespace LocalMaster::LocalMaster(Master* master_impl, const int64 default_timeout_in_ms) : master_impl_(master_impl), default_timeout_in_ms_(default_timeout_in_ms) {} Status LocalMaster::CreateSession(CallOptions* call_options, const CreateSessionRequest* request, CreateSessionResponse* response) { Notification n; Status ret; master_impl_->CreateSession(request, response, [&n, &ret](const Status& s) { ret.Update(s); n.Notify(); }); TF_RETURN_IF_ERROR( WaitForNotification(call_options, default_timeout_in_ms_, &n)); return ret; } Status LocalMaster::ExtendSession(CallOptions* call_options, const ExtendSessionRequest* request, ExtendSessionResponse* response) { Notification n; Status ret; master_impl_->ExtendSession(request, response, [&n, &ret](const Status& s) { ret.Update(s); n.Notify(); }); TF_RETURN_IF_ERROR( WaitForNotification(call_options, default_timeout_in_ms_, &n)); return ret; } Status LocalMaster::PartialRunSetup(CallOptions* call_options, const PartialRunSetupRequest* request, PartialRunSetupResponse* response) { Notification n; Status ret; master_impl_->PartialRunSetup(request, response, [&n, &ret](const Status& s) { ret.Update(s); n.Notify(); }); TF_RETURN_IF_ERROR( WaitForNotification(call_options, default_timeout_in_ms_, &n)); return ret; } Status LocalMaster::RunStep(CallOptions* call_options, RunStepRequestWrapper* request, MutableRunStepResponseWrapper* response) { Notification n; Status ret; master_impl_->RunStep(call_options, request, response, [&n, &ret](const Status& s) { ret.Update(s); n.Notify(); }); TF_RETURN_IF_ERROR( WaitForNotification(call_options, default_timeout_in_ms_, &n)); return ret; } MutableRunStepRequestWrapper* LocalMaster::CreateRunStepRequest() { return new InMemoryRunStepRequest; } MutableRunStepResponseWrapper* LocalMaster::CreateRunStepResponse() { return new InMemoryRunStepResponse; } Status LocalMaster::CloseSession(CallOptions* call_options, const CloseSessionRequest* request, CloseSessionResponse* response) { Notification n; Status ret; master_impl_->CloseSession(request, response, [&n, &ret](const Status& s) { ret.Update(s); n.Notify(); }); TF_RETURN_IF_ERROR( WaitForNotification(call_options, default_timeout_in_ms_, &n)); return ret; } Status LocalMaster::ListDevices(CallOptions* call_options, const ListDevicesRequest* request, ListDevicesResponse* response) { Notification n; Status ret; master_impl_->ListDevices(request, response, [&n, &ret](const Status& s) { ret.Update(s); n.Notify(); }); TF_RETURN_IF_ERROR( WaitForNotification(call_options, default_timeout_in_ms_, &n)); return ret; } Status LocalMaster::Reset(CallOptions* call_options, const ResetRequest* request, ResetResponse* response) { Notification n; Status ret; master_impl_->Reset(request, response, [&n, &ret](const Status& s) { ret.Update(s); n.Notify(); }); TF_RETURN_IF_ERROR( WaitForNotification(call_options, default_timeout_in_ms_, &n)); return ret; } Status LocalMaster::MakeCallable(CallOptions* call_options, const MakeCallableRequest* request, MakeCallableResponse* response) { Notification n; Status ret; master_impl_->MakeCallable(request, response, [&n, &ret](const Status& s) { ret.Update(s); n.Notify(); }); TF_RETURN_IF_ERROR( WaitForNotification(call_options, default_timeout_in_ms_, &n)); return ret; } Status LocalMaster::RunCallable(CallOptions* call_options, const RunCallableRequest* request, RunCallableResponse* response) { Notification n; Status ret; master_impl_->RunCallable(call_options, request, response, [&n, &ret](const Status& s) { ret.Update(s); n.Notify(); }); TF_RETURN_IF_ERROR( WaitForNotification(call_options, default_timeout_in_ms_, &n)); return ret; } Status LocalMaster::ReleaseCallable(CallOptions* call_options, const ReleaseCallableRequest* request, ReleaseCallableResponse* response) { Notification n; Status ret; master_impl_->ReleaseCallable(request, response, [&n, &ret](const Status& s) { ret.Update(s); n.Notify(); }); TF_RETURN_IF_ERROR( WaitForNotification(call_options, default_timeout_in_ms_, &n)); return ret; } namespace { mutex* get_local_master_registry_lock() { static mutex local_master_registry_lock(LINKER_INITIALIZED); return &local_master_registry_lock; } struct MasterInfo { Master* master; const int64 default_timeout_in_ms; MasterInfo(Master* master, const int64 default_timeout_in_ms) : master(master), default_timeout_in_ms(default_timeout_in_ms) {} }; typedef std::unordered_map LocalMasterRegistry; LocalMasterRegistry* local_master_registry() { static LocalMasterRegistry* local_master_registry_ = new LocalMasterRegistry; return local_master_registry_; } } // namespace /* static */ void LocalMaster::Register(const string& target, Master* master, int64 default_timeout_in_ms) { mutex_lock l(*get_local_master_registry_lock()); local_master_registry()->insert( {target, MasterInfo(master, default_timeout_in_ms)}); } /* static */ std::unique_ptr LocalMaster::Lookup(const string& target) { std::unique_ptr ret; mutex_lock l(*get_local_master_registry_lock()); auto iter = local_master_registry()->find(target); if (iter != local_master_registry()->end()) { ret.reset(new LocalMaster(iter->second.master, iter->second.default_timeout_in_ms)); } return ret; } } // namespace tensorflow