aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime')
-rw-r--r--tensorflow/core/common_runtime/bfc_allocator.cc52
-rw-r--r--tensorflow/core/common_runtime/bfc_allocator.h37
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc45
-rw-r--r--tensorflow/core/common_runtime/direct_session_test.cc139
-rw-r--r--tensorflow/core/common_runtime/eager/context.cc26
-rw-r--r--tensorflow/core/common_runtime/eager/context.h13
-rw-r--r--tensorflow/core/common_runtime/eager/execute.cc133
-rw-r--r--tensorflow/core/common_runtime/eager/kernel_and_device.cc14
-rw-r--r--tensorflow/core/common_runtime/eager/kernel_and_device.h3
-rw-r--r--tensorflow/core/common_runtime/eager/kernel_and_device_test.cc8
-rw-r--r--tensorflow/core/common_runtime/eager/tensor_handle.cc13
-rw-r--r--tensorflow/core/common_runtime/eager/tensor_handle.h8
-rw-r--r--tensorflow/core/common_runtime/executor.cc23
-rw-r--r--tensorflow/core/common_runtime/executor.h3
-rw-r--r--tensorflow/core/common_runtime/function.cc19
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.cc4
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc74
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_event_mgr.h22
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc23
-rw-r--r--tensorflow/core/common_runtime/placer.cc35
-rw-r--r--tensorflow/core/common_runtime/placer.h2
-rw-r--r--tensorflow/core/common_runtime/placer_test.cc44
-rw-r--r--tensorflow/core/common_runtime/process_state.cc2
-rw-r--r--tensorflow/core/common_runtime/process_state.h4
-rw-r--r--tensorflow/core/common_runtime/session.cc20
-rw-r--r--tensorflow/core/common_runtime/session_factory.h7
-rw-r--r--tensorflow/core/common_runtime/session_test.cc6
-rw-r--r--tensorflow/core/common_runtime/threadpool_device.cc16
28 files changed, 622 insertions, 173 deletions
diff --git a/tensorflow/core/common_runtime/bfc_allocator.cc b/tensorflow/core/common_runtime/bfc_allocator.cc
index 9cda17867b..3bf0532491 100644
--- a/tensorflow/core/common_runtime/bfc_allocator.cc
+++ b/tensorflow/core/common_runtime/bfc_allocator.cc
@@ -155,10 +155,6 @@ bool BFCAllocator::Extend(size_t alignment, size_t rounded_bytes) {
region_manager_.set_handle(c->ptr, h);
- // TODO(vrv): Try to merge this new region with an existing region,
- // if the address space is contiguous, to avoid fragmentation
- // across regions.
-
// Insert the chunk into the right bin.
InsertFreeChunkIntoBin(h);
@@ -465,49 +461,33 @@ void BFCAllocator::FreeAndMaybeCoalesce(BFCAllocator::ChunkHandle h) {
Chunk* c = ChunkFromHandle(h);
CHECK(c->in_use() && (c->bin_num == kInvalidBinNum));
- // Mark the chunk as no longer in use
+ // Mark the chunk as no longer in use.
c->allocation_id = -1;
// Updates the stats.
stats_.bytes_in_use -= c->size;
- // This chunk is no longer in-use, consider coalescing the chunk
- // with adjacent chunks.
- ChunkHandle chunk_to_reassign = h;
-
- // If the next chunk is free, coalesce the two
- if (c->next != kInvalidChunkHandle) {
- Chunk* cnext = ChunkFromHandle(c->next);
- if (!cnext->in_use()) {
- // VLOG(8) << "Chunk at " << cnext->ptr << " merging with c " <<
- // c->ptr;
-
- chunk_to_reassign = h;
+ ChunkHandle coalesced_chunk = h;
- // Deletes c->next
- RemoveFreeChunkFromBin(c->next);
- Merge(h, ChunkFromHandle(h)->next);
- }
+ // If the next chunk is free, merge it into c and delete it.
+ if (c->next != kInvalidChunkHandle && !ChunkFromHandle(c->next)->in_use()) {
+ // VLOG(8) << "Merging c->next " << ChunkFromHandle(c->next)->ptr
+ // << " with c " << c->ptr;
+ RemoveFreeChunkFromBin(c->next);
+ Merge(h, c->next);
}
- // If the previous chunk is free, coalesce the two
- c = ChunkFromHandle(h);
- if (c->prev != kInvalidChunkHandle) {
- Chunk* cprev = ChunkFromHandle(c->prev);
- if (!cprev->in_use()) {
- // VLOG(8) << "Chunk at " << c->ptr << " merging into c->prev "
- // << cprev->ptr;
-
- chunk_to_reassign = c->prev;
+ // If the previous chunk is free, merge c into it and delete c.
+ if (c->prev != kInvalidChunkHandle && !ChunkFromHandle(c->prev)->in_use()) {
+ // VLOG(8) << "Merging c " << c->ptr << " into c->prev "
+ // << ChunkFromHandle(c->prev)->ptr;
- // Deletes c
- RemoveFreeChunkFromBin(c->prev);
- Merge(ChunkFromHandle(h)->prev, h);
- c = ChunkFromHandle(h);
- }
+ coalesced_chunk = c->prev;
+ RemoveFreeChunkFromBin(c->prev);
+ Merge(c->prev, h);
}
- InsertFreeChunkIntoBin(chunk_to_reassign);
+ InsertFreeChunkIntoBin(coalesced_chunk);
}
void BFCAllocator::AddAllocVisitor(Visitor visitor) {
diff --git a/tensorflow/core/common_runtime/bfc_allocator.h b/tensorflow/core/common_runtime/bfc_allocator.h
index 52aedb1e9c..580e61e2ea 100644
--- a/tensorflow/core/common_runtime/bfc_allocator.h
+++ b/tensorflow/core/common_runtime/bfc_allocator.h
@@ -88,11 +88,20 @@ class BFCAllocator : public VisitableAllocator {
static const int kInvalidBinNum = -1;
static const int kNumBins = 21;
- // Chunks point to memory. Their prev/next pointers form a
- // doubly-linked list of addresses sorted by base address that
- // must be contiguous. Chunks contain information about whether
- // they are in use or whether they are free, and contain a pointer
- // to the bin they are in.
+ // A Chunk points to a piece of memory that's either entirely free or entirely
+ // in use by one user memory allocation.
+ //
+ // An AllocationRegion's memory is split up into one or more disjoint Chunks,
+ // which together cover the whole region without gaps. Chunks participate in
+ // a doubly-linked list, and the prev/next pointers point to the physically
+ // adjacent chunks.
+ //
+ // Since a chunk cannot be partially in use, we may need to split a free chunk
+ // in order to service a user allocation. We always merge adjacent free
+ // chunks.
+ //
+ // Chunks contain information about whether they are in use or whether they
+ // are free, and contain a pointer to the bin they are in.
struct Chunk {
size_t size = 0; // Full size of buffer.
@@ -177,8 +186,12 @@ class BFCAllocator : public VisitableAllocator {
static const size_t kMinAllocationBits = 8;
static const size_t kMinAllocationSize = 1 << kMinAllocationBits;
- // AllocationRegion maps pointers to ChunkHandles for a single
- // contiguous memory region.
+ // BFCAllocator allocates memory into a collection of disjoint
+ // AllocationRegions. Each AllocationRegion corresponds to one call to
+ // SubAllocator::Alloc().
+ //
+ // An AllocationRegion contains one or more Chunks, covering all of its
+ // memory. Its primary job is to map a pointers to ChunkHandles.
//
// This class is thread-compatible.
class AllocationRegion {
@@ -191,18 +204,14 @@ class BFCAllocator : public VisitableAllocator {
DCHECK_EQ(0, memory_size % kMinAllocationSize);
const size_t n_handles =
(memory_size + kMinAllocationSize - 1) / kMinAllocationSize;
- handles_ = new ChunkHandle[n_handles];
+ handles_.reset(new ChunkHandle[n_handles]);
for (size_t i = 0; i < n_handles; i++) {
handles_[i] = kInvalidChunkHandle;
}
}
- AllocationRegion() {}
-
- ~AllocationRegion() { delete[] handles_; }
-
+ AllocationRegion() = default;
AllocationRegion(AllocationRegion&& other) { Swap(other); }
-
AllocationRegion& operator=(AllocationRegion&& other) {
Swap(other);
return *this;
@@ -241,7 +250,7 @@ class BFCAllocator : public VisitableAllocator {
// Array of size "memory_size / kMinAllocationSize". It is
// indexed by (p-base) / kMinAllocationSize, contains ChunkHandle
// for the memory allocation represented by "p"
- ChunkHandle* handles_ = nullptr;
+ std::unique_ptr<ChunkHandle[]> handles_;
TF_DISALLOW_COPY_AND_ASSIGN(AllocationRegion);
};
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index f903faf1bd..d1fd930d25 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -146,18 +146,15 @@ class DirectSessionFactory : public SessionFactory {
return options.target.empty();
}
- Session* NewSession(const SessionOptions& options) override {
+ Status NewSession(const SessionOptions& options,
+ Session** out_session) override {
// Must do this before the CPU allocator is created.
if (options.config.graph_options().build_cost_model() > 0) {
EnableCPUAllocatorFullStats(true);
}
std::vector<Device*> devices;
- const Status s = DeviceFactory::AddDevices(
- options, "/job:localhost/replica:0/task:0", &devices);
- if (!s.ok()) {
- LOG(ERROR) << s;
- return nullptr;
- }
+ TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
+ options, "/job:localhost/replica:0/task:0", &devices));
DirectSession* session =
new DirectSession(options, new DeviceMgr(devices), this);
@@ -165,7 +162,8 @@ class DirectSessionFactory : public SessionFactory {
mutex_lock l(sessions_lock_);
sessions_.push_back(session);
}
- return session;
+ *out_session = session;
+ return Status::OK();
}
Status Reset(const SessionOptions& options,
@@ -237,7 +235,11 @@ void DirectSession::SchedClosure(thread::ThreadPool* pool,
// safe given the reasoning above.
c();
#else
- pool->Schedule(std::move(c));
+ if (pool != nullptr) {
+ pool->Schedule(std::move(c));
+ } else {
+ c();
+ }
#endif // __ANDROID__
}
@@ -524,8 +526,9 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
}
}
- if (run_options.inter_op_thread_pool() < 0 ||
- run_options.inter_op_thread_pool() >= thread_pools_.size()) {
+ if (run_options.inter_op_thread_pool() < -1 ||
+ run_options.inter_op_thread_pool() >=
+ static_cast<int32>(thread_pools_.size())) {
run_state.executors_done.Notify();
delete barrier;
return errors::InvalidArgument("Invalid inter_op_thread_pool: ",
@@ -550,7 +553,19 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
}
thread::ThreadPool* pool =
- thread_pools_[run_options.inter_op_thread_pool()].first;
+ run_options.inter_op_thread_pool() >= 0
+ ? thread_pools_[run_options.inter_op_thread_pool()].first
+ : nullptr;
+
+ if (pool == nullptr) {
+ // We allow using the caller thread only when having a single executor
+ // specified.
+ if (executors_and_keys->items.size() > 1) {
+ pool = thread_pools_[0].first;
+ } else {
+ VLOG(1) << "Executing Session::Run() synchronously!";
+ }
+ }
Executor::Args::Runner default_runner = [this,
pool](Executor::Args::Closure c) {
@@ -702,7 +717,8 @@ Status DirectSession::Run(const RunOptions& run_options,
// Receive outputs.
if (outputs) {
std::vector<Tensor> sorted_outputs;
- const Status s = call_frame.ConsumeRetvals(&sorted_outputs);
+ const Status s = call_frame.ConsumeRetvals(
+ &sorted_outputs, /* allow_dead_tensors = */ false);
if (errors::IsInternal(s)) {
return errors::InvalidArgument(s.error_message());
} else if (!s.ok()) {
@@ -1188,12 +1204,11 @@ Status DirectSession::CreateExecutors(
delete kernel;
}
};
- params.node_outputs_cb = node_outputs_callback_;
optimizer.Optimize(lib, options_.env, device, &iter->second,
/*shape_map=*/nullptr);
- // EXPERIMENTAL: tfdbg inserts debug nodes in the graph.
+ // TensorFlow Debugger (tfdbg) inserts debug nodes in the graph.
const DebugOptions& debug_options =
options.callable_options.run_options().debug_options();
if (!debug_options.debug_tensor_watch_opts().empty()) {
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index 142d613129..4b51b20bb1 100644
--- a/tensorflow/core/common_runtime/direct_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <map>
#include <memory>
#include <string>
+#include <thread>
#include <unordered_map>
#include <vector>
@@ -896,6 +897,125 @@ TEST(DirectSessionTest, FetchMultipleTimes) {
}
}
+TEST(DirectSessionTest, MultipleFeedTestSomeSyncRun) {
+ GraphDef def;
+ Graph g(OpRegistry::Global());
+ RunOptions run_options;
+ run_options.set_inter_op_thread_pool(-1);
+
+ Tensor first_value(DT_FLOAT, TensorShape({}));
+ first_value.scalar<float>()() = 1.0;
+ Node* first_const = test::graph::Constant(&g, first_value);
+ Node* first_identity = test::graph::Identity(&g, first_const);
+
+ Tensor second_value(DT_FLOAT, TensorShape({}));
+ second_value.scalar<float>()() = 2.0;
+ Node* second_const = test::graph::Constant(&g, second_value);
+ Node* second_identity = test::graph::Identity(&g, second_const);
+
+ test::graph::ToGraphDef(&g, &def);
+
+ auto session = CreateSession();
+ ASSERT_TRUE(session != nullptr);
+ TF_ASSERT_OK(session->Create(def));
+
+ std::vector<Tensor> outputs;
+
+ // Fetch without feeding.
+ Status s = session->Run(
+ run_options, {},
+ {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
+ &outputs, nullptr);
+ TF_ASSERT_OK(s);
+ ASSERT_EQ(2, outputs.size());
+ ASSERT_EQ(1.0, outputs[0].flat<float>()(0));
+ ASSERT_EQ(2.0, outputs[1].flat<float>()(0));
+
+ s = session->Run(
+ {}, {second_identity->name() + ":0", first_identity->name() + ":0"}, {},
+ &outputs);
+ TF_ASSERT_OK(s);
+ ASSERT_EQ(2, outputs.size());
+ ASSERT_EQ(2.0, outputs[0].flat<float>()(0));
+ ASSERT_EQ(1.0, outputs[1].flat<float>()(0));
+
+ Tensor value_11(DT_FLOAT, TensorShape({}));
+ value_11.scalar<float>()() = 11.0;
+ Tensor value_22(DT_FLOAT, TensorShape({}));
+ value_22.scalar<float>()() = 22.0;
+
+ // Feed [first_const, second_const]
+ s = session->Run(
+ {{first_const->name(), value_11}, {second_const->name(), value_22}},
+ {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
+ &outputs);
+ TF_ASSERT_OK(s);
+ ASSERT_EQ(2, outputs.size());
+ ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
+ ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
+
+ // Feed [second_const, first_const]
+ s = session->Run(
+ {{second_const->name(), value_22}, {first_const->name(), value_11}},
+ {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
+ &outputs);
+ TF_ASSERT_OK(s);
+ ASSERT_EQ(2, outputs.size());
+ ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
+ ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
+
+ // Feed [first_const, first_const]
+ s = session->Run(
+ run_options,
+ {{first_const->name(), value_11}, {first_const->name(), value_22}},
+ {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
+ &outputs, nullptr);
+ EXPECT_TRUE(errors::IsInvalidArgument(s));
+ EXPECT_TRUE(str_util::StrContains(s.error_message(), "fed more than once"));
+}
+
+REGISTER_OP("ThreadID").Input("x: int64").Output("y: int64").Doc(R"doc(
+ThreadID returns the thread ID that called compute.
+
+x: int64
+y: int64
+)doc");
+
+// The ThreadID kernel returns the thread ID that executed Compute.
+class ThreadIDOp : public OpKernel {
+ public:
+ explicit ThreadIDOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ void Compute(OpKernelContext* ctx) override {
+ Tensor* out_tensor = nullptr;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output("y", TensorShape({}), &out_tensor));
+ std::hash<std::thread::id> hasher;
+ out_tensor->scalar<int64>()() =
+ static_cast<int64>(hasher(std::this_thread::get_id()));
+ }
+};
+REGISTER_KERNEL_BUILDER(Name("ThreadID").Device(DEVICE_CPU), ThreadIDOp);
+
+TEST(DirectSessionTest, SessionSyncRun) {
+ Graph g(OpRegistry::Global());
+ Tensor vx(DT_INT64, TensorShape({}));
+ vx.scalar<int64>()() = 17;
+ Node* x = test::graph::Constant(&g, vx);
+ Node* y = test::graph::Unary(&g, "ThreadID", x);
+ GraphDef def;
+ test::graph::ToGraphDef(&g, &def);
+ auto sess = CreateSession();
+ TF_ASSERT_OK(sess->Create(def));
+ std::vector<Tensor> outputs;
+ RunOptions run_opts;
+ run_opts.set_inter_op_thread_pool(-1);
+ auto s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs, nullptr);
+
+ std::hash<std::thread::id> hasher;
+ EXPECT_EQ(static_cast<int64>(hasher(std::this_thread::get_id())),
+ static_cast<int64>(outputs[0].scalar<int64>()()));
+}
+
REGISTER_OP("Darth").Input("x: float").Output("y: float").Doc(R"doc(
Darth promises one return value.
@@ -1400,6 +1520,7 @@ static void TestSessionInterOpThreadsImpl(bool use_function_lib,
p = options.config.add_session_inter_op_thread_pool();
if (use_global_pools) p->set_global_name("small pool");
p->set_num_threads(1);
+ const int kSyncPool = -1;
const int kLargePool = 0;
const int kSmallPool = 1;
@@ -1442,7 +1563,11 @@ static void TestSessionInterOpThreadsImpl(bool use_function_lib,
EXPECT_FLOAT_EQ(1.2, flat(0));
num_done.fetch_add(1);
};
- tp->Schedule(fn);
+ if (tp != nullptr) {
+ tp->Schedule(fn);
+ } else {
+ fn();
+ }
};
// For blocking states:
@@ -1463,9 +1588,10 @@ static void TestSessionInterOpThreadsImpl(bool use_function_lib,
tp1 = new thread::ThreadPool(Env::Default(), "tp1", 5);
- // Launch 2 session run calls. Neither will finish until the blocking op is
+ // Launch a session run call. It will not finish until the blocking op is
// unblocked, because it is using all threads in the small pool.
add_session_run_call(tp1, y, kSmallPool);
+
blocking_op_state->AwaitState(1); // Wait for the blocking op to Compute.
// These will block on <BlockingOpState>.
@@ -1484,10 +1610,15 @@ static void TestSessionInterOpThreadsImpl(bool use_function_lib,
delete tp2;
EXPECT_EQ(kUnblockedThreads, num_done.load());
+ // Launch a session call using this thread. This will finish as it runs
+ // synchronously in this thread.
+ add_session_run_call(nullptr, x, kSyncPool);
+
// Unblock the blocked op and wait for the blocked functions to finish.
blocking_op_state->MoveToState(1, 2);
delete tp1;
- EXPECT_EQ(kUnblockedThreads + kBlockedThreads + 1, num_done.load());
+
+ EXPECT_EQ(kUnblockedThreads + kBlockedThreads + 1 + 1, num_done.load());
delete blocking_op_state;
blocking_op_state = nullptr;
}
@@ -1532,7 +1663,7 @@ TEST(DirectSessionTest, TestSessionInterOpThreadsInvalidOptions) {
{
std::unique_ptr<Session> session(NewSession(options));
TF_ASSERT_OK(session->Create(def));
- for (int pool_num = -1; pool_num <= 1; pool_num += 2) {
+ for (int pool_num = -2; pool_num <= 1; pool_num += 3) {
RunOptions run_options;
run_options.set_inter_op_thread_pool(pool_num);
std::vector<Tensor> outputs;
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index 70208fb6d1..5e0f0a45f8 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -17,8 +17,20 @@ limitations under the License.
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
+#include "tensorflow/core/util/env_var.h"
namespace tensorflow {
+namespace {
+
+bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) {
+ bool val;
+ if (ReadBoolFromEnvVar(env_var_name, default_val, &val).ok()) {
+ return val;
+ }
+ return default_val;
+}
+
+} // namespace
EagerContext::EagerContext(const SessionOptions& opts,
ContextDevicePlacementPolicy default_policy,
@@ -34,8 +46,16 @@ EagerContext::EagerContext(const SessionOptions& opts,
local_device_manager_.get(), opts.env, TF_GRAPH_DEF_VERSION,
&func_lib_def_, {}, thread_pool_.get())),
log_device_placement_(opts.config.log_device_placement()),
- async_default_(async) {
+ async_default_(async),
+ use_send_tensor_rpc_(false) {
InitDeviceMapAndAsync();
+ if (opts.config.inter_op_parallelism_threads() > 0) {
+ runner_ = [this](std::function<void()> closure) {
+ this->thread_pool_->Schedule(closure);
+ };
+ } else {
+ runner_ = [](std::function<void()> closure) { closure(); };
+ }
}
#ifndef __ANDROID__
@@ -59,7 +79,9 @@ EagerContext::EagerContext(
remote_device_manager_(std::move(remote_device_manager)),
server_(std::move(server)),
remote_eager_workers_(std::move(remote_eager_workers)),
- remote_contexts_(remote_contexts) {
+ remote_contexts_(remote_contexts),
+ use_send_tensor_rpc_(
+ ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", false)) {
InitDeviceMapAndAsync();
}
#endif
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index 864f514a19..4a180e074d 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -105,6 +105,8 @@ class EagerContext {
EagerExecutor* Executor() { return &executor_; }
+ std::function<void(std::function<void()>)>* runner() { return &runner_; }
+
// Sets whether this thread should run in synchronous or asynchronous mode.
Status SetAsyncForThread(bool async);
@@ -180,6 +182,11 @@ class EagerContext {
#ifndef __ANDROID__
Status GetClientAndContextID(Device* device, eager::EagerClient** client,
uint64* context_id);
+
+ // If true, then tensors should be shipped across processes via the
+ // EagerService.SendTensor RPC. If false, _Send/_Recv ops should be used
+ // instead (which in-turn use WorkerService.RecvTensor RPCs.
+ bool UseSendTensorRPC() { return use_send_tensor_rpc_; }
#endif
private:
void InitDeviceMapAndAsync();
@@ -214,6 +221,8 @@ class EagerContext {
// session->devices[i].
const std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
+ std::function<void(std::function<void()>)> runner_;
+
mutex cache_mu_;
std::unordered_map<Fprint128, KernelAndDevice*, Fprint128Hasher> kernel_cache_
GUARDED_BY(cache_mu_);
@@ -235,16 +244,18 @@ class EagerContext {
const std::unique_ptr<DeviceMgr> remote_device_manager_;
+#ifndef __ANDROID__
// The server_ is not const since we release it when the context is destroyed.
// Therefore the server_ object is not marked as const (even though it should
// be).
-#ifndef __ANDROID__
std::unique_ptr<ServerInterface> server_;
const std::unique_ptr<eager::EagerClientCache> remote_eager_workers_;
const gtl::FlatMap<string, uint64> remote_contexts_;
gtl::FlatMap<Device*, std::pair<eager::EagerClient*, uint64>>
device_to_client_cache_;
+
+ const bool use_send_tensor_rpc_;
#endif
};
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 7a2b477845..7ea78b63d9 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -88,6 +88,8 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op, int i,
TF_RETURN_IF_ERROR((*handle)->Device(&handle_device));
const Device* actual_device =
handle_device == nullptr ? ctx->HostCPU() : handle_device;
+ const Device* op_device =
+ op->Device() == nullptr ? ctx->HostCPU() : op->Device();
if (expected_device != actual_device) {
switch (ctx->GetDevicePlacementPolicy()) {
@@ -106,8 +108,8 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op, int i,
" cannot compute ",
op->Name(), " as input #", i, " was expected to be on ",
expected_device->name(), " but is actually on ",
- actual_device->name(), " (operation running on ",
- op->Device()->name(), ")",
+ actual_device->name(), " (operation running on ", op_device->name(),
+ ")",
" Tensors can be copied explicitly using .gpu() or .cpu() "
"methods,"
" or transparently copied by using tf.enable_eager_execution("
@@ -118,7 +120,7 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op, int i,
LOG(WARNING) << "before computing " << op->Name() << " input #" << i
<< " was expected to be on " << expected_device->name()
<< " but is actually on " << actual_device->name()
- << " (operation running on " << op->Device()->name()
+ << " (operation running on " << op_device->name()
<< "). This triggers a copy which can be a performance "
"bottleneck.";
break;
@@ -512,7 +514,8 @@ Status EagerLocalExecute(EagerOperation* op,
// See WARNING comment in Execute (before kernel->Run) - would be nice to
// rework to avoid this subtlety.
tf_shared_lock l(*ctx->FunctionsMu());
- status = KernelAndDevice::Init(ndef, ctx->func_lib(device), kernel);
+ status = KernelAndDevice::Init(ndef, ctx->func_lib(device), ctx->runner(),
+ kernel);
if (!status.ok()) {
delete kernel;
return status;
@@ -582,6 +585,87 @@ Status EagerLocalExecute(EagerOperation* op,
return status;
}
+std::function<void()> GetRemoteTensorDestructor(
+ EagerContext* ctx, eager::EagerClient* eager_client, uint64 context_id,
+ uint64 op_id, int output_num) {
+ return [ctx, eager_client, context_id, op_id, output_num]() {
+ std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
+ request->set_context_id(context_id);
+
+ auto* handle_to_decref = request->add_queue()->mutable_handle_to_decref();
+ handle_to_decref->set_op_id(op_id);
+ handle_to_decref->set_output_num(output_num);
+
+ if (ctx->Async()) {
+ tensorflow::uint64 id = ctx->NextId();
+ auto* node =
+ new eager::RemoteExecuteNode(id, std::move(request), eager_client);
+ ctx->ExecutorAdd(node);
+ } else {
+ eager::EnqueueRequest* actual_request = request.release();
+ eager::EnqueueResponse* response = new eager::EnqueueResponse;
+ eager_client->EnqueueAsync(
+ actual_request, response,
+ [actual_request, response](const tensorflow::Status& s) {
+ delete actual_request;
+ delete response;
+ });
+ }
+
+ return tensorflow::Status::OK();
+ };
+}
+
+// When !ctx->UseSendTensorRPC(), then tensors are shipped between remote
+// devices by the receiver invoking the WorkerService.RecvTensor RPC *on the
+// sender* (Rendezvous::RecvAsync() invoked by the _Recv kernel).
+//
+// However, in some configurations the node that has the tensor to be copied
+// isn't running a server (WorkerService RPC interface). For such cases,
+// this function enables sending tensors using the EagerService.SendTensor RPC
+// *on the receiver*.
+Status EagerRemoteSendTensor(EagerContext* ctx, TensorHandle* h,
+ Device* recv_device, TensorHandle** result) {
+ eager::EagerClient* eager_client;
+ uint64 context_id;
+ TF_RETURN_IF_ERROR(
+ ctx->GetClientAndContextID(recv_device, &eager_client, &context_id));
+
+ eager::SendTensorRequest request;
+ eager::SendTensorResponse response;
+
+ request.set_context_id(context_id);
+ request.set_op_id(ctx->NextId());
+ request.set_device_name(recv_device->name());
+
+ const Tensor* tensor;
+ TF_RETURN_IF_ERROR(h->Tensor(&tensor));
+ tensor->AsProtoTensorContent(request.add_tensors());
+
+ const tensorflow::uint64 id = request.op_id();
+
+ // TODO(nareshmodi): support making this call async.
+ Notification n;
+ Status status;
+ eager_client->SendTensorAsync(&request, &response,
+ [&n, &status](const Status& s) {
+ status = s;
+ n.Notify();
+ });
+ n.WaitForNotification();
+ if (!status.ok()) return status;
+
+ std::function<void()> destructor =
+ GetRemoteTensorDestructor(ctx, eager_client, context_id, id, 0);
+
+ *result = new TensorHandle(id, /*output_num=*/0, /*remote_shape_node_id=*/0,
+ tensor->dtype(), std::move(destructor),
+ recv_device, recv_device, ctx);
+ (*result)->SetRemoteShape(MakeUnique<TensorShape>(tensor->shape()));
+
+ return Status::OK();
+}
+
Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
int* num_retvals) {
#ifdef __ANDROID__
@@ -595,10 +679,12 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
TF_RETURN_IF_ERROR(
ctx->GetClientAndContextID(op->Device(), &eager_client, &context_id));
- eager::EnqueueRequest request;
+ std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
eager::EnqueueResponse response;
- auto* remote_op = request.add_queue()->mutable_operation();
+ request->set_context_id(context_id);
+
+ auto* remote_op = request->add_queue()->mutable_operation();
for (int i = 0; i < op->Inputs().size(); i++) {
tensorflow::Device* input_device;
@@ -628,8 +714,6 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
op->Attrs().FillAttrValueMap(remote_op->mutable_attrs());
remote_op->set_device(op->Device()->name());
- request.set_context_id(context_id);
-
DataTypeVector output_dtypes;
TF_RETURN_IF_ERROR(GetOutputDTypes(op, &output_dtypes));
@@ -651,32 +735,11 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
for (int i = 0; i < *num_retvals; i++) {
// TODO(nareshmodi): Change the callback to instead add the decref to a list
// of pending decrefs that we can send as a batch with the next execute.
- std::function<void()> callback = [ctx, eager_client, context_id, id, i]() {
- eager::EnqueueRequest request;
- request.set_context_id(context_id);
-
- auto* handle_to_decref = request.add_queue()->mutable_handle_to_decref();
- handle_to_decref->set_op_id(id);
- handle_to_decref->set_output_num(i);
-
- if (ctx->Async()) {
- tensorflow::uint64 id = ctx->NextId();
- auto* node = new eager::RemoteExecuteNode(id, request, eager_client);
- ctx->ExecutorAdd(node);
- } else {
- Notification n;
- eager::EnqueueResponse response;
- eager_client->EnqueueAsync(
- &request, &response,
- [&n](const tensorflow::Status& s) { n.Notify(); });
- n.WaitForNotification();
- }
-
- return tensorflow::Status::OK();
- };
+ std::function<void()> destructor =
+ GetRemoteTensorDestructor(ctx, eager_client, context_id, id, i);
retvals[i] = new TensorHandle(remote_op->id(), i, remote_node_id,
- output_dtypes[i], std::move(callback),
+ output_dtypes[i], std::move(destructor),
op_device, op_device, op->EagerContext());
}
@@ -690,7 +753,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
}
// Unable to capture via std::move, so bind instead.
auto* node = new eager::RemoteExecuteNode(
- remote_node_id, request, eager_client, op->Inputs(),
+ remote_node_id, std::move(request), eager_client, op->Inputs(),
std::bind(
[](const gtl::InlinedVector<TensorHandle*, 2>& retvals,
const Status& status, const eager::EnqueueResponse& response) {
@@ -707,7 +770,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
} else {
Notification n;
Status status;
- eager_client->EnqueueAsync(&request, &response,
+ eager_client->EnqueueAsync(request.get(), &response,
[&n, &status](const Status& s) {
status = s;
n.Notify();
@@ -936,6 +999,8 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
if (sender_is_local && recver_is_local) {
return LocalEagerCopyToDevice(h, ctx, recv_device, result);
+ } else if (ctx->UseSendTensorRPC() && sender_is_local && !recver_is_local) {
+ return EagerRemoteSendTensor(ctx, h, recv_device, result);
} else {
string wire_id = GetUniqueWireID();
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
index b410ea175b..dae5d1983f 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
@@ -41,17 +41,22 @@ Status KernelAndDevice::InitOp(Device* device, const NodeDef& ndef,
out->device_ = device;
out->kernel_.reset(k);
out->flib_ = nullptr;
+ out->runner_ = nullptr;
+ out->default_runner_ = [](std::function<void()> f) { f(); };
return s;
}
// static
Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
+ std::function<void(std::function<void()>)>* runner,
KernelAndDevice* out) {
OpKernel* k = nullptr;
Status s = flib->CreateKernel(ndef, &k);
out->device_ = flib->device();
out->kernel_.reset(k);
out->flib_ = flib;
+ out->runner_ = runner;
+ out->default_runner_ = [](std::function<void()> f) { f(); };
return s;
}
@@ -83,10 +88,11 @@ Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors,
if (stats != nullptr) {
params.track_allocations = true;
}
- // TODO(apassos): use a thread pool.
- std::function<void(std::function<void()>)> runner =
- [](std::function<void()> f) { f(); };
- params.runner = &runner;
+ if (runner_ == nullptr) {
+ params.runner = &default_runner_;
+ } else {
+ params.runner = runner_;
+ }
ScopedStepContainer step_container(0, [this](const string& name) {
device_->resource_manager()->Cleanup(name).IgnoreError();
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h
index c41a0972b1..c0b676b285 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.h
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h
@@ -57,6 +57,7 @@ class KernelAndDevice {
// the FunctionLibraryRuntime is pushed on to the caller (see locking in
// c_api.cc).
static Status Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
+ std::function<void(std::function<void()>)>* runner,
KernelAndDevice* out);
// TODO(ashankar): Remove this
static Status InitOp(Device* device, const NodeDef& ndef,
@@ -88,6 +89,8 @@ class KernelAndDevice {
checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_;
Rendezvous* rendez_;
DataTypeVector output_dtypes_;
+ std::function<void(std::function<void()>)>* runner_;
+ std::function<void(std::function<void()>)> default_runner_;
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc b/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc
index b4349e1dee..6abe98f53c 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc
@@ -107,8 +107,8 @@ void BM_KernelAndDeviceInit(int iters) {
KernelAndDevice k(nullptr);
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
- TF_CHECK_OK(
- KernelAndDevice::Init(ndef, env.function_library_runtime(), &k));
+ TF_CHECK_OK(KernelAndDevice::Init(ndef, env.function_library_runtime(),
+ nullptr, &k));
}
}
BENCHMARK(BM_KernelAndDeviceInit);
@@ -128,8 +128,8 @@ void BM_KernelAndDeviceRun(int iters) {
.BuildNodeDef());
TestEnv env;
KernelAndDevice kernel(nullptr);
- TF_CHECK_OK(
- KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel));
+ TF_CHECK_OK(KernelAndDevice::Init(ndef, env.function_library_runtime(),
+ nullptr, &kernel));
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
TF_CHECK_OK(kernel.Run(&inputs, &outputs, nullptr));
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc
index f9b9abcc99..85b0b79bce 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.cc
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc
@@ -109,6 +109,19 @@ Status TensorHandle::TensorAndDevice(const tensorflow::Tensor** tensor,
return Status::OK();
}
+Status TensorHandle::Shape(tensorflow::TensorShape* shape) {
+ if (IsRemote()) {
+ TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false));
+ CHECK(remote_shape_ != nullptr);
+ *shape = *(remote_shape_.get());
+ } else {
+ TF_RETURN_IF_ERROR(WaitReady());
+ DCHECK(IsReady());
+ *shape = tensor_.shape();
+ }
+ return Status::OK();
+}
+
Status TensorHandle::NumDims(int* num_dims) {
if (IsRemote()) {
TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false));
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h
index 46bc94f875..1bc9c6531a 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.h
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.h
@@ -109,6 +109,8 @@ class TensorHandle : public core::RefCounted {
tensorflow::Device** device,
tensorflow::Device** op_device);
+ Status Shape(tensorflow::TensorShape* shape);
+
Status NumDims(int* num_dims);
Status Dim(int dim_index, int64* dim);
@@ -138,6 +140,12 @@ class TensorHandle : public core::RefCounted {
remote_shape_ = std::move(remote_shape);
}
+ bool OnHostCPU() {
+ mutex_lock ml(ctx_mutex_);
+ return device_ == nullptr ||
+ (ctx_ == nullptr || ctx_->HostCPU() == device_);
+ }
+
private:
// If the contents of the Tensor pointed to by this handle is yet to be
// computed by a EagerNode, this function will block till that compuatation is
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index f7f2cdc14f..8096139d90 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -1966,17 +1966,9 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
device_context = device_context_map_[node->id()];
}
- // Experimental: debugger (tfdb) access to intermediate node completion.
- if (item.num_outputs == 0 && impl_->params_.node_outputs_cb != nullptr) {
- // If the node has no output, invoke the callback with output slot set to
- // -1, signifying that this is a no-output node.
- s.Update(impl_->params_.node_outputs_cb(item.node->name(), -1, nullptr,
- false, ctx));
- }
-
for (int i = 0; i < item.num_outputs; ++i) {
const TensorValue val = ctx->release_output(i);
- if (*ctx->is_output_dead() || val.tensor == nullptr) {
+ if (val.tensor == nullptr) {
// Unless it's a Switch or a Recv, the node must produce a
// tensor value at i-th output.
if (!IsSwitch(node) && !IsRecv(node)) {
@@ -2018,13 +2010,6 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
LogMemory::RecordTensorOutput(ctx->op_kernel().name(),
ctx->step_id(), i, to_log);
}
-
- // Experimental: debugger (tfdb) access to intermediate node
- // outputs.
- if (impl_->params_.node_outputs_cb != nullptr) {
- s.Update(impl_->params_.node_outputs_cb(item.node->name(), i,
- out->ref, true, ctx));
- }
} else {
// NOTE that std::move is used here, so val.tensor goes to
// uninitialized state (val.tensor->IsInitialized return false).
@@ -2036,12 +2021,6 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
LogMemory::RecordTensorOutput(ctx->op_kernel().name(),
ctx->step_id(), i, *out->val);
}
-
- // Experimental: debugger access to intermediate node outputs.
- if (impl_->params_.node_outputs_cb != nullptr) {
- s.Update(impl_->params_.node_outputs_cb(
- item.node->name(), i, out->val.get(), false, ctx));
- }
}
} else {
s.Update(errors::Internal("Output ", i, " of type ",
diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h
index e5d7b7c53c..cd01b43aea 100644
--- a/tensorflow/core/common_runtime/executor.h
+++ b/tensorflow/core/common_runtime/executor.h
@@ -103,7 +103,6 @@ class Executor {
const Tensor* tensor, const bool is_ref,
OpKernelContext* ctx)>
NodeOutputsCallback;
- NodeOutputsCallback node_outputs_cb = nullptr;
};
typedef std::function<void(const Status&)> DoneCallback;
virtual void RunAsync(const Args& args, DoneCallback done) = 0;
@@ -139,8 +138,6 @@ struct LocalExecutorParams {
// when the executor is deleted.
std::function<Status(const NodeDef&, OpKernel**)> create_kernel;
std::function<void(OpKernel*)> delete_kernel;
-
- Executor::Args::NodeOutputsCallback node_outputs_cb;
};
::tensorflow::Status NewLocalExecutor(const LocalExecutorParams& params,
std::unique_ptr<const Graph> graph,
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index a93cfa2ec5..54bbe84b57 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -746,6 +746,8 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
rets_alloc_attrs.push_back(ret_alloc_attrs);
}
+ bool allow_dead_tensors = opts.allow_dead_tensors;
+
// The ProcFLR sends the arguments to the function from the source_device to
// the target_device. So here we receive those arguments. Similarly, when the
// computation is done and stored in *rets, we send the return values back
@@ -756,7 +758,7 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
device_context, args_alloc_attrs, rendezvous, remote_args,
[frame, remote_args, item, source_device, target_device,
target_incarnation, rendezvous, device_context, rets, done, exec_args,
- rets_alloc_attrs](const Status& status) {
+ rets_alloc_attrs, allow_dead_tensors](const Status& status) {
Status s = status;
if (s.ok()) {
s = frame->SetArgs(*remote_args);
@@ -769,13 +771,13 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
return;
}
item->exec->RunAsync(
- *exec_args,
- [frame, rets, done, source_device, target_device,
- target_incarnation, rendezvous, device_context, remote_args,
- exec_args, rets_alloc_attrs](const Status& status) {
+ *exec_args, [frame, rets, done, source_device, target_device,
+ target_incarnation, rendezvous, device_context,
+ remote_args, exec_args, rets_alloc_attrs,
+ allow_dead_tensors](const Status& status) {
Status s = status;
if (s.ok()) {
- s = frame->ConsumeRetvals(rets);
+ s = frame->ConsumeRetvals(rets, allow_dead_tensors);
}
delete frame;
if (!s.ok()) {
@@ -859,14 +861,15 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
return;
}
+ bool allow_dead_tensors = opts.allow_dead_tensors;
item->exec->RunAsync(
// Executor args
*exec_args,
// Done callback.
- [frame, rets, done, exec_args](const Status& status) {
+ [frame, rets, done, exec_args, allow_dead_tensors](const Status& status) {
Status s = status;
if (s.ok()) {
- s = frame->ConsumeRetvals(rets);
+ s = frame->ConsumeRetvals(rets, allow_dead_tensors);
}
delete frame;
delete exec_args;
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 3cb51b0dbc..3292ef2f62 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -41,6 +41,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
#include "tensorflow/core/common_runtime/gpu_device_context.h"
#include "tensorflow/core/common_runtime/local_device.h"
+#include "tensorflow/core/common_runtime/visitable_allocator.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -224,6 +225,7 @@ class BaseGPUDevice::StreamGroupFactory {
int num_d2d_streams =
options.experimental().num_dev_to_dev_copy_streams();
+ if (num_d2d_streams == 0) num_d2d_streams = 1;
if (num_d2d_streams < 1 || num_d2d_streams > 4) {
LOG(ERROR)
<< "Illegal GPUOptions.experimental.num_dev_to_dev_copy_streams="
@@ -856,7 +858,7 @@ void BaseGPUDevice::ReinitializeDevice(OpKernelContext* context,
static_cast<ConcretePerOpGpuDevice*>(device);
DCHECK(concrete_device);
const cudaStream_t* cuda_stream = reinterpret_cast<const cudaStream_t*>(
- streams_[stream_id]->compute->implementation()->CudaStreamMemberHack());
+ streams_[stream_id]->compute->implementation()->GpuStreamMemberHack());
concrete_device->Reinitialize(context, cuda_stream, tf_gpu_id_, allocator,
scratch_[stream_id]);
}
diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc
index 4898448476..3c1c31aa73 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc
@@ -15,11 +15,80 @@ limitations under the License.
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
+#include "tensorflow/core/platform/stacktrace.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/protobuf/config.pb.h"
namespace tensorflow {
+namespace {
+// The EventMgr has 1 thread for the polling loop and one to execute
+// event callback functions. Issues for reconsideration:
+// - Is this the right number of threads?
+// - Should EventMgrs be shared between GPUDevices on a multi-GPU machine?
+static const int kNumThreads = 2;
+} // namespace
+
+namespace gpu_event_mgr {
+class ThreadLabel {
+ public:
+ static const char* GetValue() { return value_; }
+
+ // v must be a static const because value_ will capture and use its value
+ // until reset or thread terminates.
+ static void SetValue(const char* v) { value_ = v; }
+
+ private:
+ static thread_local const char* value_;
+};
+thread_local const char* ThreadLabel::value_ = "";
+
+void WarnIfInCallback(std::function<void()> f) {
+ const char* label = ThreadLabel::GetValue();
+ if (label && !strcmp(label, "gpu_event_mgr")) {
+ if (f) {
+ f();
+ } else {
+ LOG(WARNING) << "Executing inside EventMgr callback thread: "
+ << CurrentStackTrace();
+ }
+ }
+}
+
+void InitThreadpoolLabels(thread::ThreadPool* threadpool) {
+ static const char* label = "gpu_event_mgr";
+ mutex mu;
+ int init_count = 0;
+ condition_variable all_initialized;
+ int exit_count = 0;
+ condition_variable ready_to_exit;
+ const int num_threads = threadpool->NumThreads();
+ for (int i = 0; i < num_threads; ++i) {
+ threadpool->Schedule([num_threads, &mu, &init_count, &all_initialized,
+ &exit_count, &ready_to_exit]() {
+ gpu_event_mgr::ThreadLabel::SetValue(label);
+ mutex_lock l(mu);
+ ++init_count;
+ if (init_count == num_threads) {
+ all_initialized.notify_all();
+ }
+ while (init_count < num_threads) {
+ all_initialized.wait(l);
+ }
+ if (++exit_count == num_threads) {
+ ready_to_exit.notify_all();
+ }
+ });
+ }
+ {
+ mutex_lock l(mu);
+ while (exit_count < num_threads) {
+ ready_to_exit.wait(l);
+ }
+ }
+}
+} // namespace gpu_event_mgr
+
EventMgr::EventMgr(se::StreamExecutor* se, const GPUOptions& gpu_options)
: exec_(se),
deferred_bytes_threshold_(gpu_options.deferred_deletion_bytes()
@@ -31,9 +100,8 @@ EventMgr::EventMgr(se::StreamExecutor* se, const GPUOptions& gpu_options)
accumulated_stream_(nullptr),
accumulated_tensors_(new TensorReferenceVector),
accumulated_tensor_bytes_(0),
- // threadpool_ has 1 thread for the polling loop, and one to execute
- // event callback functions. Maybe we should have more?
- threadpool_(Env::Default(), "GPU_Event_Manager", 2) {
+ threadpool_(Env::Default(), "GPU_Event_Manager", kNumThreads) {
+ gpu_event_mgr::InitThreadpoolLabels(&threadpool_);
StartPollingLoop();
}
diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h
index b26f88a201..f0a109cc10 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h
@@ -39,6 +39,25 @@ namespace tensorflow {
class GPUOptions;
+// The callback provided to EventMgr::ThenExecute must not block or take a long
+// time. If it does, performance may be impacted and GPU memory may be
+// exhausted. This macro is for checking that an EventMgr thread is not
+// accidentally entering blocking parts of the code, e.g. the RPC subsystem.
+//
+// Intended use is something like
+//
+// void RespondToAnRPC(Params* params) {
+// WARN_IF_IN_EVENT_MGR_THREAD;
+// if (params->status.ok()) { ...
+//
+namespace gpu_event_mgr {
+// Logs a stack trace if current execution thread belongs to this EventMgr
+// object. If f is not nullptr, executes instead of logging the stack trace.
+// trace.
+void WarnIfInCallback(std::function<void()> f);
+} // namespace gpu_event_mgr
+#define WARN_IF_IN_EVENT_MGR_THREAD gpu_event_mgr::WarnIfInCallback(nullptr)
+
// An object to keep track of pending Events in the StreamExecutor streams
// and associated Tensors that cannot safely be deleted until the associated
// Events are recorded.
@@ -74,6 +93,9 @@ class EventMgr {
FreeMemory(to_free);
}
+ // Execute func when all pending stream actions have completed.
+ // func must be brief and non-blocking since it executes in the one
+ // thread used for all such callbacks and also buffer deletions.
inline void ThenExecute(se::Stream* stream, std::function<void()> func) {
ToFreeVector to_free;
{
diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc b/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc
index c5ff6c97a1..d2adf699f5 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <atomic>
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
+#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/config.pb.h"
@@ -243,6 +244,28 @@ TEST(EventMgr, NonEmptyShutdown) {
}
}
+// Tests that WarnIfInCallback() triggers correctly.
+TEST(EventMgr, WarnIfInCallback) {
+ auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
+ EventMgr em(stream_exec, GPUOptions());
+ TEST_EventMgrHelper th(&em);
+ std::unique_ptr<se::Stream> stream(new se::Stream(stream_exec));
+ CHECK(stream);
+ stream->Init();
+ bool hit = false;
+ gpu_event_mgr::WarnIfInCallback([&hit] { hit = true; });
+ EXPECT_FALSE(hit);
+ Notification note;
+ em.ThenExecute(stream.get(), [&hit, &note]() {
+ gpu_event_mgr::WarnIfInCallback([&hit, &note] {
+ hit = true;
+ note.Notify();
+ });
+ });
+ note.WaitForNotification();
+ EXPECT_TRUE(hit);
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/placer.cc b/tensorflow/core/common_runtime/placer.cc
index 1f0773d387..6781c87f6c 100644
--- a/tensorflow/core/common_runtime/placer.cc
+++ b/tensorflow/core/common_runtime/placer.cc
@@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/util/status_util.h"
namespace tensorflow {
@@ -822,10 +823,10 @@ Status Placer::Run() {
std::vector<Device*>* devices;
Status status = colocation_graph.GetDevicesForNode(node, &devices);
if (!status.ok()) {
- return AttachDef(
- errors::InvalidArgument("Cannot assign a device for operation '",
- node->name(), "': ", status.error_message()),
- *node);
+ return AttachDef(errors::InvalidArgument(
+ "Cannot assign a device for operation ",
+ RichNodeName(node), ": ", status.error_message()),
+ *node);
}
// Returns the first device in sorted devices list so we will always
@@ -869,10 +870,10 @@ Status Placer::Run() {
std::vector<Device*>* devices;
Status status = colocation_graph.GetDevicesForNode(node, &devices);
if (!status.ok()) {
- return AttachDef(
- errors::InvalidArgument("Cannot assign a device for operation '",
- node->name(), "': ", status.error_message()),
- *node);
+ return AttachDef(errors::InvalidArgument(
+ "Cannot assign a device for operation ",
+ RichNodeName(node), ": ", status.error_message()),
+ *node);
}
int assigned_device = -1;
@@ -938,4 +939,22 @@ void Placer::LogDeviceAssignment(const Node* node) const {
}
}
+bool Placer::ClientHandlesErrorFormatting() const {
+ return options_ != nullptr &&
+ options_->config.experimental().client_handles_error_formatting();
+}
+
+// Returns the node name in single quotes. If the client handles formatted
+// errors, appends a formatting tag which the client will reformat into, for
+// example, " (defined at filename:123)".
+string Placer::RichNodeName(const Node* node) const {
+ string quoted_name = strings::StrCat("'", node->name(), "'");
+ if (ClientHandlesErrorFormatting()) {
+ string file_and_line = error_format_tag(*node, "${file}:${line}");
+ return strings::StrCat(quoted_name, " (defined at ", file_and_line, ")");
+ } else {
+ return quoted_name;
+ }
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/placer.h b/tensorflow/core/common_runtime/placer.h
index 75dce7c7fe..fce87269c5 100644
--- a/tensorflow/core/common_runtime/placer.h
+++ b/tensorflow/core/common_runtime/placer.h
@@ -87,6 +87,8 @@ class Placer {
// placement if the SessionOptions entry in 'options_' requests it.
void AssignAndLog(int assigned_device, Node* node) const;
void LogDeviceAssignment(const Node* node) const;
+ bool ClientHandlesErrorFormatting() const;
+ string RichNodeName(const Node* node) const;
Graph* const graph_; // Not owned.
const DeviceSet* const devices_; // Not owned.
diff --git a/tensorflow/core/common_runtime/placer_test.cc b/tensorflow/core/common_runtime/placer_test.cc
index 07a7724f16..cede899842 100644
--- a/tensorflow/core/common_runtime/placer_test.cc
+++ b/tensorflow/core/common_runtime/placer_test.cc
@@ -1142,6 +1142,50 @@ TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacement) {
EXPECT_TRUE(str_util::StrContains(s.error_message(), "/device:fakegpu:11"));
}
+// Test that the "Cannot assign a device" error message contains a format tag
+// when requested.
+TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacementFormatTag) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("TestDevice",
+ b.opts().WithName("in").WithDevice("/device:fakegpu:11"));
+ TF_EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ SessionOptions options;
+ options.config.mutable_experimental()->set_client_handles_error_formatting(
+ true);
+ Status s = Place(&g, &options);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(
+ str_util::StrContains(s.error_message(),
+ "Cannot assign a device for operation 'in'"
+ " (defined at ^^node:in:${file}:${line}^^)"));
+}
+
+// Test that the "Cannot assign a device" error message does not contain a
+// format tag when not it shouldn't
+TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacementNoFormatTag) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("TestDevice",
+ b.opts().WithName("in").WithDevice("/device:fakegpu:11"));
+ TF_EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ SessionOptions options;
+ options.config.mutable_experimental()->set_client_handles_error_formatting(
+ false);
+ Status s = Place(&g, &options);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(str_util::StrContains(
+ s.error_message(), "Cannot assign a device for operation 'in'"));
+ EXPECT_FALSE(str_util::StrContains(
+ s.error_message(), "'in' (defined at ^^node:in:${file}:${line}^^)"));
+}
+
// Test that placement fails when a node requests an explicit device that is not
// supported by the registered kernels if allow_soft_placement is no set.
TEST_F(PlacerTest, TestUnsupportedDeviceNoAllowSoftPlacement) {
diff --git a/tensorflow/core/common_runtime/process_state.cc b/tensorflow/core/common_runtime/process_state.cc
index 4d83b25ce6..447338e7bd 100644
--- a/tensorflow/core/common_runtime/process_state.cc
+++ b/tensorflow/core/common_runtime/process_state.cc
@@ -71,7 +71,7 @@ ProcessState::MemDesc ProcessState::PtrType(const void* ptr) {
return MemDesc();
}
-Allocator* ProcessState::GetCPUAllocator(int numa_node) {
+VisitableAllocator* ProcessState::GetCPUAllocator(int numa_node) {
CHECK_GE(numa_node, 0);
if (!numa_enabled_) numa_node = 0;
mutex_lock lock(mu_);
diff --git a/tensorflow/core/common_runtime/process_state.h b/tensorflow/core/common_runtime/process_state.h
index 0f4ae230bb..2892677333 100644
--- a/tensorflow/core/common_runtime/process_state.h
+++ b/tensorflow/core/common_runtime/process_state.h
@@ -65,7 +65,7 @@ class ProcessState {
// Returns the one CPUAllocator used for the given numa_node.
// TEMPORARY: ignores numa_node.
- Allocator* GetCPUAllocator(int numa_node);
+ VisitableAllocator* GetCPUAllocator(int numa_node);
typedef std::unordered_map<const void*, MemDesc> MDMap;
@@ -87,7 +87,7 @@ class ProcessState {
mutex mu_;
- std::vector<Allocator*> cpu_allocators_ GUARDED_BY(mu_);
+ std::vector<VisitableAllocator*> cpu_allocators_ GUARDED_BY(mu_);
virtual ~ProcessState();
diff --git a/tensorflow/core/common_runtime/session.cc b/tensorflow/core/common_runtime/session.cc
index 4a9248171b..8c30beeec2 100644
--- a/tensorflow/core/common_runtime/session.cc
+++ b/tensorflow/core/common_runtime/session.cc
@@ -53,27 +53,33 @@ Status Session::PRun(const string& handle,
Session* NewSession(const SessionOptions& options) {
SessionFactory* factory;
- const Status s = SessionFactory::GetFactory(options, &factory);
+ Status s = SessionFactory::GetFactory(options, &factory);
if (!s.ok()) {
LOG(ERROR) << s;
return nullptr;
}
- return factory->NewSession(options);
+ Session* out_session;
+ s = NewSession(options, &out_session);
+ if (!s.ok()) {
+ LOG(ERROR) << "Failed to create session: " << s;
+ return nullptr;
+ }
+ return out_session;
}
Status NewSession(const SessionOptions& options, Session** out_session) {
SessionFactory* factory;
- const Status s = SessionFactory::GetFactory(options, &factory);
+ Status s = SessionFactory::GetFactory(options, &factory);
if (!s.ok()) {
*out_session = nullptr;
LOG(ERROR) << s;
return s;
}
- *out_session = factory->NewSession(options);
- if (!*out_session) {
- return errors::Internal("Failed to create session.");
+ s = factory->NewSession(options, out_session);
+ if (!s.ok()) {
+ *out_session = nullptr;
}
- return Status::OK();
+ return s;
}
Status Reset(const SessionOptions& options,
diff --git a/tensorflow/core/common_runtime/session_factory.h b/tensorflow/core/common_runtime/session_factory.h
index df3198a70d..81c172c6ae 100644
--- a/tensorflow/core/common_runtime/session_factory.h
+++ b/tensorflow/core/common_runtime/session_factory.h
@@ -30,7 +30,12 @@ struct SessionOptions;
class SessionFactory {
public:
- virtual Session* NewSession(const SessionOptions& options) = 0;
+ // Creates a new session and stores it in *out_session, or fails with an error
+ // status if the Session could not be created. Caller takes ownership of
+ // *out_session if this returns Status::OK().
+ virtual Status NewSession(const SessionOptions& options,
+ Session** out_session) = 0;
+
virtual bool AcceptsOptions(const SessionOptions& options) = 0;
// Abort and close all existing sessions, disconnecting their resources from
diff --git a/tensorflow/core/common_runtime/session_test.cc b/tensorflow/core/common_runtime/session_test.cc
index feaf29c7bb..1fa5aad60c 100644
--- a/tensorflow/core/common_runtime/session_test.cc
+++ b/tensorflow/core/common_runtime/session_test.cc
@@ -47,8 +47,10 @@ class FakeSessionFactory : public SessionFactory {
return str_util::StartsWith(options.target, "fake");
}
- Session* NewSession(const SessionOptions& options) override {
- return nullptr;
+ Status NewSession(const SessionOptions& options,
+ Session** out_session) override {
+ *out_session = nullptr;
+ return Status::OK();
}
};
class FakeSessionRegistrar {
diff --git a/tensorflow/core/common_runtime/threadpool_device.cc b/tensorflow/core/common_runtime/threadpool_device.cc
index 74a87215e1..7406ecf4f8 100644
--- a/tensorflow/core/common_runtime/threadpool_device.cc
+++ b/tensorflow/core/common_runtime/threadpool_device.cc
@@ -111,7 +111,21 @@ Status ThreadPoolDevice::MakeTensorFromProto(
}
#ifdef INTEL_MKL
-REGISTER_MEM_ALLOCATOR("MklCPUAllocator", 200, MklCPUAllocator);
+namespace {
+class MklCPUAllocatorFactory : public AllocatorFactory {
+ public:
+ bool NumaEnabled() override { return false; }
+
+ Allocator* CreateAllocator() override { return new MklCPUAllocator; }
+
+ // Note: Ignores numa_node, for now.
+ virtual SubAllocator* CreateSubAllocator(int numa_node) {
+ return new MklSubAllocator;
+ }
+};
+
+REGISTER_MEM_ALLOCATOR("MklCPUAllocator", 200, MklCPUAllocatorFactory);
+} // namespace
#endif
} // namespace tensorflow