diff options
author | 2018-03-21 23:11:40 -0700 | |
---|---|---|
committer | 2018-03-21 23:14:11 -0700 | |
commit | f83711104b64a108ac43213c92f13827343d09ef (patch) | |
tree | 07258ae0bdc8ec46af93111ed0b530dab3959021 /tensorflow/c/eager/c_api.cc | |
parent | 0e1775355f9d7fe5301bc0d17906453caf970e27 (diff) |
Automated g4 rollback of changelist 190001737
PiperOrigin-RevId: 190021164
Diffstat (limited to 'tensorflow/c/eager/c_api.cc')
-rw-r--r-- | tensorflow/c/eager/c_api.cc | 179 |
1 files changed, 119 insertions, 60 deletions
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 5d668848ab..a23015c99e 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -71,6 +71,18 @@ std::atomic_int_fast64_t func_id_generator(0); } // namespace +TFE_ContextDevicePlacementPolicy PlacementPolicy( + bool soft_placement, TFE_ContextDevicePlacementPolicy original_policy) { + if (!soft_placement) { + return original_policy; + } + if (original_policy == TFE_DEVICE_PLACEMENT_EXPLICIT || + original_policy == TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32) { + return TFE_DEVICE_PLACEMENT_SILENT; + } + return original_policy; +} + extern "C" { TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; } @@ -92,7 +104,19 @@ void TFE_ContextOptionsSetDevicePlacementPolicy( TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, unsigned char async, TF_Status* status) { - status->status = ctx->context.SetAsyncForThread(async); + { + tensorflow::mutex_lock l(ctx->async_map_mu); + ctx->thread_local_async[std::this_thread::get_id()] = async; + } + if (async) { + ctx->executor.EnableAsync(); + } else { + // TODO(agarwal): Currently we add a wait here to handle cases where a sync + // op has a control dependency on an async op, and the latter has not + // executed yet. This wait can be removed by storing all the control inputs + // and waiting for them when executing ops. + status->status = ctx->executor.WaitForAllPendingNodes(); + } } void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } @@ -109,26 +133,34 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { new tensorflow::DeviceMgr(devices)); tensorflow::Rendezvous* r = new tensorflow::IntraProcessRendezvous(device_mgr.get()); - return new TFE_Context(opts->session_options.options, opts->policy, - opts->async, std::move(device_mgr), r); + return new TFE_Context(*opts, std::move(device_mgr), r); } void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) { + status->status = ctx->executor.WaitForAllPendingNodes(); + { + tensorflow::mutex_lock ml(ctx->cache_mu); + tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache); + } + ctx->rendezvous->Unref(); delete ctx; } TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { TF_DeviceList* list = new TF_DeviceList; - ctx->context.device_mgr()->ListDeviceAttributes(&list->response); + ctx->device_manager->ListDeviceAttributes(&list->response); return list; } -void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context.ClearCaches(); } +void TFE_ContextClearCaches(TFE_Context* ctx) { + tensorflow::mutex_lock ml(ctx->cache_mu); + tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache); +} void TFE_ContextSetThreadLocalDevicePlacementPolicy( TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { - ctx->context.SetThreadLocalDevicePlacementPolicy( - static_cast<tensorflow::ContextDevicePlacementPolicy>(policy)); + tensorflow::mutex_lock ml(ctx->policy_map_mu); + ctx->thread_local_policies[std::this_thread::get_id()] = policy; } // Note: this function looks up a thread local policy. So it should be called in @@ -136,20 +168,25 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy( // safe to call this function from the async EagerExecutor threads. extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( TFE_Context* ctx) { - return static_cast<TFE_ContextDevicePlacementPolicy>( - ctx->context.GetDevicePlacementPolicy()); + tensorflow::mutex_lock ml(ctx->policy_map_mu); + auto policy_map_it = + ctx->thread_local_policies.find(std::this_thread::get_id()); + if (policy_map_it != ctx->thread_local_policies.end()) { + return policy_map_it->second; + } + return ctx->policy; } void TFE_ContextAsyncWait(TFE_Context* ctx, TF_Status* status) { - status->status = ctx->context.AsyncWait(); + status->status = ctx->executor.WaitForAllPendingNodes(); } void TFE_ContextGetStatus(TFE_Context* ctx, TF_Status* status) { - status->status = ctx->context.GetStatus(); + status->status = ctx->executor.status(); } void TFE_ContextAsyncClearError(TFE_Context* ctx) { - ctx->context.ClearAsyncError(); + ctx->executor.ClearError(); } TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { @@ -222,7 +259,7 @@ tensorflow::Status TensorHandleCopyToDevice(TFE_TensorHandle* h, // nullptr. tensorflow::Device* src_opd = nullptr; TF_RETURN_IF_ERROR(h->TensorAndDevice(&src, &srcd, &src_opd)); - if (srcd == nullptr) srcd = ctx->context.HostCPU(); + if (srcd == nullptr) srcd = ctx->devices[0]; bool is_same_device = (srcd == dstd) || (DeviceName(srcd) == DeviceName(dstd)); const bool dst_cpu = IsCPU(dstd); @@ -295,7 +332,8 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, status->status = tensorflow::AttrTypeMapForOp(name, &types); if (status->status.ok()) return new TFE_Op(ctx, name, types); if (TF_GetCode(status) == TF_NOT_FOUND) { - if (ctx->context.FindFunctionByName(name)) { + tensorflow::mutex_lock l(ctx->functions_mu); + if (ctx->func_lib_def.Find(name) != nullptr) { status->status = tensorflow::Status::OK(); return new TFE_Op(ctx, name, nullptr); } @@ -308,14 +346,20 @@ void TFE_DeleteOp(TFE_Op* op) { delete op; } void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) { tensorflow::Device* d = nullptr; if (device_name != nullptr && strlen(device_name) > 0) { - status->status = op->ctx->context.FindDeviceByName(device_name, &d); + auto it = op->ctx->devices_map.find(device_name); + if (it == op->ctx->devices_map.end()) { + status->status = + tensorflow::errors::InvalidArgument(device_name, " unknown device."); + return; + } + d = it->second; } op->device = d; } const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) { tensorflow::Device* device = - (op->device == nullptr) ? op->ctx->context.HostCPU() : op->device; + (op->device == nullptr) ? op->ctx->devices[0] : op->device; return device->name().c_str(); } @@ -590,7 +634,7 @@ tensorflow::Status ValidateInputTypeAndPlacement( tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef, TFE_Context* ctx, TF_Status* status) { tensorflow::DeviceSet ds; - for (tensorflow::Device* d : *ctx->context.devices()) { + for (tensorflow::Device* d : ctx->devices) { ds.AddDevice(d); } tensorflow::DeviceTypeVector final_devices; @@ -604,7 +648,7 @@ tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef, "Could not find valid device for node ", ndef.DebugString()); return nullptr; } - for (tensorflow::Device* d : *ctx->context.devices()) { + for (tensorflow::Device* d : ctx->devices) { if (d->device_type() == final_devices[0].type_string()) { return d; } @@ -619,8 +663,9 @@ tensorflow::Status Execute( const tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4>& op_inputs, tensorflow::KernelAndDevice* kernel, tensorflow::NodeExecStats* maybe_stats, TFE_TensorHandle** retvals, int num_retvals) { - if (!ctx->context.SoftPlacement() && device == nullptr) { - device = ctx->context.HostCPU(); + if (!ctx->soft_placement && device == nullptr) { + // TODO(ashankar): ASSUMPTION: ctx->devices[0] is always CPU + device = ctx->devices[0]; } if (device == nullptr) { @@ -652,18 +697,18 @@ tensorflow::Status Execute( if (maybe_stats != nullptr) { maybe_stats->set_op_end_rel_micros(tensorflow::Env::Default()->NowMicros() - maybe_stats->all_start_micros()); - tensorflow::mutex_lock ml(*ctx->context.MetadataMu()); - if (ctx->context.ShouldStoreMetadata()) { - auto* step_stats = ctx->context.RunMetadataProto()->mutable_step_stats(); + tensorflow::mutex_lock ml(ctx->metadata_mu); + if (ctx->should_store_metadata.load()) { + auto* step_stats = ctx->run_metadata.mutable_step_stats(); // Lazily initialize the RunMetadata with information about all devices if // this is the first call. - while (step_stats->dev_stats_size() < ctx->context.devices()->size()) { + while (step_stats->dev_stats_size() < ctx->devices.size()) { step_stats->add_dev_stats(); } // Find the current device's index. int device_idx = 0; - for (int i = 0; i < ctx->context.devices()->size(); ++i) { - if (ctx->context.devices()->at(i) == device) { + for (int i = 0; i < ctx->devices.size(); ++i) { + if (ctx->devices[i] == device) { device_idx = i; break; } @@ -699,7 +744,7 @@ class ExecuteNode : public tensorflow::EagerNode { tensorflow::NodeExecStats* maybe_stats, const tensorflow::DataTypeVector& output_dtypes, TFE_TensorHandle** retvals, int num_retvals) - : tensorflow::EagerNode(op->ctx->context.NextId()), + : tensorflow::EagerNode(op->ctx->executor.NextId()), ctx_(op->ctx), op_device_(op->device), inputs_(op->inputs), @@ -755,7 +800,7 @@ class CopyToDeviceNode : public tensorflow::EagerNode { public: CopyToDeviceNode(TFE_TensorHandle* src, tensorflow::Device* dstd, TFE_Context* ctx) - : tensorflow::EagerNode(ctx->context.NextId()), + : tensorflow::EagerNode(ctx->executor.NextId()), src_(src), dstd_(dstd), ctx_(ctx), @@ -1018,7 +1063,7 @@ extern "C" { void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, TF_Status* status) { TFE_Context* ctx = op->ctx; - status->status = ctx->context.GetStatus(); + status->status = ctx->executor.status(); if (!status->status.ok()) { return; } @@ -1042,7 +1087,7 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, if (op->inputs[i]->dtype == tensorflow::DT_RESOURCE && input_op_device != op->device) { tensorflow::Device* d = - input_op_device == nullptr ? ctx->context.HostCPU() : input_op_device; + input_op_device == nullptr ? ctx->devices[0] : input_op_device; VLOG(1) << "Changing device of operation " << op->name << " to " << d->name() << " because input #" << i << " is a resource in this device."; @@ -1050,35 +1095,40 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, } } tensorflow::Device* device = op->device; - if (!ctx->context.SoftPlacement() && device == nullptr) { - device = ctx->context.HostCPU(); + if (!ctx->soft_placement && device == nullptr) { + // TODO(ashankar): ASSUMPTION: ctx->devices[0] is always CPU + device = ctx->devices[0]; } tensorflow::Fprint128 cache_key = op->attrs.CacheKey(device == nullptr ? "unspecified" : device->name()); - tensorflow::KernelAndDevice* kernel = ctx->context.GetCachedKernel(cache_key); + tensorflow::KernelAndDevice* kernel; + { + tensorflow::tf_shared_lock l(ctx->cache_mu); + kernel = tensorflow::gtl::FindPtrOrNull(ctx->kernel_cache, cache_key); + } if (kernel == nullptr) { const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef(); - if (ctx->context.SoftPlacement() && device == nullptr) { + if (ctx->soft_placement && device == nullptr) { device = SelectDevice(ndef, ctx, status); if (!status->status.ok()) { return; } } CHECK(device != nullptr); - if (ctx->context.LogDevicePlacement()) { + if (ctx->log_device_placement) { LOG(INFO) << "Executing op " << ndef.op() << " in device " << device->name(); } - kernel = new tensorflow::KernelAndDevice(ctx->context.GetRendezvous()); + kernel = new tensorflow::KernelAndDevice(ctx->rendezvous); // Knowledge of the implementation of Init (and in-turn // FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def // will be accessed, so grab on to the lock. // See WARNING comment in Execute (before kernel->Run) - would be nice to // rework to avoid this subtlety. - tensorflow::tf_shared_lock l(*ctx->context.FunctionsMu()); - status->status = tensorflow::KernelAndDevice::Init( - ndef, ctx->context.func_lib(device), kernel); + tensorflow::tf_shared_lock l(ctx->functions_mu); + status->status = + tensorflow::KernelAndDevice::Init(ndef, ctx->func_lib(device), kernel); if (!status->status.ok()) { delete kernel; return; @@ -1086,7 +1136,7 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, // Update output_dtypes inside `kernel`. const tensorflow::OpDef* op_def = nullptr; const tensorflow::FunctionDef* function_def = - ctx->context.FuncLibDef()->Find(ndef.op()); + ctx->func_lib_def.Find(ndef.op()); if (function_def != nullptr) { op_def = &(function_def->signature()); } @@ -1102,7 +1152,8 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, if (!status->status.ok()) { return; } - ctx->context.AddKernelToCache(cache_key, kernel); + tensorflow::mutex_lock ml(ctx->cache_mu); + tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel); } const tensorflow::DataTypeVector& output_dtypes = kernel->output_dtypes(); const int output_dtypes_size = output_dtypes.size(); @@ -1120,11 +1171,11 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, // device from the one requested above. device = kernel->device(); } - status->status = ValidateInputTypeAndPlacement(ctx, ctx->context.HostCPU(), - device, op, kernel->kernel()); + status->status = ValidateInputTypeAndPlacement(ctx, ctx->devices[0], device, + op, kernel->kernel()); if (!status->status.ok()) return; std::unique_ptr<tensorflow::NodeExecStats> maybe_stats; - if (ctx->context.ShouldStoreMetadata()) { + if (ctx->should_store_metadata.load()) { maybe_stats.reset(new tensorflow::NodeExecStats); maybe_stats->set_node_name(op->name); maybe_stats->set_all_start_micros(tensorflow::Env::Default()->NowMicros()); @@ -1132,14 +1183,14 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, maybe_stats->set_scheduled_micros(tensorflow::Env::Default()->NowMicros()); // TODO(apassos) track referenced tensors } - if (ctx->context.Async()) { + if (ctx->Async()) { // Note that for async mode, execution order will make sure that all // input handles are ready before executing them. // TODO(agarwal): Consider executing "cheap" kernels inline for performance. tensorflow::EagerNode* node = new ExecuteNode(op, kernel, maybe_stats.release(), output_dtypes, retvals, *num_retvals); - ctx->context.ExecutorAdd(node); + ctx->executor.Add(node); } else { // Execute checks if retvals[i] is nullptr or not to figure if it needs to // allocate it. @@ -1155,24 +1206,23 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, TFE_Context* ctx, const char* device_name, TF_Status* status) { - status->status = ctx->context.GetStatus(); + status->status = ctx->executor.status(); if (!status->status.ok()) { return nullptr; } - tensorflow::Device* dstd = ctx->context.HostCPU(); + tensorflow::Device* dstd = ctx->devices[0]; if (device_name != nullptr && strlen(device_name) > 0) { - status->status = - ctx->context.device_mgr()->LookupDevice(device_name, &dstd); + status->status = ctx->device_manager->LookupDevice(device_name, &dstd); if (!status->status.ok()) return nullptr; } - if (ctx->context.Async()) { + if (ctx->Async()) { // Note that `h` may not be currently ready. However execution order will // make sure that `h` is ready before the copy is actually done. CopyToDeviceNode* node = new CopyToDeviceNode(h, dstd, ctx); TFE_TensorHandle* output = node->dst(); // Note that calling Add makes `node` accessible by the EagerExecutor // thread. So further accesses need to be thread-safe. - ctx->context.ExecutorAdd(node); + ctx->executor.Add(node); return output; } else { TFE_TensorHandle* output = nullptr; @@ -1190,20 +1240,24 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx, tensorflow::errors::InvalidArgument("Invalid FunctionDef proto"); return; } - status->status = ctx->context.AddFunctionDef(function_def); + tensorflow::mutex_lock l(ctx->functions_mu); + status->status = ctx->func_lib_def.AddFunctionDef(function_def); } void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, TF_Status* status) { - status->status = ctx->context.AddFunctionDef(function->fdef); + tensorflow::mutex_lock l(ctx->functions_mu); + status->status = ctx->func_lib_def.AddFunctionDef(function->fdef); } void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { - ctx->context.SetShouldStoreMetadata(true); + ctx->should_store_metadata.store(true); } void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { - ctx->context.SetShouldStoreMetadata(false); + tensorflow::mutex_lock ml(ctx->metadata_mu); + ctx->should_store_metadata.store(false); + ctx->run_metadata.Clear(); } } // extern "C" @@ -1232,9 +1286,9 @@ void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status) { TFE_ContextAsyncWait(ctx, status); if (!status->status.ok()) return; - tensorflow::mutex_lock ml(*ctx->context.MetadataMu()); - status->status = MessageToBuffer(*ctx->context.RunMetadataProto(), buf); - ctx->context.RunMetadataProto()->Clear(); + tensorflow::mutex_lock ml(ctx->metadata_mu); + status->status = MessageToBuffer(ctx->run_metadata, buf); + ctx->run_metadata.Clear(); } namespace { @@ -1309,6 +1363,11 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, } // namespace tensorflow +bool TFE_Context::Async() const { + tensorflow::mutex_lock l(async_map_mu); + return tensorflow::gtl::FindWithDefault( + thread_local_async, std::this_thread::get_id(), async_default); +} bool TFE_TensorHandle::IsReady() { if (node_id == 0) return true; @@ -1322,7 +1381,7 @@ tensorflow::Status TFE_TensorHandle::WaitReady() { { tensorflow::mutex_lock l(ctx_mutex_); if (ctx_ == nullptr) return tensorflow::Status::OK(); - executor = ctx_->context.Executor(); + executor = &ctx_->executor; } return executor->WaitFor(node_id); } |