aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/c/eager/BUILD2
-rw-r--r--tensorflow/c/eager/c_api.cc194
-rw-r--r--tensorflow/c/eager/c_api_internal.h84
-rw-r--r--tensorflow/core/common_runtime/eager/BUILD22
-rw-r--r--tensorflow/core/common_runtime/eager/context.cc153
-rw-r--r--tensorflow/core/common_runtime/eager/context.h198
6 files changed, 450 insertions, 203 deletions
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index 841ff48a38..bea5a121b3 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -28,6 +28,7 @@ tf_cuda_library(
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_internal",
"//tensorflow/core:core_cpu",
+ "//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:eager_executor",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
"//tensorflow/core:core_cpu_internal",
@@ -64,6 +65,7 @@ tf_cuda_library(
"//tensorflow/core:framework_lite",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:eager_executor",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
],
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index a23015c99e..2402a6d044 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -71,18 +71,6 @@ std::atomic_int_fast64_t func_id_generator(0);
} // namespace
-TFE_ContextDevicePlacementPolicy PlacementPolicy(
- bool soft_placement, TFE_ContextDevicePlacementPolicy original_policy) {
- if (!soft_placement) {
- return original_policy;
- }
- if (original_policy == TFE_DEVICE_PLACEMENT_EXPLICIT ||
- original_policy == TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32) {
- return TFE_DEVICE_PLACEMENT_SILENT;
- }
- return original_policy;
-}
-
extern "C" {
TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; }
@@ -104,19 +92,7 @@ void TFE_ContextOptionsSetDevicePlacementPolicy(
TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
unsigned char async,
TF_Status* status) {
- {
- tensorflow::mutex_lock l(ctx->async_map_mu);
- ctx->thread_local_async[std::this_thread::get_id()] = async;
- }
- if (async) {
- ctx->executor.EnableAsync();
- } else {
- // TODO(agarwal): Currently we add a wait here to handle cases where a sync
- // op has a control dependency on an async op, and the latter has not
- // executed yet. This wait can be removed by storing all the control inputs
- // and waiting for them when executing ops.
- status->status = ctx->executor.WaitForAllPendingNodes();
- }
+ status->status = ctx->context.SetAsyncForThread(async);
}
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
@@ -133,34 +109,26 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
new tensorflow::DeviceMgr(devices));
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr.get());
- return new TFE_Context(*opts, std::move(device_mgr), r);
+ return new TFE_Context(opts->session_options.options, opts->policy,
+ opts->async, std::move(device_mgr), r);
}
void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) {
- status->status = ctx->executor.WaitForAllPendingNodes();
- {
- tensorflow::mutex_lock ml(ctx->cache_mu);
- tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache);
- }
- ctx->rendezvous->Unref();
delete ctx;
}
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
TF_DeviceList* list = new TF_DeviceList;
- ctx->device_manager->ListDeviceAttributes(&list->response);
+ ctx->context.device_mgr()->ListDeviceAttributes(&list->response);
return list;
}
-void TFE_ContextClearCaches(TFE_Context* ctx) {
- tensorflow::mutex_lock ml(ctx->cache_mu);
- tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache);
-}
+void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context.ClearCaches(); }
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
- tensorflow::mutex_lock ml(ctx->policy_map_mu);
- ctx->thread_local_policies[std::this_thread::get_id()] = policy;
+ ctx->context.SetThreadLocalDevicePlacementPolicy(
+ static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
}
// Note: this function looks up a thread local policy. So it should be called in
@@ -168,25 +136,20 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy(
// safe to call this function from the async EagerExecutor threads.
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
TFE_Context* ctx) {
- tensorflow::mutex_lock ml(ctx->policy_map_mu);
- auto policy_map_it =
- ctx->thread_local_policies.find(std::this_thread::get_id());
- if (policy_map_it != ctx->thread_local_policies.end()) {
- return policy_map_it->second;
- }
- return ctx->policy;
+ return static_cast<TFE_ContextDevicePlacementPolicy>(
+ ctx->context.GetDevicePlacementPolicy());
}
void TFE_ContextAsyncWait(TFE_Context* ctx, TF_Status* status) {
- status->status = ctx->executor.WaitForAllPendingNodes();
+ status->status = ctx->context.AsyncWait();
}
void TFE_ContextGetStatus(TFE_Context* ctx, TF_Status* status) {
- status->status = ctx->executor.status();
+ status->status = ctx->context.GetStatus();
}
void TFE_ContextAsyncClearError(TFE_Context* ctx) {
- ctx->executor.ClearError();
+ ctx->context.ClearAsyncError();
}
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
@@ -259,7 +222,7 @@ tensorflow::Status TensorHandleCopyToDevice(TFE_TensorHandle* h,
// nullptr.
tensorflow::Device* src_opd = nullptr;
TF_RETURN_IF_ERROR(h->TensorAndDevice(&src, &srcd, &src_opd));
- if (srcd == nullptr) srcd = ctx->devices[0];
+ if (srcd == nullptr) srcd = ctx->context.HostCPU();
bool is_same_device =
(srcd == dstd) || (DeviceName(srcd) == DeviceName(dstd));
const bool dst_cpu = IsCPU(dstd);
@@ -332,8 +295,7 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
status->status = tensorflow::AttrTypeMapForOp(name, &types);
if (status->status.ok()) return new TFE_Op(ctx, name, types);
if (TF_GetCode(status) == TF_NOT_FOUND) {
- tensorflow::mutex_lock l(ctx->functions_mu);
- if (ctx->func_lib_def.Find(name) != nullptr) {
+ if (ctx->context.FindFunctionByName(name)) {
status->status = tensorflow::Status::OK();
return new TFE_Op(ctx, name, nullptr);
}
@@ -346,20 +308,14 @@ void TFE_DeleteOp(TFE_Op* op) { delete op; }
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
tensorflow::Device* d = nullptr;
if (device_name != nullptr && strlen(device_name) > 0) {
- auto it = op->ctx->devices_map.find(device_name);
- if (it == op->ctx->devices_map.end()) {
- status->status =
- tensorflow::errors::InvalidArgument(device_name, " unknown device.");
- return;
- }
- d = it->second;
+ status->status = op->ctx->context.FindDeviceByName(device_name, &d);
}
op->device = d;
}
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
tensorflow::Device* device =
- (op->device == nullptr) ? op->ctx->devices[0] : op->device;
+ (op->device == nullptr) ? op->ctx->context.HostCPU() : op->device;
return device->name().c_str();
}
@@ -634,7 +590,7 @@ tensorflow::Status ValidateInputTypeAndPlacement(
tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef,
TFE_Context* ctx, TF_Status* status) {
tensorflow::DeviceSet ds;
- for (tensorflow::Device* d : ctx->devices) {
+ for (tensorflow::Device* d : *ctx->context.devices()) {
ds.AddDevice(d);
}
tensorflow::DeviceTypeVector final_devices;
@@ -648,7 +604,7 @@ tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef,
"Could not find valid device for node ", ndef.DebugString());
return nullptr;
}
- for (tensorflow::Device* d : ctx->devices) {
+ for (tensorflow::Device* d : *ctx->context.devices()) {
if (d->device_type() == final_devices[0].type_string()) {
return d;
}
@@ -663,9 +619,8 @@ tensorflow::Status Execute(
const tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4>& op_inputs,
tensorflow::KernelAndDevice* kernel, tensorflow::NodeExecStats* maybe_stats,
TFE_TensorHandle** retvals, int num_retvals) {
- if (!ctx->soft_placement && device == nullptr) {
- // TODO(ashankar): ASSUMPTION: ctx->devices[0] is always CPU
- device = ctx->devices[0];
+ if (!ctx->context.SoftPlacement() && device == nullptr) {
+ device = ctx->context.HostCPU();
}
if (device == nullptr) {
@@ -684,8 +639,8 @@ tensorflow::Status Execute(
inputs[i] = *input_tensor;
}
// WARNING: kernel->Run utilizes the FunctionLibraryRuntime
- // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def,
- // which is GUARDED_BY(ctx->functions_mu). But knowledge of the implementation
+ // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def.
+ // But knowledge of the implementation
// of FunctionLibraryRuntime tells us that func_lib_def is not accessed by
// FunctionLibraryRuntime::Run(), so there is no thread-safety concern here.
// This is quite subtle. Re-work things to make this better? (Would it make
@@ -697,18 +652,18 @@ tensorflow::Status Execute(
if (maybe_stats != nullptr) {
maybe_stats->set_op_end_rel_micros(tensorflow::Env::Default()->NowMicros() -
maybe_stats->all_start_micros());
- tensorflow::mutex_lock ml(ctx->metadata_mu);
- if (ctx->should_store_metadata.load()) {
- auto* step_stats = ctx->run_metadata.mutable_step_stats();
+ tensorflow::mutex_lock ml(*ctx->context.MetadataMu());
+ if (ctx->context.ShouldStoreMetadata()) {
+ auto* step_stats = ctx->context.RunMetadataProto()->mutable_step_stats();
// Lazily initialize the RunMetadata with information about all devices if
// this is the first call.
- while (step_stats->dev_stats_size() < ctx->devices.size()) {
+ while (step_stats->dev_stats_size() < ctx->context.devices()->size()) {
step_stats->add_dev_stats();
}
// Find the current device's index.
int device_idx = 0;
- for (int i = 0; i < ctx->devices.size(); ++i) {
- if (ctx->devices[i] == device) {
+ for (int i = 0; i < ctx->context.devices()->size(); ++i) {
+ if (ctx->context.devices()->at(i) == device) {
device_idx = i;
break;
}
@@ -744,7 +699,7 @@ class ExecuteNode : public tensorflow::EagerNode {
tensorflow::NodeExecStats* maybe_stats,
const tensorflow::DataTypeVector& output_dtypes,
TFE_TensorHandle** retvals, int num_retvals)
- : tensorflow::EagerNode(op->ctx->executor.NextId()),
+ : tensorflow::EagerNode(op->ctx->context.NextId()),
ctx_(op->ctx),
op_device_(op->device),
inputs_(op->inputs),
@@ -800,7 +755,7 @@ class CopyToDeviceNode : public tensorflow::EagerNode {
public:
CopyToDeviceNode(TFE_TensorHandle* src, tensorflow::Device* dstd,
TFE_Context* ctx)
- : tensorflow::EagerNode(ctx->executor.NextId()),
+ : tensorflow::EagerNode(ctx->context.NextId()),
src_(src),
dstd_(dstd),
ctx_(ctx),
@@ -866,8 +821,7 @@ const tensorflow::FunctionDef* OpToFunction(
TFE_Context* ctx = op->ctx;
const tensorflow::OpRegistrationData* op_data;
{
- tensorflow::tf_shared_lock l(ctx->functions_mu);
- status->status = ctx->func_lib_def.LookUp(op->name, &op_data);
+ status->status = ctx->context.FindFunctionOpData(op->name, &op_data);
if (!status->status.ok()) {
return nullptr;
}
@@ -963,10 +917,9 @@ const tensorflow::FunctionDef* OpToFunction(
}
VLOG(1) << "Fixed Output names and all types: " << fdef.DebugString();
- tensorflow::mutex_lock l(ctx->functions_mu);
- status->status = ctx->func_lib_def.AddFunctionDef(fdef);
+ ctx->context.AddFunctionDef(fdef);
if (!status->status.ok()) return nullptr;
- const auto ret = ctx->func_lib_def.Find(signature->name());
+ const auto ret = ctx->context.FindFunctionDef(signature->name());
DCHECK(ret != nullptr);
return ret;
}
@@ -985,8 +938,7 @@ std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) {
const tensorflow::FunctionDef* fdef;
{
- tensorflow::tf_shared_lock l(op->ctx->functions_mu);
- fdef = op->ctx->func_lib_def.Find(op->name);
+ fdef = op->ctx->context.FindFunctionDef(op->name);
}
std::vector<TF_DataType> const_input_types;
std::vector<TF_DataType> arg_input_types;
@@ -1063,7 +1015,7 @@ extern "C" {
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status) {
TFE_Context* ctx = op->ctx;
- status->status = ctx->executor.status();
+ status->status = ctx->context.GetStatus();
if (!status->status.ok()) {
return;
}
@@ -1087,7 +1039,7 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
if (op->inputs[i]->dtype == tensorflow::DT_RESOURCE &&
input_op_device != op->device) {
tensorflow::Device* d =
- input_op_device == nullptr ? ctx->devices[0] : input_op_device;
+ input_op_device == nullptr ? ctx->context.HostCPU() : input_op_device;
VLOG(1) << "Changing device of operation " << op->name << " to "
<< d->name() << " because input #" << i
<< " is a resource in this device.";
@@ -1095,40 +1047,35 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
}
}
tensorflow::Device* device = op->device;
- if (!ctx->soft_placement && device == nullptr) {
- // TODO(ashankar): ASSUMPTION: ctx->devices[0] is always CPU
- device = ctx->devices[0];
+ if (!ctx->context.SoftPlacement() && device == nullptr) {
+ device = ctx->context.HostCPU();
}
tensorflow::Fprint128 cache_key =
op->attrs.CacheKey(device == nullptr ? "unspecified" : device->name());
- tensorflow::KernelAndDevice* kernel;
- {
- tensorflow::tf_shared_lock l(ctx->cache_mu);
- kernel = tensorflow::gtl::FindPtrOrNull(ctx->kernel_cache, cache_key);
- }
+ tensorflow::KernelAndDevice* kernel = ctx->context.GetCachedKernel(cache_key);
if (kernel == nullptr) {
const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
- if (ctx->soft_placement && device == nullptr) {
+ if (ctx->context.SoftPlacement() && device == nullptr) {
device = SelectDevice(ndef, ctx, status);
if (!status->status.ok()) {
return;
}
}
CHECK(device != nullptr);
- if (ctx->log_device_placement) {
+ if (ctx->context.LogDevicePlacement()) {
LOG(INFO) << "Executing op " << ndef.op() << " in device "
<< device->name();
}
- kernel = new tensorflow::KernelAndDevice(ctx->rendezvous);
+ kernel = new tensorflow::KernelAndDevice(ctx->context.GetRendezvous());
// Knowledge of the implementation of Init (and in-turn
// FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def
// will be accessed, so grab on to the lock.
// See WARNING comment in Execute (before kernel->Run) - would be nice to
// rework to avoid this subtlety.
- tensorflow::tf_shared_lock l(ctx->functions_mu);
- status->status =
- tensorflow::KernelAndDevice::Init(ndef, ctx->func_lib(device), kernel);
+ tensorflow::tf_shared_lock l(*ctx->context.FunctionsMu());
+ status->status = tensorflow::KernelAndDevice::Init(
+ ndef, ctx->context.func_lib(device), kernel);
if (!status->status.ok()) {
delete kernel;
return;
@@ -1136,7 +1083,7 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
// Update output_dtypes inside `kernel`.
const tensorflow::OpDef* op_def = nullptr;
const tensorflow::FunctionDef* function_def =
- ctx->func_lib_def.Find(ndef.op());
+ ctx->context.FuncLibDef()->Find(ndef.op());
if (function_def != nullptr) {
op_def = &(function_def->signature());
}
@@ -1152,8 +1099,7 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
if (!status->status.ok()) {
return;
}
- tensorflow::mutex_lock ml(ctx->cache_mu);
- tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel);
+ ctx->context.AddKernelToCache(cache_key, kernel);
}
const tensorflow::DataTypeVector& output_dtypes = kernel->output_dtypes();
const int output_dtypes_size = output_dtypes.size();
@@ -1171,11 +1117,11 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
// device from the one requested above.
device = kernel->device();
}
- status->status = ValidateInputTypeAndPlacement(ctx, ctx->devices[0], device,
- op, kernel->kernel());
+ status->status = ValidateInputTypeAndPlacement(ctx, ctx->context.HostCPU(),
+ device, op, kernel->kernel());
if (!status->status.ok()) return;
std::unique_ptr<tensorflow::NodeExecStats> maybe_stats;
- if (ctx->should_store_metadata.load()) {
+ if (ctx->context.ShouldStoreMetadata()) {
maybe_stats.reset(new tensorflow::NodeExecStats);
maybe_stats->set_node_name(op->name);
maybe_stats->set_all_start_micros(tensorflow::Env::Default()->NowMicros());
@@ -1183,14 +1129,14 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
maybe_stats->set_scheduled_micros(tensorflow::Env::Default()->NowMicros());
// TODO(apassos) track referenced tensors
}
- if (ctx->Async()) {
+ if (ctx->context.Async()) {
// Note that for async mode, execution order will make sure that all
// input handles are ready before executing them.
// TODO(agarwal): Consider executing "cheap" kernels inline for performance.
tensorflow::EagerNode* node =
new ExecuteNode(op, kernel, maybe_stats.release(), output_dtypes,
retvals, *num_retvals);
- ctx->executor.Add(node);
+ ctx->context.ExecutorAdd(node);
} else {
// Execute checks if retvals[i] is nullptr or not to figure if it needs to
// allocate it.
@@ -1206,23 +1152,24 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
TFE_Context* ctx,
const char* device_name,
TF_Status* status) {
- status->status = ctx->executor.status();
+ status->status = ctx->context.GetStatus();
if (!status->status.ok()) {
return nullptr;
}
- tensorflow::Device* dstd = ctx->devices[0];
+ tensorflow::Device* dstd = ctx->context.HostCPU();
if (device_name != nullptr && strlen(device_name) > 0) {
- status->status = ctx->device_manager->LookupDevice(device_name, &dstd);
+ status->status =
+ ctx->context.device_mgr()->LookupDevice(device_name, &dstd);
if (!status->status.ok()) return nullptr;
}
- if (ctx->Async()) {
+ if (ctx->context.Async()) {
// Note that `h` may not be currently ready. However execution order will
// make sure that `h` is ready before the copy is actually done.
CopyToDeviceNode* node = new CopyToDeviceNode(h, dstd, ctx);
TFE_TensorHandle* output = node->dst();
// Note that calling Add makes `node` accessible by the EagerExecutor
// thread. So further accesses need to be thread-safe.
- ctx->executor.Add(node);
+ ctx->context.ExecutorAdd(node);
return output;
} else {
TFE_TensorHandle* output = nullptr;
@@ -1240,24 +1187,20 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
return;
}
- tensorflow::mutex_lock l(ctx->functions_mu);
- status->status = ctx->func_lib_def.AddFunctionDef(function_def);
+ status->status = ctx->context.AddFunctionDef(function_def);
}
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
TF_Status* status) {
- tensorflow::mutex_lock l(ctx->functions_mu);
- status->status = ctx->func_lib_def.AddFunctionDef(function->fdef);
+ status->status = ctx->context.AddFunctionDef(function->fdef);
}
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
- ctx->should_store_metadata.store(true);
+ ctx->context.SetShouldStoreMetadata(true);
}
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
- tensorflow::mutex_lock ml(ctx->metadata_mu);
- ctx->should_store_metadata.store(false);
- ctx->run_metadata.Clear();
+ ctx->context.SetShouldStoreMetadata(false);
}
} // extern "C"
@@ -1286,9 +1229,9 @@ void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
TF_Status* status) {
TFE_ContextAsyncWait(ctx, status);
if (!status->status.ok()) return;
- tensorflow::mutex_lock ml(ctx->metadata_mu);
- status->status = MessageToBuffer(ctx->run_metadata, buf);
- ctx->run_metadata.Clear();
+ tensorflow::mutex_lock ml(*ctx->context.MetadataMu());
+ status->status = MessageToBuffer(*ctx->context.RunMetadataProto(), buf);
+ ctx->context.RunMetadataProto()->Clear();
}
namespace {
@@ -1363,11 +1306,6 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
} // namespace tensorflow
-bool TFE_Context::Async() const {
- tensorflow::mutex_lock l(async_map_mu);
- return tensorflow::gtl::FindWithDefault(
- thread_local_async, std::this_thread::get_id(), async_default);
-}
bool TFE_TensorHandle::IsReady() {
if (node_id == 0) return true;
@@ -1381,7 +1319,7 @@ tensorflow::Status TFE_TensorHandle::WaitReady() {
{
tensorflow::mutex_lock l(ctx_mutex_);
if (ctx_ == nullptr) return tensorflow::Status::OK();
- executor = &ctx_->executor;
+ executor = ctx_->context.Executor();
}
return executor->WaitFor(node_id);
}
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index a79f8ddd33..5b29120b40 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/runtime.h"
#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
#include "tensorflow/core/common_runtime/function.h"
@@ -52,85 +53,18 @@ struct TFE_ContextOptions {
TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32};
};
-TFE_ContextDevicePlacementPolicy PlacementPolicy(
- bool soft_placement, TFE_ContextDevicePlacementPolicy original_policy);
-
struct TFE_Context {
- explicit TFE_Context(const TFE_ContextOptions& opts,
+ explicit TFE_Context(const tensorflow::SessionOptions& opts,
+ TFE_ContextDevicePlacementPolicy default_policy,
+ bool async,
std::unique_ptr<tensorflow::DeviceMgr> device_mgr,
tensorflow::Rendezvous* rendezvous)
- : soft_placement(
- opts.session_options.options.config.allow_soft_placement()),
- policy(PlacementPolicy(soft_placement, opts.policy)),
- device_manager(std::move(device_mgr)),
- devices(device_manager->ListDevices()),
- rendezvous(rendezvous),
- pflr(new tensorflow::ProcessFunctionLibraryRuntime(
- device_manager.get(), opts.session_options.options.env,
- TF_GRAPH_DEF_VERSION, &func_lib_def, {})),
- log_device_placement(
- opts.session_options.options.config.log_device_placement()),
- async_default(opts.async) {
- if (async_default) executor.EnableAsync();
-
- for (auto* device : devices) {
- devices_map[tensorflow::StringPiece(device->name())] = device;
- }
- }
-
- const bool soft_placement;
- const TFE_ContextDevicePlacementPolicy policy;
-
- // Note: we cannot use C++11 thread_local here as there is no concept of a
- // thread-local-object-local variable in C++11.
- tensorflow::mutex policy_map_mu;
- std::unordered_map<std::thread::id, TFE_ContextDevicePlacementPolicy>
- thread_local_policies GUARDED_BY(policy_map_mu);
-
- std::unique_ptr<tensorflow::DeviceMgr> device_manager;
- // Devices owned by device_manager
- std::vector<tensorflow::Device*> devices;
- // All devices are not owned.
- tensorflow::gtl::FlatMap<tensorflow::StringPiece, tensorflow::Device*,
- tensorflow::StringPieceHasher>
- devices_map;
- tensorflow::Rendezvous* const rendezvous;
-
- tensorflow::mutex functions_mu;
- tensorflow::FunctionLibraryDefinition func_lib_def GUARDED_BY(functions_mu){
- tensorflow::OpRegistry::Global(), {}};
-
- // One FunctionLibraryRuntime per device.
- // func_libs[i] is the FunctionLibraryRuntime corresponding to
- // session->devices[i].
- const std::unique_ptr<tensorflow::ProcessFunctionLibraryRuntime> pflr;
-
- tensorflow::mutex cache_mu;
- std::unordered_map<tensorflow::Fprint128, tensorflow::KernelAndDevice*,
- tensorflow::Fprint128Hasher>
- kernel_cache GUARDED_BY(cache_mu);
-
- tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) const {
- return pflr->GetFLR(d->name());
- }
+ : context(opts,
+ static_cast<tensorflow::ContextDevicePlacementPolicy>(
+ default_policy),
+ async, std::move(device_mgr), rendezvous) {}
- // Whether we should compute RunMetadata.
- std::atomic<bool> should_store_metadata{false};
- tensorflow::mutex metadata_mu;
- tensorflow::RunMetadata run_metadata GUARDED_BY(metadata_mu);
- const bool log_device_placement;
- // EagerExecutor for async execution.
- tensorflow::EagerExecutor executor;
-
- // True if running in asynchronous mode.
- bool Async() const;
-
- // True if the default value for execution mode is async. Note that this value
- // can be overridden per thread based on `thread_local_async` overrides.
- const bool async_default;
- mutable tensorflow::mutex async_map_mu;
- std::unordered_map<std::thread::id, bool> thread_local_async
- GUARDED_BY(async_map_mu);
+ tensorflow::EagerContext context;
};
struct TFE_TensorHandle : public tensorflow::core::RefCounted {
diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD
index 8ba560bef8..de10b10b7e 100644
--- a/tensorflow/core/common_runtime/eager/BUILD
+++ b/tensorflow/core/common_runtime/eager/BUILD
@@ -33,6 +33,28 @@ tf_cuda_library(
)
tf_cuda_library(
+ name = "context",
+ srcs = [
+ "context.cc",
+ ],
+ hdrs = [
+ "context.h",
+ ],
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ ":eager_executor",
+ ":kernel_and_device",
+ "//tensorflow/core:core_cpu_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:session_options",
+ ],
+)
+
+tf_cuda_library(
name = "kernel_and_device",
srcs = [
"kernel_and_device.cc",
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
new file mode 100644
index 0000000000..0566329f18
--- /dev/null
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -0,0 +1,153 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/eager/context.h"
+
+namespace tensorflow {
+
+ContextDevicePlacementPolicy PlacementPolicy(
+ bool soft_placement, ContextDevicePlacementPolicy original_policy) {
+ if (!soft_placement) {
+ return original_policy;
+ }
+ if (original_policy == DEVICE_PLACEMENT_EXPLICIT ||
+ original_policy == DEVICE_PLACEMENT_SILENT_FOR_INT32) {
+ return DEVICE_PLACEMENT_SILENT;
+ }
+ return original_policy;
+}
+
+EagerContext::EagerContext(const SessionOptions& opts,
+ ContextDevicePlacementPolicy default_policy,
+ bool async, std::unique_ptr<DeviceMgr> device_mgr,
+ Rendezvous* rendezvous)
+ : soft_placement_(opts.config.allow_soft_placement()),
+ policy_(PlacementPolicy(soft_placement_, default_policy)),
+ device_manager_(std::move(device_mgr)),
+ devices_(device_manager_->ListDevices()),
+ rendezvous_(rendezvous),
+ pflr_(new ProcessFunctionLibraryRuntime(device_manager_.get(), opts.env,
+ TF_GRAPH_DEF_VERSION,
+ &func_lib_def_, {})),
+ log_device_placement_(opts.config.log_device_placement()),
+ async_default_(async) {
+ if (async_default_) {
+ executor_.EnableAsync();
+ }
+
+ for (auto* device : devices_) {
+ devices_map_[device->name()] = device;
+ }
+}
+
+bool EagerContext::Async() const {
+ mutex_lock l(async_map_mu_);
+ return gtl::FindWithDefault(thread_local_async_, std::this_thread::get_id(),
+ async_default_);
+}
+
+Status EagerContext::SetAsyncForThread(bool async) {
+ {
+ tensorflow::mutex_lock l(async_map_mu_);
+ thread_local_async_[std::this_thread::get_id()] = async;
+ }
+ if (async) {
+ executor_.EnableAsync();
+ } else {
+ // TODO(agarwal): Currently we add a wait here to handle cases where a
+ // sync op has a control dependency on an async op, and the latter has not
+ // executed yet. This wait can be removed by storing all the control
+ // inputs and waiting for them when executing ops.
+ return executor_.WaitForAllPendingNodes();
+ }
+ return Status::OK();
+}
+
+void EagerContext::ClearCaches() {
+ mutex_lock ml(cache_mu_);
+ gtl::STLDeleteValues(&kernel_cache_);
+}
+
+void EagerContext::SetThreadLocalDevicePlacementPolicy(
+ ContextDevicePlacementPolicy policy) {
+ mutex_lock ml(policy_map_mu_);
+ thread_local_policies_[std::this_thread::get_id()] = policy;
+}
+
+ContextDevicePlacementPolicy EagerContext::GetDevicePlacementPolicy() {
+ mutex_lock ml(policy_map_mu_);
+ auto policy_map_it = thread_local_policies_.find(std::this_thread::get_id());
+ if (policy_map_it != thread_local_policies_.end()) {
+ return policy_map_it->second;
+ }
+ return policy_;
+}
+
+EagerContext::~EagerContext() {
+ executor_.WaitForAllPendingNodes().IgnoreError();
+ ClearCaches();
+ rendezvous_->Unref();
+}
+
+bool EagerContext::FindFunctionByName(const string& name) {
+ mutex_lock l(functions_mu_);
+ return func_lib_def_.Find(name) != nullptr;
+}
+
+Status EagerContext::FindFunctionOpData(
+ const string& name, const tensorflow::OpRegistrationData** op_data) {
+ mutex_lock l(functions_mu_);
+ return func_lib_def_.LookUp(name, op_data);
+}
+
+const FunctionDef* EagerContext::FindFunctionDef(const string& name) {
+ mutex_lock l(functions_mu_);
+ return func_lib_def_.Find(name);
+}
+
+Status EagerContext::FindDeviceByName(const string& name, Device** result) {
+ auto it = devices_map_.find(name);
+ if (it == devices_map_.end()) {
+ return errors::InvalidArgument(name, " unknown device.");
+ }
+ *result = it->second;
+ return Status::OK();
+}
+
+Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
+ mutex_lock l(functions_mu_);
+ return func_lib_def_.AddFunctionDef(fdef);
+}
+
+KernelAndDevice* EagerContext::GetCachedKernel(Fprint128 cache_key) {
+ tf_shared_lock l(cache_mu_);
+ return gtl::FindPtrOrNull(kernel_cache_, cache_key);
+}
+
+void EagerContext::AddKernelToCache(Fprint128 cache_key,
+ KernelAndDevice* kernel) {
+ mutex_lock ml(cache_mu_);
+ gtl::InsertOrUpdate(&kernel_cache_, cache_key, kernel);
+}
+
+void EagerContext::SetShouldStoreMetadata(bool value) {
+ should_store_metadata_.store(value);
+ if (!value) {
+ mutex_lock ml(metadata_mu_);
+ run_metadata_.Clear();
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
new file mode 100644
index 0000000000..bc97219dae
--- /dev/null
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -0,0 +1,198 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
+
+#include <algorithm>
+#include <cstddef>
+#include <map>
+#include <memory>
+#include <queue>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/eager/eager_executor.h"
+#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
+#include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/platform/fingerprint.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/public/version.h"
+
+namespace tensorflow {
+
+// Note: there's a copy enum in eager/c_api.h. It should be kept in sync.
+enum ContextDevicePlacementPolicy {
+ // Running operations with input tensors on the wrong device will fail. When
+ // soft placement is enabled acts like TFE_DEVICE_PLACEMENT_SILENT.
+ DEVICE_PLACEMENT_EXPLICIT = 0,
+ // Copy the tensor to the right device but log a warning.
+ DEVICE_PLACEMENT_WARN = 1,
+ // Silently copy the tensor, which has a performance cost since the
+ // operation will be blocked till the copy completes.
+ DEVICE_PLACEMENT_SILENT = 2,
+ // Default placement policy which silently copies int32 tensors but not other
+ // dtypes. When soft placement is enabled acts like
+ // TFE_DEVICE_PLACEMENT_SILENT.
+ DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
+};
+
+ContextDevicePlacementPolicy PlacementPolicy(
+ bool soft_placement, ContextDevicePlacementPolicy original_policy);
+
+class EagerContext {
+ public:
+ explicit EagerContext(const SessionOptions& opts,
+ ContextDevicePlacementPolicy default_policy, bool async,
+ std::unique_ptr<DeviceMgr> device_mgr,
+ Rendezvous* rendezvous);
+
+ ~EagerContext();
+
+ // Returns the function library runtime for the given device.
+ FunctionLibraryRuntime* func_lib(Device* d) const {
+ return pflr_->GetFLR(d->name());
+ }
+
+ // True if running in asynchronous mode.
+ bool Async() const;
+
+ EagerExecutor* Executor() { return &executor_; }
+
+ // Sets whether this thread should run in synchronous or asynchronous mode.
+ Status SetAsyncForThread(bool async);
+
+ // TODO(apassos) make this return a constant reference
+ gtl::FlatMap<string, Device*, StringPieceHasher>* device_map() {
+ return &devices_map_;
+ }
+
+ // TODO(apassos) make this return a constant reference
+ std::vector<Device*>* devices() { return &devices_; }
+
+ // Clears the kernel caches.
+ void ClearCaches();
+
+ // Sets the device placement policy for the current thread.
+ void SetThreadLocalDevicePlacementPolicy(ContextDevicePlacementPolicy policy);
+
+ // Returns the device placement policy for the current thread.
+ ContextDevicePlacementPolicy GetDevicePlacementPolicy();
+
+ Status AsyncWait() { return executor_.WaitForAllPendingNodes(); }
+
+ Status GetStatus() { return executor_.status(); }
+
+ void ClearAsyncError() { executor_.ClearError(); }
+
+ bool FindFunctionByName(const string& name);
+
+ Status FindFunctionOpData(const string& name,
+ const tensorflow::OpRegistrationData** op_data);
+
+ const FunctionDef* FindFunctionDef(const string& name);
+
+ Status FindDeviceByName(const string& name, Device** result);
+
+ Device* HostCPU() { return devices_[0]; }
+
+ bool SoftPlacement() { return soft_placement_; }
+
+ uint64 NextId() { return executor_.NextId(); }
+
+ void ExecutorAdd(EagerNode* node) { executor_.Add(node); }
+
+ Status AddFunctionDef(const FunctionDef& fdef);
+
+ KernelAndDevice* GetCachedKernel(Fprint128 cache_key);
+
+ void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel);
+
+ bool LogDevicePlacement() { return log_device_placement_; }
+
+ Rendezvous* GetRendezvous() { return rendezvous_; }
+
+ mutex* FunctionsMu() { return &functions_mu_; }
+
+ tensorflow::DeviceMgr* device_mgr() { return device_manager_.get(); }
+
+ // TODO(apassos) remove the need for this
+ void ReleaseDeviceMgr() { device_manager_.release(); }
+
+ // TODO(apassos) clean up RunMetadata storage.
+ mutex* MetadataMu() { return &metadata_mu_; }
+ bool ShouldStoreMetadata() { return should_store_metadata_.load(); }
+ void SetShouldStoreMetadata(bool value);
+ RunMetadata* RunMetadataProto() { return &run_metadata_; }
+
+ FunctionLibraryDefinition* FuncLibDef() { return &func_lib_def_; }
+
+ private:
+ const bool soft_placement_;
+ const ContextDevicePlacementPolicy policy_;
+
+ // Note: we cannot use C++11 thread_local here as there is no concept of a
+ // thread-local-object-local variable in C++11.
+ mutex policy_map_mu_;
+ std::unordered_map<std::thread::id, ContextDevicePlacementPolicy>
+ thread_local_policies_ GUARDED_BY(policy_map_mu_);
+
+ std::unique_ptr<DeviceMgr> device_manager_;
+ // Devices owned by device_manager
+ std::vector<Device*> devices_;
+ // All devices are not owned.
+ gtl::FlatMap<string, Device*, StringPieceHasher> devices_map_;
+ Rendezvous* const rendezvous_;
+
+ mutex functions_mu_;
+ FunctionLibraryDefinition func_lib_def_ GUARDED_BY(functions_mu_){
+ OpRegistry::Global(), {}};
+
+ // One FunctionLibraryRuntime per device.
+ // func_libs[i] is the FunctionLibraryRuntime corresponding to
+ // session->devices[i].
+ const std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
+
+ mutex cache_mu_;
+ std::unordered_map<Fprint128, KernelAndDevice*, Fprint128Hasher> kernel_cache_
+ GUARDED_BY(cache_mu_);
+
+ // Whether we should compute RunMetadata.
+ std::atomic<bool> should_store_metadata_{false};
+ mutex metadata_mu_;
+ RunMetadata run_metadata_ GUARDED_BY(metadata_mu_);
+ const bool log_device_placement_;
+ // EagerExecutor for async execution.
+ EagerExecutor executor_;
+
+ // True if the default value for execution mode is async. Note that this value
+ // can be overridden per thread based on `thread_local_async` overrides.
+ const bool async_default_;
+ mutable mutex async_map_mu_;
+ std::unordered_map<std::thread::id, bool> thread_local_async_
+ GUARDED_BY(async_map_mu_);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_