aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/eager/c_api.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-21 23:11:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-21 23:14:11 -0700
commitf83711104b64a108ac43213c92f13827343d09ef (patch)
tree07258ae0bdc8ec46af93111ed0b530dab3959021 /tensorflow/c/eager/c_api.cc
parent0e1775355f9d7fe5301bc0d17906453caf970e27 (diff)
Automated g4 rollback of changelist 190001737
PiperOrigin-RevId: 190021164
Diffstat (limited to 'tensorflow/c/eager/c_api.cc')
-rw-r--r--tensorflow/c/eager/c_api.cc179
1 files changed, 119 insertions, 60 deletions
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 5d668848ab..a23015c99e 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -71,6 +71,18 @@ std::atomic_int_fast64_t func_id_generator(0);
} // namespace
+TFE_ContextDevicePlacementPolicy PlacementPolicy(
+ bool soft_placement, TFE_ContextDevicePlacementPolicy original_policy) {
+ if (!soft_placement) {
+ return original_policy;
+ }
+ if (original_policy == TFE_DEVICE_PLACEMENT_EXPLICIT ||
+ original_policy == TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32) {
+ return TFE_DEVICE_PLACEMENT_SILENT;
+ }
+ return original_policy;
+}
+
extern "C" {
TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; }
@@ -92,7 +104,19 @@ void TFE_ContextOptionsSetDevicePlacementPolicy(
TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
unsigned char async,
TF_Status* status) {
- status->status = ctx->context.SetAsyncForThread(async);
+ {
+ tensorflow::mutex_lock l(ctx->async_map_mu);
+ ctx->thread_local_async[std::this_thread::get_id()] = async;
+ }
+ if (async) {
+ ctx->executor.EnableAsync();
+ } else {
+ // TODO(agarwal): Currently we add a wait here to handle cases where a sync
+ // op has a control dependency on an async op, and the latter has not
+ // executed yet. This wait can be removed by storing all the control inputs
+ // and waiting for them when executing ops.
+ status->status = ctx->executor.WaitForAllPendingNodes();
+ }
}
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
@@ -109,26 +133,34 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
new tensorflow::DeviceMgr(devices));
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr.get());
- return new TFE_Context(opts->session_options.options, opts->policy,
- opts->async, std::move(device_mgr), r);
+ return new TFE_Context(*opts, std::move(device_mgr), r);
}
void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) {
+ status->status = ctx->executor.WaitForAllPendingNodes();
+ {
+ tensorflow::mutex_lock ml(ctx->cache_mu);
+ tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache);
+ }
+ ctx->rendezvous->Unref();
delete ctx;
}
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
TF_DeviceList* list = new TF_DeviceList;
- ctx->context.device_mgr()->ListDeviceAttributes(&list->response);
+ ctx->device_manager->ListDeviceAttributes(&list->response);
return list;
}
-void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context.ClearCaches(); }
+void TFE_ContextClearCaches(TFE_Context* ctx) {
+ tensorflow::mutex_lock ml(ctx->cache_mu);
+ tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache);
+}
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
- ctx->context.SetThreadLocalDevicePlacementPolicy(
- static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
+ tensorflow::mutex_lock ml(ctx->policy_map_mu);
+ ctx->thread_local_policies[std::this_thread::get_id()] = policy;
}
// Note: this function looks up a thread local policy. So it should be called in
@@ -136,20 +168,25 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy(
// safe to call this function from the async EagerExecutor threads.
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
TFE_Context* ctx) {
- return static_cast<TFE_ContextDevicePlacementPolicy>(
- ctx->context.GetDevicePlacementPolicy());
+ tensorflow::mutex_lock ml(ctx->policy_map_mu);
+ auto policy_map_it =
+ ctx->thread_local_policies.find(std::this_thread::get_id());
+ if (policy_map_it != ctx->thread_local_policies.end()) {
+ return policy_map_it->second;
+ }
+ return ctx->policy;
}
void TFE_ContextAsyncWait(TFE_Context* ctx, TF_Status* status) {
- status->status = ctx->context.AsyncWait();
+ status->status = ctx->executor.WaitForAllPendingNodes();
}
void TFE_ContextGetStatus(TFE_Context* ctx, TF_Status* status) {
- status->status = ctx->context.GetStatus();
+ status->status = ctx->executor.status();
}
void TFE_ContextAsyncClearError(TFE_Context* ctx) {
- ctx->context.ClearAsyncError();
+ ctx->executor.ClearError();
}
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
@@ -222,7 +259,7 @@ tensorflow::Status TensorHandleCopyToDevice(TFE_TensorHandle* h,
// nullptr.
tensorflow::Device* src_opd = nullptr;
TF_RETURN_IF_ERROR(h->TensorAndDevice(&src, &srcd, &src_opd));
- if (srcd == nullptr) srcd = ctx->context.HostCPU();
+ if (srcd == nullptr) srcd = ctx->devices[0];
bool is_same_device =
(srcd == dstd) || (DeviceName(srcd) == DeviceName(dstd));
const bool dst_cpu = IsCPU(dstd);
@@ -295,7 +332,8 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
status->status = tensorflow::AttrTypeMapForOp(name, &types);
if (status->status.ok()) return new TFE_Op(ctx, name, types);
if (TF_GetCode(status) == TF_NOT_FOUND) {
- if (ctx->context.FindFunctionByName(name)) {
+ tensorflow::mutex_lock l(ctx->functions_mu);
+ if (ctx->func_lib_def.Find(name) != nullptr) {
status->status = tensorflow::Status::OK();
return new TFE_Op(ctx, name, nullptr);
}
@@ -308,14 +346,20 @@ void TFE_DeleteOp(TFE_Op* op) { delete op; }
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
tensorflow::Device* d = nullptr;
if (device_name != nullptr && strlen(device_name) > 0) {
- status->status = op->ctx->context.FindDeviceByName(device_name, &d);
+ auto it = op->ctx->devices_map.find(device_name);
+ if (it == op->ctx->devices_map.end()) {
+ status->status =
+ tensorflow::errors::InvalidArgument(device_name, " unknown device.");
+ return;
+ }
+ d = it->second;
}
op->device = d;
}
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
tensorflow::Device* device =
- (op->device == nullptr) ? op->ctx->context.HostCPU() : op->device;
+ (op->device == nullptr) ? op->ctx->devices[0] : op->device;
return device->name().c_str();
}
@@ -590,7 +634,7 @@ tensorflow::Status ValidateInputTypeAndPlacement(
tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef,
TFE_Context* ctx, TF_Status* status) {
tensorflow::DeviceSet ds;
- for (tensorflow::Device* d : *ctx->context.devices()) {
+ for (tensorflow::Device* d : ctx->devices) {
ds.AddDevice(d);
}
tensorflow::DeviceTypeVector final_devices;
@@ -604,7 +648,7 @@ tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef,
"Could not find valid device for node ", ndef.DebugString());
return nullptr;
}
- for (tensorflow::Device* d : *ctx->context.devices()) {
+ for (tensorflow::Device* d : ctx->devices) {
if (d->device_type() == final_devices[0].type_string()) {
return d;
}
@@ -619,8 +663,9 @@ tensorflow::Status Execute(
const tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4>& op_inputs,
tensorflow::KernelAndDevice* kernel, tensorflow::NodeExecStats* maybe_stats,
TFE_TensorHandle** retvals, int num_retvals) {
- if (!ctx->context.SoftPlacement() && device == nullptr) {
- device = ctx->context.HostCPU();
+ if (!ctx->soft_placement && device == nullptr) {
+ // TODO(ashankar): ASSUMPTION: ctx->devices[0] is always CPU
+ device = ctx->devices[0];
}
if (device == nullptr) {
@@ -652,18 +697,18 @@ tensorflow::Status Execute(
if (maybe_stats != nullptr) {
maybe_stats->set_op_end_rel_micros(tensorflow::Env::Default()->NowMicros() -
maybe_stats->all_start_micros());
- tensorflow::mutex_lock ml(*ctx->context.MetadataMu());
- if (ctx->context.ShouldStoreMetadata()) {
- auto* step_stats = ctx->context.RunMetadataProto()->mutable_step_stats();
+ tensorflow::mutex_lock ml(ctx->metadata_mu);
+ if (ctx->should_store_metadata.load()) {
+ auto* step_stats = ctx->run_metadata.mutable_step_stats();
// Lazily initialize the RunMetadata with information about all devices if
// this is the first call.
- while (step_stats->dev_stats_size() < ctx->context.devices()->size()) {
+ while (step_stats->dev_stats_size() < ctx->devices.size()) {
step_stats->add_dev_stats();
}
// Find the current device's index.
int device_idx = 0;
- for (int i = 0; i < ctx->context.devices()->size(); ++i) {
- if (ctx->context.devices()->at(i) == device) {
+ for (int i = 0; i < ctx->devices.size(); ++i) {
+ if (ctx->devices[i] == device) {
device_idx = i;
break;
}
@@ -699,7 +744,7 @@ class ExecuteNode : public tensorflow::EagerNode {
tensorflow::NodeExecStats* maybe_stats,
const tensorflow::DataTypeVector& output_dtypes,
TFE_TensorHandle** retvals, int num_retvals)
- : tensorflow::EagerNode(op->ctx->context.NextId()),
+ : tensorflow::EagerNode(op->ctx->executor.NextId()),
ctx_(op->ctx),
op_device_(op->device),
inputs_(op->inputs),
@@ -755,7 +800,7 @@ class CopyToDeviceNode : public tensorflow::EagerNode {
public:
CopyToDeviceNode(TFE_TensorHandle* src, tensorflow::Device* dstd,
TFE_Context* ctx)
- : tensorflow::EagerNode(ctx->context.NextId()),
+ : tensorflow::EagerNode(ctx->executor.NextId()),
src_(src),
dstd_(dstd),
ctx_(ctx),
@@ -1018,7 +1063,7 @@ extern "C" {
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status) {
TFE_Context* ctx = op->ctx;
- status->status = ctx->context.GetStatus();
+ status->status = ctx->executor.status();
if (!status->status.ok()) {
return;
}
@@ -1042,7 +1087,7 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
if (op->inputs[i]->dtype == tensorflow::DT_RESOURCE &&
input_op_device != op->device) {
tensorflow::Device* d =
- input_op_device == nullptr ? ctx->context.HostCPU() : input_op_device;
+ input_op_device == nullptr ? ctx->devices[0] : input_op_device;
VLOG(1) << "Changing device of operation " << op->name << " to "
<< d->name() << " because input #" << i
<< " is a resource in this device.";
@@ -1050,35 +1095,40 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
}
}
tensorflow::Device* device = op->device;
- if (!ctx->context.SoftPlacement() && device == nullptr) {
- device = ctx->context.HostCPU();
+ if (!ctx->soft_placement && device == nullptr) {
+ // TODO(ashankar): ASSUMPTION: ctx->devices[0] is always CPU
+ device = ctx->devices[0];
}
tensorflow::Fprint128 cache_key =
op->attrs.CacheKey(device == nullptr ? "unspecified" : device->name());
- tensorflow::KernelAndDevice* kernel = ctx->context.GetCachedKernel(cache_key);
+ tensorflow::KernelAndDevice* kernel;
+ {
+ tensorflow::tf_shared_lock l(ctx->cache_mu);
+ kernel = tensorflow::gtl::FindPtrOrNull(ctx->kernel_cache, cache_key);
+ }
if (kernel == nullptr) {
const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
- if (ctx->context.SoftPlacement() && device == nullptr) {
+ if (ctx->soft_placement && device == nullptr) {
device = SelectDevice(ndef, ctx, status);
if (!status->status.ok()) {
return;
}
}
CHECK(device != nullptr);
- if (ctx->context.LogDevicePlacement()) {
+ if (ctx->log_device_placement) {
LOG(INFO) << "Executing op " << ndef.op() << " in device "
<< device->name();
}
- kernel = new tensorflow::KernelAndDevice(ctx->context.GetRendezvous());
+ kernel = new tensorflow::KernelAndDevice(ctx->rendezvous);
// Knowledge of the implementation of Init (and in-turn
// FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def
// will be accessed, so grab on to the lock.
// See WARNING comment in Execute (before kernel->Run) - would be nice to
// rework to avoid this subtlety.
- tensorflow::tf_shared_lock l(*ctx->context.FunctionsMu());
- status->status = tensorflow::KernelAndDevice::Init(
- ndef, ctx->context.func_lib(device), kernel);
+ tensorflow::tf_shared_lock l(ctx->functions_mu);
+ status->status =
+ tensorflow::KernelAndDevice::Init(ndef, ctx->func_lib(device), kernel);
if (!status->status.ok()) {
delete kernel;
return;
@@ -1086,7 +1136,7 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
// Update output_dtypes inside `kernel`.
const tensorflow::OpDef* op_def = nullptr;
const tensorflow::FunctionDef* function_def =
- ctx->context.FuncLibDef()->Find(ndef.op());
+ ctx->func_lib_def.Find(ndef.op());
if (function_def != nullptr) {
op_def = &(function_def->signature());
}
@@ -1102,7 +1152,8 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
if (!status->status.ok()) {
return;
}
- ctx->context.AddKernelToCache(cache_key, kernel);
+ tensorflow::mutex_lock ml(ctx->cache_mu);
+ tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel);
}
const tensorflow::DataTypeVector& output_dtypes = kernel->output_dtypes();
const int output_dtypes_size = output_dtypes.size();
@@ -1120,11 +1171,11 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
// device from the one requested above.
device = kernel->device();
}
- status->status = ValidateInputTypeAndPlacement(ctx, ctx->context.HostCPU(),
- device, op, kernel->kernel());
+ status->status = ValidateInputTypeAndPlacement(ctx, ctx->devices[0], device,
+ op, kernel->kernel());
if (!status->status.ok()) return;
std::unique_ptr<tensorflow::NodeExecStats> maybe_stats;
- if (ctx->context.ShouldStoreMetadata()) {
+ if (ctx->should_store_metadata.load()) {
maybe_stats.reset(new tensorflow::NodeExecStats);
maybe_stats->set_node_name(op->name);
maybe_stats->set_all_start_micros(tensorflow::Env::Default()->NowMicros());
@@ -1132,14 +1183,14 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
maybe_stats->set_scheduled_micros(tensorflow::Env::Default()->NowMicros());
// TODO(apassos) track referenced tensors
}
- if (ctx->context.Async()) {
+ if (ctx->Async()) {
// Note that for async mode, execution order will make sure that all
// input handles are ready before executing them.
// TODO(agarwal): Consider executing "cheap" kernels inline for performance.
tensorflow::EagerNode* node =
new ExecuteNode(op, kernel, maybe_stats.release(), output_dtypes,
retvals, *num_retvals);
- ctx->context.ExecutorAdd(node);
+ ctx->executor.Add(node);
} else {
// Execute checks if retvals[i] is nullptr or not to figure if it needs to
// allocate it.
@@ -1155,24 +1206,23 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
TFE_Context* ctx,
const char* device_name,
TF_Status* status) {
- status->status = ctx->context.GetStatus();
+ status->status = ctx->executor.status();
if (!status->status.ok()) {
return nullptr;
}
- tensorflow::Device* dstd = ctx->context.HostCPU();
+ tensorflow::Device* dstd = ctx->devices[0];
if (device_name != nullptr && strlen(device_name) > 0) {
- status->status =
- ctx->context.device_mgr()->LookupDevice(device_name, &dstd);
+ status->status = ctx->device_manager->LookupDevice(device_name, &dstd);
if (!status->status.ok()) return nullptr;
}
- if (ctx->context.Async()) {
+ if (ctx->Async()) {
// Note that `h` may not be currently ready. However execution order will
// make sure that `h` is ready before the copy is actually done.
CopyToDeviceNode* node = new CopyToDeviceNode(h, dstd, ctx);
TFE_TensorHandle* output = node->dst();
// Note that calling Add makes `node` accessible by the EagerExecutor
// thread. So further accesses need to be thread-safe.
- ctx->context.ExecutorAdd(node);
+ ctx->executor.Add(node);
return output;
} else {
TFE_TensorHandle* output = nullptr;
@@ -1190,20 +1240,24 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
return;
}
- status->status = ctx->context.AddFunctionDef(function_def);
+ tensorflow::mutex_lock l(ctx->functions_mu);
+ status->status = ctx->func_lib_def.AddFunctionDef(function_def);
}
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
TF_Status* status) {
- status->status = ctx->context.AddFunctionDef(function->fdef);
+ tensorflow::mutex_lock l(ctx->functions_mu);
+ status->status = ctx->func_lib_def.AddFunctionDef(function->fdef);
}
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
- ctx->context.SetShouldStoreMetadata(true);
+ ctx->should_store_metadata.store(true);
}
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
- ctx->context.SetShouldStoreMetadata(false);
+ tensorflow::mutex_lock ml(ctx->metadata_mu);
+ ctx->should_store_metadata.store(false);
+ ctx->run_metadata.Clear();
}
} // extern "C"
@@ -1232,9 +1286,9 @@ void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
TF_Status* status) {
TFE_ContextAsyncWait(ctx, status);
if (!status->status.ok()) return;
- tensorflow::mutex_lock ml(*ctx->context.MetadataMu());
- status->status = MessageToBuffer(*ctx->context.RunMetadataProto(), buf);
- ctx->context.RunMetadataProto()->Clear();
+ tensorflow::mutex_lock ml(ctx->metadata_mu);
+ status->status = MessageToBuffer(ctx->run_metadata, buf);
+ ctx->run_metadata.Clear();
}
namespace {
@@ -1309,6 +1363,11 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
} // namespace tensorflow
+bool TFE_Context::Async() const {
+ tensorflow::mutex_lock l(async_map_mu);
+ return tensorflow::gtl::FindWithDefault(
+ thread_local_async, std::this_thread::get_id(), async_default);
+}
bool TFE_TensorHandle::IsReady() {
if (node_id == 0) return true;
@@ -1322,7 +1381,7 @@ tensorflow::Status TFE_TensorHandle::WaitReady() {
{
tensorflow::mutex_lock l(ctx_mutex_);
if (ctx_ == nullptr) return tensorflow::Status::OK();
- executor = ctx_->context.Executor();
+ executor = &ctx_->executor;
}
return executor->WaitFor(node_id);
}