diff options
Diffstat (limited to 'tensorflow/core/common_runtime')
28 files changed, 622 insertions, 173 deletions
diff --git a/tensorflow/core/common_runtime/bfc_allocator.cc b/tensorflow/core/common_runtime/bfc_allocator.cc index 9cda17867b..3bf0532491 100644 --- a/tensorflow/core/common_runtime/bfc_allocator.cc +++ b/tensorflow/core/common_runtime/bfc_allocator.cc @@ -155,10 +155,6 @@ bool BFCAllocator::Extend(size_t alignment, size_t rounded_bytes) { region_manager_.set_handle(c->ptr, h); - // TODO(vrv): Try to merge this new region with an existing region, - // if the address space is contiguous, to avoid fragmentation - // across regions. - // Insert the chunk into the right bin. InsertFreeChunkIntoBin(h); @@ -465,49 +461,33 @@ void BFCAllocator::FreeAndMaybeCoalesce(BFCAllocator::ChunkHandle h) { Chunk* c = ChunkFromHandle(h); CHECK(c->in_use() && (c->bin_num == kInvalidBinNum)); - // Mark the chunk as no longer in use + // Mark the chunk as no longer in use. c->allocation_id = -1; // Updates the stats. stats_.bytes_in_use -= c->size; - // This chunk is no longer in-use, consider coalescing the chunk - // with adjacent chunks. - ChunkHandle chunk_to_reassign = h; - - // If the next chunk is free, coalesce the two - if (c->next != kInvalidChunkHandle) { - Chunk* cnext = ChunkFromHandle(c->next); - if (!cnext->in_use()) { - // VLOG(8) << "Chunk at " << cnext->ptr << " merging with c " << - // c->ptr; - - chunk_to_reassign = h; + ChunkHandle coalesced_chunk = h; - // Deletes c->next - RemoveFreeChunkFromBin(c->next); - Merge(h, ChunkFromHandle(h)->next); - } + // If the next chunk is free, merge it into c and delete it. + if (c->next != kInvalidChunkHandle && !ChunkFromHandle(c->next)->in_use()) { + // VLOG(8) << "Merging c->next " << ChunkFromHandle(c->next)->ptr + // << " with c " << c->ptr; + RemoveFreeChunkFromBin(c->next); + Merge(h, c->next); } - // If the previous chunk is free, coalesce the two - c = ChunkFromHandle(h); - if (c->prev != kInvalidChunkHandle) { - Chunk* cprev = ChunkFromHandle(c->prev); - if (!cprev->in_use()) { - // VLOG(8) << "Chunk at " << c->ptr << " merging into c->prev " - // << cprev->ptr; - - chunk_to_reassign = c->prev; + // If the previous chunk is free, merge c into it and delete c. + if (c->prev != kInvalidChunkHandle && !ChunkFromHandle(c->prev)->in_use()) { + // VLOG(8) << "Merging c " << c->ptr << " into c->prev " + // << ChunkFromHandle(c->prev)->ptr; - // Deletes c - RemoveFreeChunkFromBin(c->prev); - Merge(ChunkFromHandle(h)->prev, h); - c = ChunkFromHandle(h); - } + coalesced_chunk = c->prev; + RemoveFreeChunkFromBin(c->prev); + Merge(c->prev, h); } - InsertFreeChunkIntoBin(chunk_to_reassign); + InsertFreeChunkIntoBin(coalesced_chunk); } void BFCAllocator::AddAllocVisitor(Visitor visitor) { diff --git a/tensorflow/core/common_runtime/bfc_allocator.h b/tensorflow/core/common_runtime/bfc_allocator.h index 52aedb1e9c..580e61e2ea 100644 --- a/tensorflow/core/common_runtime/bfc_allocator.h +++ b/tensorflow/core/common_runtime/bfc_allocator.h @@ -88,11 +88,20 @@ class BFCAllocator : public VisitableAllocator { static const int kInvalidBinNum = -1; static const int kNumBins = 21; - // Chunks point to memory. Their prev/next pointers form a - // doubly-linked list of addresses sorted by base address that - // must be contiguous. Chunks contain information about whether - // they are in use or whether they are free, and contain a pointer - // to the bin they are in. + // A Chunk points to a piece of memory that's either entirely free or entirely + // in use by one user memory allocation. + // + // An AllocationRegion's memory is split up into one or more disjoint Chunks, + // which together cover the whole region without gaps. Chunks participate in + // a doubly-linked list, and the prev/next pointers point to the physically + // adjacent chunks. + // + // Since a chunk cannot be partially in use, we may need to split a free chunk + // in order to service a user allocation. We always merge adjacent free + // chunks. + // + // Chunks contain information about whether they are in use or whether they + // are free, and contain a pointer to the bin they are in. struct Chunk { size_t size = 0; // Full size of buffer. @@ -177,8 +186,12 @@ class BFCAllocator : public VisitableAllocator { static const size_t kMinAllocationBits = 8; static const size_t kMinAllocationSize = 1 << kMinAllocationBits; - // AllocationRegion maps pointers to ChunkHandles for a single - // contiguous memory region. + // BFCAllocator allocates memory into a collection of disjoint + // AllocationRegions. Each AllocationRegion corresponds to one call to + // SubAllocator::Alloc(). + // + // An AllocationRegion contains one or more Chunks, covering all of its + // memory. Its primary job is to map a pointers to ChunkHandles. // // This class is thread-compatible. class AllocationRegion { @@ -191,18 +204,14 @@ class BFCAllocator : public VisitableAllocator { DCHECK_EQ(0, memory_size % kMinAllocationSize); const size_t n_handles = (memory_size + kMinAllocationSize - 1) / kMinAllocationSize; - handles_ = new ChunkHandle[n_handles]; + handles_.reset(new ChunkHandle[n_handles]); for (size_t i = 0; i < n_handles; i++) { handles_[i] = kInvalidChunkHandle; } } - AllocationRegion() {} - - ~AllocationRegion() { delete[] handles_; } - + AllocationRegion() = default; AllocationRegion(AllocationRegion&& other) { Swap(other); } - AllocationRegion& operator=(AllocationRegion&& other) { Swap(other); return *this; @@ -241,7 +250,7 @@ class BFCAllocator : public VisitableAllocator { // Array of size "memory_size / kMinAllocationSize". It is // indexed by (p-base) / kMinAllocationSize, contains ChunkHandle // for the memory allocation represented by "p" - ChunkHandle* handles_ = nullptr; + std::unique_ptr<ChunkHandle[]> handles_; TF_DISALLOW_COPY_AND_ASSIGN(AllocationRegion); }; diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index f903faf1bd..d1fd930d25 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -146,18 +146,15 @@ class DirectSessionFactory : public SessionFactory { return options.target.empty(); } - Session* NewSession(const SessionOptions& options) override { + Status NewSession(const SessionOptions& options, + Session** out_session) override { // Must do this before the CPU allocator is created. if (options.config.graph_options().build_cost_model() > 0) { EnableCPUAllocatorFullStats(true); } std::vector<Device*> devices; - const Status s = DeviceFactory::AddDevices( - options, "/job:localhost/replica:0/task:0", &devices); - if (!s.ok()) { - LOG(ERROR) << s; - return nullptr; - } + TF_RETURN_IF_ERROR(DeviceFactory::AddDevices( + options, "/job:localhost/replica:0/task:0", &devices)); DirectSession* session = new DirectSession(options, new DeviceMgr(devices), this); @@ -165,7 +162,8 @@ class DirectSessionFactory : public SessionFactory { mutex_lock l(sessions_lock_); sessions_.push_back(session); } - return session; + *out_session = session; + return Status::OK(); } Status Reset(const SessionOptions& options, @@ -237,7 +235,11 @@ void DirectSession::SchedClosure(thread::ThreadPool* pool, // safe given the reasoning above. c(); #else - pool->Schedule(std::move(c)); + if (pool != nullptr) { + pool->Schedule(std::move(c)); + } else { + c(); + } #endif // __ANDROID__ } @@ -524,8 +526,9 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options, } } - if (run_options.inter_op_thread_pool() < 0 || - run_options.inter_op_thread_pool() >= thread_pools_.size()) { + if (run_options.inter_op_thread_pool() < -1 || + run_options.inter_op_thread_pool() >= + static_cast<int32>(thread_pools_.size())) { run_state.executors_done.Notify(); delete barrier; return errors::InvalidArgument("Invalid inter_op_thread_pool: ", @@ -550,7 +553,19 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options, } thread::ThreadPool* pool = - thread_pools_[run_options.inter_op_thread_pool()].first; + run_options.inter_op_thread_pool() >= 0 + ? thread_pools_[run_options.inter_op_thread_pool()].first + : nullptr; + + if (pool == nullptr) { + // We allow using the caller thread only when having a single executor + // specified. + if (executors_and_keys->items.size() > 1) { + pool = thread_pools_[0].first; + } else { + VLOG(1) << "Executing Session::Run() synchronously!"; + } + } Executor::Args::Runner default_runner = [this, pool](Executor::Args::Closure c) { @@ -702,7 +717,8 @@ Status DirectSession::Run(const RunOptions& run_options, // Receive outputs. if (outputs) { std::vector<Tensor> sorted_outputs; - const Status s = call_frame.ConsumeRetvals(&sorted_outputs); + const Status s = call_frame.ConsumeRetvals( + &sorted_outputs, /* allow_dead_tensors = */ false); if (errors::IsInternal(s)) { return errors::InvalidArgument(s.error_message()); } else if (!s.ok()) { @@ -1188,12 +1204,11 @@ Status DirectSession::CreateExecutors( delete kernel; } }; - params.node_outputs_cb = node_outputs_callback_; optimizer.Optimize(lib, options_.env, device, &iter->second, /*shape_map=*/nullptr); - // EXPERIMENTAL: tfdbg inserts debug nodes in the graph. + // TensorFlow Debugger (tfdbg) inserts debug nodes in the graph. const DebugOptions& debug_options = options.callable_options.run_options().debug_options(); if (!debug_options.debug_tensor_watch_opts().empty()) { diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc index 142d613129..4b51b20bb1 100644 --- a/tensorflow/core/common_runtime/direct_session_test.cc +++ b/tensorflow/core/common_runtime/direct_session_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include <map> #include <memory> #include <string> +#include <thread> #include <unordered_map> #include <vector> @@ -896,6 +897,125 @@ TEST(DirectSessionTest, FetchMultipleTimes) { } } +TEST(DirectSessionTest, MultipleFeedTestSomeSyncRun) { + GraphDef def; + Graph g(OpRegistry::Global()); + RunOptions run_options; + run_options.set_inter_op_thread_pool(-1); + + Tensor first_value(DT_FLOAT, TensorShape({})); + first_value.scalar<float>()() = 1.0; + Node* first_const = test::graph::Constant(&g, first_value); + Node* first_identity = test::graph::Identity(&g, first_const); + + Tensor second_value(DT_FLOAT, TensorShape({})); + second_value.scalar<float>()() = 2.0; + Node* second_const = test::graph::Constant(&g, second_value); + Node* second_identity = test::graph::Identity(&g, second_const); + + test::graph::ToGraphDef(&g, &def); + + auto session = CreateSession(); + ASSERT_TRUE(session != nullptr); + TF_ASSERT_OK(session->Create(def)); + + std::vector<Tensor> outputs; + + // Fetch without feeding. + Status s = session->Run( + run_options, {}, + {first_identity->name() + ":0", second_identity->name() + ":0"}, {}, + &outputs, nullptr); + TF_ASSERT_OK(s); + ASSERT_EQ(2, outputs.size()); + ASSERT_EQ(1.0, outputs[0].flat<float>()(0)); + ASSERT_EQ(2.0, outputs[1].flat<float>()(0)); + + s = session->Run( + {}, {second_identity->name() + ":0", first_identity->name() + ":0"}, {}, + &outputs); + TF_ASSERT_OK(s); + ASSERT_EQ(2, outputs.size()); + ASSERT_EQ(2.0, outputs[0].flat<float>()(0)); + ASSERT_EQ(1.0, outputs[1].flat<float>()(0)); + + Tensor value_11(DT_FLOAT, TensorShape({})); + value_11.scalar<float>()() = 11.0; + Tensor value_22(DT_FLOAT, TensorShape({})); + value_22.scalar<float>()() = 22.0; + + // Feed [first_const, second_const] + s = session->Run( + {{first_const->name(), value_11}, {second_const->name(), value_22}}, + {first_identity->name() + ":0", second_identity->name() + ":0"}, {}, + &outputs); + TF_ASSERT_OK(s); + ASSERT_EQ(2, outputs.size()); + ASSERT_EQ(11.0, outputs[0].flat<float>()(0)); + ASSERT_EQ(22.0, outputs[1].flat<float>()(0)); + + // Feed [second_const, first_const] + s = session->Run( + {{second_const->name(), value_22}, {first_const->name(), value_11}}, + {first_identity->name() + ":0", second_identity->name() + ":0"}, {}, + &outputs); + TF_ASSERT_OK(s); + ASSERT_EQ(2, outputs.size()); + ASSERT_EQ(11.0, outputs[0].flat<float>()(0)); + ASSERT_EQ(22.0, outputs[1].flat<float>()(0)); + + // Feed [first_const, first_const] + s = session->Run( + run_options, + {{first_const->name(), value_11}, {first_const->name(), value_22}}, + {first_identity->name() + ":0", second_identity->name() + ":0"}, {}, + &outputs, nullptr); + EXPECT_TRUE(errors::IsInvalidArgument(s)); + EXPECT_TRUE(str_util::StrContains(s.error_message(), "fed more than once")); +} + +REGISTER_OP("ThreadID").Input("x: int64").Output("y: int64").Doc(R"doc( +ThreadID returns the thread ID that called compute. + +x: int64 +y: int64 +)doc"); + +// The ThreadID kernel returns the thread ID that executed Compute. +class ThreadIDOp : public OpKernel { + public: + explicit ThreadIDOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) override { + Tensor* out_tensor = nullptr; + OP_REQUIRES_OK(ctx, + ctx->allocate_output("y", TensorShape({}), &out_tensor)); + std::hash<std::thread::id> hasher; + out_tensor->scalar<int64>()() = + static_cast<int64>(hasher(std::this_thread::get_id())); + } +}; +REGISTER_KERNEL_BUILDER(Name("ThreadID").Device(DEVICE_CPU), ThreadIDOp); + +TEST(DirectSessionTest, SessionSyncRun) { + Graph g(OpRegistry::Global()); + Tensor vx(DT_INT64, TensorShape({})); + vx.scalar<int64>()() = 17; + Node* x = test::graph::Constant(&g, vx); + Node* y = test::graph::Unary(&g, "ThreadID", x); + GraphDef def; + test::graph::ToGraphDef(&g, &def); + auto sess = CreateSession(); + TF_ASSERT_OK(sess->Create(def)); + std::vector<Tensor> outputs; + RunOptions run_opts; + run_opts.set_inter_op_thread_pool(-1); + auto s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs, nullptr); + + std::hash<std::thread::id> hasher; + EXPECT_EQ(static_cast<int64>(hasher(std::this_thread::get_id())), + static_cast<int64>(outputs[0].scalar<int64>()())); +} + REGISTER_OP("Darth").Input("x: float").Output("y: float").Doc(R"doc( Darth promises one return value. @@ -1400,6 +1520,7 @@ static void TestSessionInterOpThreadsImpl(bool use_function_lib, p = options.config.add_session_inter_op_thread_pool(); if (use_global_pools) p->set_global_name("small pool"); p->set_num_threads(1); + const int kSyncPool = -1; const int kLargePool = 0; const int kSmallPool = 1; @@ -1442,7 +1563,11 @@ static void TestSessionInterOpThreadsImpl(bool use_function_lib, EXPECT_FLOAT_EQ(1.2, flat(0)); num_done.fetch_add(1); }; - tp->Schedule(fn); + if (tp != nullptr) { + tp->Schedule(fn); + } else { + fn(); + } }; // For blocking states: @@ -1463,9 +1588,10 @@ static void TestSessionInterOpThreadsImpl(bool use_function_lib, tp1 = new thread::ThreadPool(Env::Default(), "tp1", 5); - // Launch 2 session run calls. Neither will finish until the blocking op is + // Launch a session run call. It will not finish until the blocking op is // unblocked, because it is using all threads in the small pool. add_session_run_call(tp1, y, kSmallPool); + blocking_op_state->AwaitState(1); // Wait for the blocking op to Compute. // These will block on <BlockingOpState>. @@ -1484,10 +1610,15 @@ static void TestSessionInterOpThreadsImpl(bool use_function_lib, delete tp2; EXPECT_EQ(kUnblockedThreads, num_done.load()); + // Launch a session call using this thread. This will finish as it runs + // synchronously in this thread. + add_session_run_call(nullptr, x, kSyncPool); + // Unblock the blocked op and wait for the blocked functions to finish. blocking_op_state->MoveToState(1, 2); delete tp1; - EXPECT_EQ(kUnblockedThreads + kBlockedThreads + 1, num_done.load()); + + EXPECT_EQ(kUnblockedThreads + kBlockedThreads + 1 + 1, num_done.load()); delete blocking_op_state; blocking_op_state = nullptr; } @@ -1532,7 +1663,7 @@ TEST(DirectSessionTest, TestSessionInterOpThreadsInvalidOptions) { { std::unique_ptr<Session> session(NewSession(options)); TF_ASSERT_OK(session->Create(def)); - for (int pool_num = -1; pool_num <= 1; pool_num += 2) { + for (int pool_num = -2; pool_num <= 1; pool_num += 3) { RunOptions run_options; run_options.set_inter_op_thread_pool(pool_num); std::vector<Tensor> outputs; diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 70208fb6d1..5e0f0a45f8 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -17,8 +17,20 @@ limitations under the License. #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/util/env_var.h" namespace tensorflow { +namespace { + +bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) { + bool val; + if (ReadBoolFromEnvVar(env_var_name, default_val, &val).ok()) { + return val; + } + return default_val; +} + +} // namespace EagerContext::EagerContext(const SessionOptions& opts, ContextDevicePlacementPolicy default_policy, @@ -34,8 +46,16 @@ EagerContext::EagerContext(const SessionOptions& opts, local_device_manager_.get(), opts.env, TF_GRAPH_DEF_VERSION, &func_lib_def_, {}, thread_pool_.get())), log_device_placement_(opts.config.log_device_placement()), - async_default_(async) { + async_default_(async), + use_send_tensor_rpc_(false) { InitDeviceMapAndAsync(); + if (opts.config.inter_op_parallelism_threads() > 0) { + runner_ = [this](std::function<void()> closure) { + this->thread_pool_->Schedule(closure); + }; + } else { + runner_ = [](std::function<void()> closure) { closure(); }; + } } #ifndef __ANDROID__ @@ -59,7 +79,9 @@ EagerContext::EagerContext( remote_device_manager_(std::move(remote_device_manager)), server_(std::move(server)), remote_eager_workers_(std::move(remote_eager_workers)), - remote_contexts_(remote_contexts) { + remote_contexts_(remote_contexts), + use_send_tensor_rpc_( + ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", false)) { InitDeviceMapAndAsync(); } #endif diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 864f514a19..4a180e074d 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -105,6 +105,8 @@ class EagerContext { EagerExecutor* Executor() { return &executor_; } + std::function<void(std::function<void()>)>* runner() { return &runner_; } + // Sets whether this thread should run in synchronous or asynchronous mode. Status SetAsyncForThread(bool async); @@ -180,6 +182,11 @@ class EagerContext { #ifndef __ANDROID__ Status GetClientAndContextID(Device* device, eager::EagerClient** client, uint64* context_id); + + // If true, then tensors should be shipped across processes via the + // EagerService.SendTensor RPC. If false, _Send/_Recv ops should be used + // instead (which in-turn use WorkerService.RecvTensor RPCs. + bool UseSendTensorRPC() { return use_send_tensor_rpc_; } #endif private: void InitDeviceMapAndAsync(); @@ -214,6 +221,8 @@ class EagerContext { // session->devices[i]. const std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; + std::function<void(std::function<void()>)> runner_; + mutex cache_mu_; std::unordered_map<Fprint128, KernelAndDevice*, Fprint128Hasher> kernel_cache_ GUARDED_BY(cache_mu_); @@ -235,16 +244,18 @@ class EagerContext { const std::unique_ptr<DeviceMgr> remote_device_manager_; +#ifndef __ANDROID__ // The server_ is not const since we release it when the context is destroyed. // Therefore the server_ object is not marked as const (even though it should // be). -#ifndef __ANDROID__ std::unique_ptr<ServerInterface> server_; const std::unique_ptr<eager::EagerClientCache> remote_eager_workers_; const gtl::FlatMap<string, uint64> remote_contexts_; gtl::FlatMap<Device*, std::pair<eager::EagerClient*, uint64>> device_to_client_cache_; + + const bool use_send_tensor_rpc_; #endif }; diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 7a2b477845..7ea78b63d9 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -88,6 +88,8 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op, int i, TF_RETURN_IF_ERROR((*handle)->Device(&handle_device)); const Device* actual_device = handle_device == nullptr ? ctx->HostCPU() : handle_device; + const Device* op_device = + op->Device() == nullptr ? ctx->HostCPU() : op->Device(); if (expected_device != actual_device) { switch (ctx->GetDevicePlacementPolicy()) { @@ -106,8 +108,8 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op, int i, " cannot compute ", op->Name(), " as input #", i, " was expected to be on ", expected_device->name(), " but is actually on ", - actual_device->name(), " (operation running on ", - op->Device()->name(), ")", + actual_device->name(), " (operation running on ", op_device->name(), + ")", " Tensors can be copied explicitly using .gpu() or .cpu() " "methods," " or transparently copied by using tf.enable_eager_execution(" @@ -118,7 +120,7 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op, int i, LOG(WARNING) << "before computing " << op->Name() << " input #" << i << " was expected to be on " << expected_device->name() << " but is actually on " << actual_device->name() - << " (operation running on " << op->Device()->name() + << " (operation running on " << op_device->name() << "). This triggers a copy which can be a performance " "bottleneck."; break; @@ -512,7 +514,8 @@ Status EagerLocalExecute(EagerOperation* op, // See WARNING comment in Execute (before kernel->Run) - would be nice to // rework to avoid this subtlety. tf_shared_lock l(*ctx->FunctionsMu()); - status = KernelAndDevice::Init(ndef, ctx->func_lib(device), kernel); + status = KernelAndDevice::Init(ndef, ctx->func_lib(device), ctx->runner(), + kernel); if (!status.ok()) { delete kernel; return status; @@ -582,6 +585,87 @@ Status EagerLocalExecute(EagerOperation* op, return status; } +std::function<void()> GetRemoteTensorDestructor( + EagerContext* ctx, eager::EagerClient* eager_client, uint64 context_id, + uint64 op_id, int output_num) { + return [ctx, eager_client, context_id, op_id, output_num]() { + std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest); + request->set_context_id(context_id); + + auto* handle_to_decref = request->add_queue()->mutable_handle_to_decref(); + handle_to_decref->set_op_id(op_id); + handle_to_decref->set_output_num(output_num); + + if (ctx->Async()) { + tensorflow::uint64 id = ctx->NextId(); + auto* node = + new eager::RemoteExecuteNode(id, std::move(request), eager_client); + ctx->ExecutorAdd(node); + } else { + eager::EnqueueRequest* actual_request = request.release(); + eager::EnqueueResponse* response = new eager::EnqueueResponse; + eager_client->EnqueueAsync( + actual_request, response, + [actual_request, response](const tensorflow::Status& s) { + delete actual_request; + delete response; + }); + } + + return tensorflow::Status::OK(); + }; +} + +// When !ctx->UseSendTensorRPC(), then tensors are shipped between remote +// devices by the receiver invoking the WorkerService.RecvTensor RPC *on the +// sender* (Rendezvous::RecvAsync() invoked by the _Recv kernel). +// +// However, in some configurations the node that has the tensor to be copied +// isn't running a server (WorkerService RPC interface). For such cases, +// this function enables sending tensors using the EagerService.SendTensor RPC +// *on the receiver*. +Status EagerRemoteSendTensor(EagerContext* ctx, TensorHandle* h, + Device* recv_device, TensorHandle** result) { + eager::EagerClient* eager_client; + uint64 context_id; + TF_RETURN_IF_ERROR( + ctx->GetClientAndContextID(recv_device, &eager_client, &context_id)); + + eager::SendTensorRequest request; + eager::SendTensorResponse response; + + request.set_context_id(context_id); + request.set_op_id(ctx->NextId()); + request.set_device_name(recv_device->name()); + + const Tensor* tensor; + TF_RETURN_IF_ERROR(h->Tensor(&tensor)); + tensor->AsProtoTensorContent(request.add_tensors()); + + const tensorflow::uint64 id = request.op_id(); + + // TODO(nareshmodi): support making this call async. + Notification n; + Status status; + eager_client->SendTensorAsync(&request, &response, + [&n, &status](const Status& s) { + status = s; + n.Notify(); + }); + n.WaitForNotification(); + if (!status.ok()) return status; + + std::function<void()> destructor = + GetRemoteTensorDestructor(ctx, eager_client, context_id, id, 0); + + *result = new TensorHandle(id, /*output_num=*/0, /*remote_shape_node_id=*/0, + tensor->dtype(), std::move(destructor), + recv_device, recv_device, ctx); + (*result)->SetRemoteShape(MakeUnique<TensorShape>(tensor->shape())); + + return Status::OK(); +} + Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, int* num_retvals) { #ifdef __ANDROID__ @@ -595,10 +679,12 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, TF_RETURN_IF_ERROR( ctx->GetClientAndContextID(op->Device(), &eager_client, &context_id)); - eager::EnqueueRequest request; + std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest); eager::EnqueueResponse response; - auto* remote_op = request.add_queue()->mutable_operation(); + request->set_context_id(context_id); + + auto* remote_op = request->add_queue()->mutable_operation(); for (int i = 0; i < op->Inputs().size(); i++) { tensorflow::Device* input_device; @@ -628,8 +714,6 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, op->Attrs().FillAttrValueMap(remote_op->mutable_attrs()); remote_op->set_device(op->Device()->name()); - request.set_context_id(context_id); - DataTypeVector output_dtypes; TF_RETURN_IF_ERROR(GetOutputDTypes(op, &output_dtypes)); @@ -651,32 +735,11 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, for (int i = 0; i < *num_retvals; i++) { // TODO(nareshmodi): Change the callback to instead add the decref to a list // of pending decrefs that we can send as a batch with the next execute. - std::function<void()> callback = [ctx, eager_client, context_id, id, i]() { - eager::EnqueueRequest request; - request.set_context_id(context_id); - - auto* handle_to_decref = request.add_queue()->mutable_handle_to_decref(); - handle_to_decref->set_op_id(id); - handle_to_decref->set_output_num(i); - - if (ctx->Async()) { - tensorflow::uint64 id = ctx->NextId(); - auto* node = new eager::RemoteExecuteNode(id, request, eager_client); - ctx->ExecutorAdd(node); - } else { - Notification n; - eager::EnqueueResponse response; - eager_client->EnqueueAsync( - &request, &response, - [&n](const tensorflow::Status& s) { n.Notify(); }); - n.WaitForNotification(); - } - - return tensorflow::Status::OK(); - }; + std::function<void()> destructor = + GetRemoteTensorDestructor(ctx, eager_client, context_id, id, i); retvals[i] = new TensorHandle(remote_op->id(), i, remote_node_id, - output_dtypes[i], std::move(callback), + output_dtypes[i], std::move(destructor), op_device, op_device, op->EagerContext()); } @@ -690,7 +753,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, } // Unable to capture via std::move, so bind instead. auto* node = new eager::RemoteExecuteNode( - remote_node_id, request, eager_client, op->Inputs(), + remote_node_id, std::move(request), eager_client, op->Inputs(), std::bind( [](const gtl::InlinedVector<TensorHandle*, 2>& retvals, const Status& status, const eager::EnqueueResponse& response) { @@ -707,7 +770,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, } else { Notification n; Status status; - eager_client->EnqueueAsync(&request, &response, + eager_client->EnqueueAsync(request.get(), &response, [&n, &status](const Status& s) { status = s; n.Notify(); @@ -936,6 +999,8 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, if (sender_is_local && recver_is_local) { return LocalEagerCopyToDevice(h, ctx, recv_device, result); + } else if (ctx->UseSendTensorRPC() && sender_is_local && !recver_is_local) { + return EagerRemoteSendTensor(ctx, h, recv_device, result); } else { string wire_id = GetUniqueWireID(); diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc index b410ea175b..dae5d1983f 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc @@ -41,17 +41,22 @@ Status KernelAndDevice::InitOp(Device* device, const NodeDef& ndef, out->device_ = device; out->kernel_.reset(k); out->flib_ = nullptr; + out->runner_ = nullptr; + out->default_runner_ = [](std::function<void()> f) { f(); }; return s; } // static Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib, + std::function<void(std::function<void()>)>* runner, KernelAndDevice* out) { OpKernel* k = nullptr; Status s = flib->CreateKernel(ndef, &k); out->device_ = flib->device(); out->kernel_.reset(k); out->flib_ = flib; + out->runner_ = runner; + out->default_runner_ = [](std::function<void()> f) { f(); }; return s; } @@ -83,10 +88,11 @@ Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors, if (stats != nullptr) { params.track_allocations = true; } - // TODO(apassos): use a thread pool. - std::function<void(std::function<void()>)> runner = - [](std::function<void()> f) { f(); }; - params.runner = &runner; + if (runner_ == nullptr) { + params.runner = &default_runner_; + } else { + params.runner = runner_; + } ScopedStepContainer step_container(0, [this](const string& name) { device_->resource_manager()->Cleanup(name).IgnoreError(); diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h index c41a0972b1..c0b676b285 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.h +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h @@ -57,6 +57,7 @@ class KernelAndDevice { // the FunctionLibraryRuntime is pushed on to the caller (see locking in // c_api.cc). static Status Init(const NodeDef& ndef, FunctionLibraryRuntime* flib, + std::function<void(std::function<void()>)>* runner, KernelAndDevice* out); // TODO(ashankar): Remove this static Status InitOp(Device* device, const NodeDef& ndef, @@ -88,6 +89,8 @@ class KernelAndDevice { checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_; Rendezvous* rendez_; DataTypeVector output_dtypes_; + std::function<void(std::function<void()>)>* runner_; + std::function<void(std::function<void()>)> default_runner_; }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc b/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc index b4349e1dee..6abe98f53c 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc @@ -107,8 +107,8 @@ void BM_KernelAndDeviceInit(int iters) { KernelAndDevice k(nullptr); tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { - TF_CHECK_OK( - KernelAndDevice::Init(ndef, env.function_library_runtime(), &k)); + TF_CHECK_OK(KernelAndDevice::Init(ndef, env.function_library_runtime(), + nullptr, &k)); } } BENCHMARK(BM_KernelAndDeviceInit); @@ -128,8 +128,8 @@ void BM_KernelAndDeviceRun(int iters) { .BuildNodeDef()); TestEnv env; KernelAndDevice kernel(nullptr); - TF_CHECK_OK( - KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel)); + TF_CHECK_OK(KernelAndDevice::Init(ndef, env.function_library_runtime(), + nullptr, &kernel)); tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { TF_CHECK_OK(kernel.Run(&inputs, &outputs, nullptr)); diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc index f9b9abcc99..85b0b79bce 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -109,6 +109,19 @@ Status TensorHandle::TensorAndDevice(const tensorflow::Tensor** tensor, return Status::OK(); } +Status TensorHandle::Shape(tensorflow::TensorShape* shape) { + if (IsRemote()) { + TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false)); + CHECK(remote_shape_ != nullptr); + *shape = *(remote_shape_.get()); + } else { + TF_RETURN_IF_ERROR(WaitReady()); + DCHECK(IsReady()); + *shape = tensor_.shape(); + } + return Status::OK(); +} + Status TensorHandle::NumDims(int* num_dims) { if (IsRemote()) { TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false)); diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h index 46bc94f875..1bc9c6531a 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -109,6 +109,8 @@ class TensorHandle : public core::RefCounted { tensorflow::Device** device, tensorflow::Device** op_device); + Status Shape(tensorflow::TensorShape* shape); + Status NumDims(int* num_dims); Status Dim(int dim_index, int64* dim); @@ -138,6 +140,12 @@ class TensorHandle : public core::RefCounted { remote_shape_ = std::move(remote_shape); } + bool OnHostCPU() { + mutex_lock ml(ctx_mutex_); + return device_ == nullptr || + (ctx_ == nullptr || ctx_->HostCPU() == device_); + } + private: // If the contents of the Tensor pointed to by this handle is yet to be // computed by a EagerNode, this function will block till that compuatation is diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index f7f2cdc14f..8096139d90 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -1966,17 +1966,9 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, device_context = device_context_map_[node->id()]; } - // Experimental: debugger (tfdb) access to intermediate node completion. - if (item.num_outputs == 0 && impl_->params_.node_outputs_cb != nullptr) { - // If the node has no output, invoke the callback with output slot set to - // -1, signifying that this is a no-output node. - s.Update(impl_->params_.node_outputs_cb(item.node->name(), -1, nullptr, - false, ctx)); - } - for (int i = 0; i < item.num_outputs; ++i) { const TensorValue val = ctx->release_output(i); - if (*ctx->is_output_dead() || val.tensor == nullptr) { + if (val.tensor == nullptr) { // Unless it's a Switch or a Recv, the node must produce a // tensor value at i-th output. if (!IsSwitch(node) && !IsRecv(node)) { @@ -2018,13 +2010,6 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, LogMemory::RecordTensorOutput(ctx->op_kernel().name(), ctx->step_id(), i, to_log); } - - // Experimental: debugger (tfdb) access to intermediate node - // outputs. - if (impl_->params_.node_outputs_cb != nullptr) { - s.Update(impl_->params_.node_outputs_cb(item.node->name(), i, - out->ref, true, ctx)); - } } else { // NOTE that std::move is used here, so val.tensor goes to // uninitialized state (val.tensor->IsInitialized return false). @@ -2036,12 +2021,6 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, LogMemory::RecordTensorOutput(ctx->op_kernel().name(), ctx->step_id(), i, *out->val); } - - // Experimental: debugger access to intermediate node outputs. - if (impl_->params_.node_outputs_cb != nullptr) { - s.Update(impl_->params_.node_outputs_cb( - item.node->name(), i, out->val.get(), false, ctx)); - } } } else { s.Update(errors::Internal("Output ", i, " of type ", diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h index e5d7b7c53c..cd01b43aea 100644 --- a/tensorflow/core/common_runtime/executor.h +++ b/tensorflow/core/common_runtime/executor.h @@ -103,7 +103,6 @@ class Executor { const Tensor* tensor, const bool is_ref, OpKernelContext* ctx)> NodeOutputsCallback; - NodeOutputsCallback node_outputs_cb = nullptr; }; typedef std::function<void(const Status&)> DoneCallback; virtual void RunAsync(const Args& args, DoneCallback done) = 0; @@ -139,8 +138,6 @@ struct LocalExecutorParams { // when the executor is deleted. std::function<Status(const NodeDef&, OpKernel**)> create_kernel; std::function<void(OpKernel*)> delete_kernel; - - Executor::Args::NodeOutputsCallback node_outputs_cb; }; ::tensorflow::Status NewLocalExecutor(const LocalExecutorParams& params, std::unique_ptr<const Graph> graph, diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index a93cfa2ec5..54bbe84b57 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -746,6 +746,8 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle, rets_alloc_attrs.push_back(ret_alloc_attrs); } + bool allow_dead_tensors = opts.allow_dead_tensors; + // The ProcFLR sends the arguments to the function from the source_device to // the target_device. So here we receive those arguments. Similarly, when the // computation is done and stored in *rets, we send the return values back @@ -756,7 +758,7 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle, device_context, args_alloc_attrs, rendezvous, remote_args, [frame, remote_args, item, source_device, target_device, target_incarnation, rendezvous, device_context, rets, done, exec_args, - rets_alloc_attrs](const Status& status) { + rets_alloc_attrs, allow_dead_tensors](const Status& status) { Status s = status; if (s.ok()) { s = frame->SetArgs(*remote_args); @@ -769,13 +771,13 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle, return; } item->exec->RunAsync( - *exec_args, - [frame, rets, done, source_device, target_device, - target_incarnation, rendezvous, device_context, remote_args, - exec_args, rets_alloc_attrs](const Status& status) { + *exec_args, [frame, rets, done, source_device, target_device, + target_incarnation, rendezvous, device_context, + remote_args, exec_args, rets_alloc_attrs, + allow_dead_tensors](const Status& status) { Status s = status; if (s.ok()) { - s = frame->ConsumeRetvals(rets); + s = frame->ConsumeRetvals(rets, allow_dead_tensors); } delete frame; if (!s.ok()) { @@ -859,14 +861,15 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, return; } + bool allow_dead_tensors = opts.allow_dead_tensors; item->exec->RunAsync( // Executor args *exec_args, // Done callback. - [frame, rets, done, exec_args](const Status& status) { + [frame, rets, done, exec_args, allow_dead_tensors](const Status& status) { Status s = status; if (s.ok()) { - s = frame->ConsumeRetvals(rets); + s = frame->ConsumeRetvals(rets, allow_dead_tensors); } delete frame; delete exec_args; diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index 3cb51b0dbc..3292ef2f62 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -41,6 +41,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/gpu/gpu_util.h" #include "tensorflow/core/common_runtime/gpu_device_context.h" #include "tensorflow/core/common_runtime/local_device.h" +#include "tensorflow/core/common_runtime/visitable_allocator.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/op_kernel.h" @@ -224,6 +225,7 @@ class BaseGPUDevice::StreamGroupFactory { int num_d2d_streams = options.experimental().num_dev_to_dev_copy_streams(); + if (num_d2d_streams == 0) num_d2d_streams = 1; if (num_d2d_streams < 1 || num_d2d_streams > 4) { LOG(ERROR) << "Illegal GPUOptions.experimental.num_dev_to_dev_copy_streams=" @@ -856,7 +858,7 @@ void BaseGPUDevice::ReinitializeDevice(OpKernelContext* context, static_cast<ConcretePerOpGpuDevice*>(device); DCHECK(concrete_device); const cudaStream_t* cuda_stream = reinterpret_cast<const cudaStream_t*>( - streams_[stream_id]->compute->implementation()->CudaStreamMemberHack()); + streams_[stream_id]->compute->implementation()->GpuStreamMemberHack()); concrete_device->Reinitialize(context, cuda_stream, tf_gpu_id_, allocator, scratch_[stream_id]); } diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc index 4898448476..3c1c31aa73 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc @@ -15,11 +15,80 @@ limitations under the License. #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" +#include "tensorflow/core/platform/stacktrace.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/protobuf/config.pb.h" namespace tensorflow { +namespace { +// The EventMgr has 1 thread for the polling loop and one to execute +// event callback functions. Issues for reconsideration: +// - Is this the right number of threads? +// - Should EventMgrs be shared between GPUDevices on a multi-GPU machine? +static const int kNumThreads = 2; +} // namespace + +namespace gpu_event_mgr { +class ThreadLabel { + public: + static const char* GetValue() { return value_; } + + // v must be a static const because value_ will capture and use its value + // until reset or thread terminates. + static void SetValue(const char* v) { value_ = v; } + + private: + static thread_local const char* value_; +}; +thread_local const char* ThreadLabel::value_ = ""; + +void WarnIfInCallback(std::function<void()> f) { + const char* label = ThreadLabel::GetValue(); + if (label && !strcmp(label, "gpu_event_mgr")) { + if (f) { + f(); + } else { + LOG(WARNING) << "Executing inside EventMgr callback thread: " + << CurrentStackTrace(); + } + } +} + +void InitThreadpoolLabels(thread::ThreadPool* threadpool) { + static const char* label = "gpu_event_mgr"; + mutex mu; + int init_count = 0; + condition_variable all_initialized; + int exit_count = 0; + condition_variable ready_to_exit; + const int num_threads = threadpool->NumThreads(); + for (int i = 0; i < num_threads; ++i) { + threadpool->Schedule([num_threads, &mu, &init_count, &all_initialized, + &exit_count, &ready_to_exit]() { + gpu_event_mgr::ThreadLabel::SetValue(label); + mutex_lock l(mu); + ++init_count; + if (init_count == num_threads) { + all_initialized.notify_all(); + } + while (init_count < num_threads) { + all_initialized.wait(l); + } + if (++exit_count == num_threads) { + ready_to_exit.notify_all(); + } + }); + } + { + mutex_lock l(mu); + while (exit_count < num_threads) { + ready_to_exit.wait(l); + } + } +} +} // namespace gpu_event_mgr + EventMgr::EventMgr(se::StreamExecutor* se, const GPUOptions& gpu_options) : exec_(se), deferred_bytes_threshold_(gpu_options.deferred_deletion_bytes() @@ -31,9 +100,8 @@ EventMgr::EventMgr(se::StreamExecutor* se, const GPUOptions& gpu_options) accumulated_stream_(nullptr), accumulated_tensors_(new TensorReferenceVector), accumulated_tensor_bytes_(0), - // threadpool_ has 1 thread for the polling loop, and one to execute - // event callback functions. Maybe we should have more? - threadpool_(Env::Default(), "GPU_Event_Manager", 2) { + threadpool_(Env::Default(), "GPU_Event_Manager", kNumThreads) { + gpu_event_mgr::InitThreadpoolLabels(&threadpool_); StartPollingLoop(); } diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h index b26f88a201..f0a109cc10 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h +++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h @@ -39,6 +39,25 @@ namespace tensorflow { class GPUOptions; +// The callback provided to EventMgr::ThenExecute must not block or take a long +// time. If it does, performance may be impacted and GPU memory may be +// exhausted. This macro is for checking that an EventMgr thread is not +// accidentally entering blocking parts of the code, e.g. the RPC subsystem. +// +// Intended use is something like +// +// void RespondToAnRPC(Params* params) { +// WARN_IF_IN_EVENT_MGR_THREAD; +// if (params->status.ok()) { ... +// +namespace gpu_event_mgr { +// Logs a stack trace if current execution thread belongs to this EventMgr +// object. If f is not nullptr, executes instead of logging the stack trace. +// trace. +void WarnIfInCallback(std::function<void()> f); +} // namespace gpu_event_mgr +#define WARN_IF_IN_EVENT_MGR_THREAD gpu_event_mgr::WarnIfInCallback(nullptr) + // An object to keep track of pending Events in the StreamExecutor streams // and associated Tensors that cannot safely be deleted until the associated // Events are recorded. @@ -74,6 +93,9 @@ class EventMgr { FreeMemory(to_free); } + // Execute func when all pending stream actions have completed. + // func must be brief and non-blocking since it executes in the one + // thread used for all such callbacks and also buffer deletions. inline void ThenExecute(se::Stream* stream, std::function<void()> func) { ToFreeVector to_free; { diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc b/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc index c5ff6c97a1..d2adf699f5 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include <atomic> #include "tensorflow/core/common_runtime/gpu/gpu_init.h" +#include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/config.pb.h" @@ -243,6 +244,28 @@ TEST(EventMgr, NonEmptyShutdown) { } } +// Tests that WarnIfInCallback() triggers correctly. +TEST(EventMgr, WarnIfInCallback) { + auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie(); + EventMgr em(stream_exec, GPUOptions()); + TEST_EventMgrHelper th(&em); + std::unique_ptr<se::Stream> stream(new se::Stream(stream_exec)); + CHECK(stream); + stream->Init(); + bool hit = false; + gpu_event_mgr::WarnIfInCallback([&hit] { hit = true; }); + EXPECT_FALSE(hit); + Notification note; + em.ThenExecute(stream.get(), [&hit, ¬e]() { + gpu_event_mgr::WarnIfInCallback([&hit, ¬e] { + hit = true; + note.Notify(); + }); + }); + note.WaitForNotification(); + EXPECT_TRUE(hit); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/placer.cc b/tensorflow/core/common_runtime/placer.cc index 1f0773d387..6781c87f6c 100644 --- a/tensorflow/core/common_runtime/placer.cc +++ b/tensorflow/core/common_runtime/placer.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/util/status_util.h" namespace tensorflow { @@ -822,10 +823,10 @@ Status Placer::Run() { std::vector<Device*>* devices; Status status = colocation_graph.GetDevicesForNode(node, &devices); if (!status.ok()) { - return AttachDef( - errors::InvalidArgument("Cannot assign a device for operation '", - node->name(), "': ", status.error_message()), - *node); + return AttachDef(errors::InvalidArgument( + "Cannot assign a device for operation ", + RichNodeName(node), ": ", status.error_message()), + *node); } // Returns the first device in sorted devices list so we will always @@ -869,10 +870,10 @@ Status Placer::Run() { std::vector<Device*>* devices; Status status = colocation_graph.GetDevicesForNode(node, &devices); if (!status.ok()) { - return AttachDef( - errors::InvalidArgument("Cannot assign a device for operation '", - node->name(), "': ", status.error_message()), - *node); + return AttachDef(errors::InvalidArgument( + "Cannot assign a device for operation ", + RichNodeName(node), ": ", status.error_message()), + *node); } int assigned_device = -1; @@ -938,4 +939,22 @@ void Placer::LogDeviceAssignment(const Node* node) const { } } +bool Placer::ClientHandlesErrorFormatting() const { + return options_ != nullptr && + options_->config.experimental().client_handles_error_formatting(); +} + +// Returns the node name in single quotes. If the client handles formatted +// errors, appends a formatting tag which the client will reformat into, for +// example, " (defined at filename:123)". +string Placer::RichNodeName(const Node* node) const { + string quoted_name = strings::StrCat("'", node->name(), "'"); + if (ClientHandlesErrorFormatting()) { + string file_and_line = error_format_tag(*node, "${file}:${line}"); + return strings::StrCat(quoted_name, " (defined at ", file_and_line, ")"); + } else { + return quoted_name; + } +} + } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/placer.h b/tensorflow/core/common_runtime/placer.h index 75dce7c7fe..fce87269c5 100644 --- a/tensorflow/core/common_runtime/placer.h +++ b/tensorflow/core/common_runtime/placer.h @@ -87,6 +87,8 @@ class Placer { // placement if the SessionOptions entry in 'options_' requests it. void AssignAndLog(int assigned_device, Node* node) const; void LogDeviceAssignment(const Node* node) const; + bool ClientHandlesErrorFormatting() const; + string RichNodeName(const Node* node) const; Graph* const graph_; // Not owned. const DeviceSet* const devices_; // Not owned. diff --git a/tensorflow/core/common_runtime/placer_test.cc b/tensorflow/core/common_runtime/placer_test.cc index 07a7724f16..cede899842 100644 --- a/tensorflow/core/common_runtime/placer_test.cc +++ b/tensorflow/core/common_runtime/placer_test.cc @@ -1142,6 +1142,50 @@ TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacement) { EXPECT_TRUE(str_util::StrContains(s.error_message(), "/device:fakegpu:11")); } +// Test that the "Cannot assign a device" error message contains a format tag +// when requested. +TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacementFormatTag) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("TestDevice", + b.opts().WithName("in").WithDevice("/device:fakegpu:11")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + SessionOptions options; + options.config.mutable_experimental()->set_client_handles_error_formatting( + true); + Status s = Place(&g, &options); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); + EXPECT_TRUE( + str_util::StrContains(s.error_message(), + "Cannot assign a device for operation 'in'" + " (defined at ^^node:in:${file}:${line}^^)")); +} + +// Test that the "Cannot assign a device" error message does not contain a +// format tag when not it shouldn't +TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacementNoFormatTag) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("TestDevice", + b.opts().WithName("in").WithDevice("/device:fakegpu:11")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + SessionOptions options; + options.config.mutable_experimental()->set_client_handles_error_formatting( + false); + Status s = Place(&g, &options); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); + EXPECT_TRUE(str_util::StrContains( + s.error_message(), "Cannot assign a device for operation 'in'")); + EXPECT_FALSE(str_util::StrContains( + s.error_message(), "'in' (defined at ^^node:in:${file}:${line}^^)")); +} + // Test that placement fails when a node requests an explicit device that is not // supported by the registered kernels if allow_soft_placement is no set. TEST_F(PlacerTest, TestUnsupportedDeviceNoAllowSoftPlacement) { diff --git a/tensorflow/core/common_runtime/process_state.cc b/tensorflow/core/common_runtime/process_state.cc index 4d83b25ce6..447338e7bd 100644 --- a/tensorflow/core/common_runtime/process_state.cc +++ b/tensorflow/core/common_runtime/process_state.cc @@ -71,7 +71,7 @@ ProcessState::MemDesc ProcessState::PtrType(const void* ptr) { return MemDesc(); } -Allocator* ProcessState::GetCPUAllocator(int numa_node) { +VisitableAllocator* ProcessState::GetCPUAllocator(int numa_node) { CHECK_GE(numa_node, 0); if (!numa_enabled_) numa_node = 0; mutex_lock lock(mu_); diff --git a/tensorflow/core/common_runtime/process_state.h b/tensorflow/core/common_runtime/process_state.h index 0f4ae230bb..2892677333 100644 --- a/tensorflow/core/common_runtime/process_state.h +++ b/tensorflow/core/common_runtime/process_state.h @@ -65,7 +65,7 @@ class ProcessState { // Returns the one CPUAllocator used for the given numa_node. // TEMPORARY: ignores numa_node. - Allocator* GetCPUAllocator(int numa_node); + VisitableAllocator* GetCPUAllocator(int numa_node); typedef std::unordered_map<const void*, MemDesc> MDMap; @@ -87,7 +87,7 @@ class ProcessState { mutex mu_; - std::vector<Allocator*> cpu_allocators_ GUARDED_BY(mu_); + std::vector<VisitableAllocator*> cpu_allocators_ GUARDED_BY(mu_); virtual ~ProcessState(); diff --git a/tensorflow/core/common_runtime/session.cc b/tensorflow/core/common_runtime/session.cc index 4a9248171b..8c30beeec2 100644 --- a/tensorflow/core/common_runtime/session.cc +++ b/tensorflow/core/common_runtime/session.cc @@ -53,27 +53,33 @@ Status Session::PRun(const string& handle, Session* NewSession(const SessionOptions& options) { SessionFactory* factory; - const Status s = SessionFactory::GetFactory(options, &factory); + Status s = SessionFactory::GetFactory(options, &factory); if (!s.ok()) { LOG(ERROR) << s; return nullptr; } - return factory->NewSession(options); + Session* out_session; + s = NewSession(options, &out_session); + if (!s.ok()) { + LOG(ERROR) << "Failed to create session: " << s; + return nullptr; + } + return out_session; } Status NewSession(const SessionOptions& options, Session** out_session) { SessionFactory* factory; - const Status s = SessionFactory::GetFactory(options, &factory); + Status s = SessionFactory::GetFactory(options, &factory); if (!s.ok()) { *out_session = nullptr; LOG(ERROR) << s; return s; } - *out_session = factory->NewSession(options); - if (!*out_session) { - return errors::Internal("Failed to create session."); + s = factory->NewSession(options, out_session); + if (!s.ok()) { + *out_session = nullptr; } - return Status::OK(); + return s; } Status Reset(const SessionOptions& options, diff --git a/tensorflow/core/common_runtime/session_factory.h b/tensorflow/core/common_runtime/session_factory.h index df3198a70d..81c172c6ae 100644 --- a/tensorflow/core/common_runtime/session_factory.h +++ b/tensorflow/core/common_runtime/session_factory.h @@ -30,7 +30,12 @@ struct SessionOptions; class SessionFactory { public: - virtual Session* NewSession(const SessionOptions& options) = 0; + // Creates a new session and stores it in *out_session, or fails with an error + // status if the Session could not be created. Caller takes ownership of + // *out_session if this returns Status::OK(). + virtual Status NewSession(const SessionOptions& options, + Session** out_session) = 0; + virtual bool AcceptsOptions(const SessionOptions& options) = 0; // Abort and close all existing sessions, disconnecting their resources from diff --git a/tensorflow/core/common_runtime/session_test.cc b/tensorflow/core/common_runtime/session_test.cc index feaf29c7bb..1fa5aad60c 100644 --- a/tensorflow/core/common_runtime/session_test.cc +++ b/tensorflow/core/common_runtime/session_test.cc @@ -47,8 +47,10 @@ class FakeSessionFactory : public SessionFactory { return str_util::StartsWith(options.target, "fake"); } - Session* NewSession(const SessionOptions& options) override { - return nullptr; + Status NewSession(const SessionOptions& options, + Session** out_session) override { + *out_session = nullptr; + return Status::OK(); } }; class FakeSessionRegistrar { diff --git a/tensorflow/core/common_runtime/threadpool_device.cc b/tensorflow/core/common_runtime/threadpool_device.cc index 74a87215e1..7406ecf4f8 100644 --- a/tensorflow/core/common_runtime/threadpool_device.cc +++ b/tensorflow/core/common_runtime/threadpool_device.cc @@ -111,7 +111,21 @@ Status ThreadPoolDevice::MakeTensorFromProto( } #ifdef INTEL_MKL -REGISTER_MEM_ALLOCATOR("MklCPUAllocator", 200, MklCPUAllocator); +namespace { +class MklCPUAllocatorFactory : public AllocatorFactory { + public: + bool NumaEnabled() override { return false; } + + Allocator* CreateAllocator() override { return new MklCPUAllocator; } + + // Note: Ignores numa_node, for now. + virtual SubAllocator* CreateSubAllocator(int numa_node) { + return new MklSubAllocator; + } +}; + +REGISTER_MEM_ALLOCATOR("MklCPUAllocator", 200, MklCPUAllocatorFactory); +} // namespace #endif } // namespace tensorflow |