diff options
-rw-r--r-- | tensorflow/c/eager/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api.cc | 194 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api_internal.h | 84 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/eager/BUILD | 22 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/eager/context.cc | 153 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/eager/context.h | 198 |
6 files changed, 450 insertions, 203 deletions
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 841ff48a38..bea5a121b3 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -28,6 +28,7 @@ tf_cuda_library( "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", "//tensorflow/core:core_cpu", + "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:eager_executor", "//tensorflow/core/common_runtime/eager:kernel_and_device", "//tensorflow/core:core_cpu_internal", @@ -64,6 +65,7 @@ tf_cuda_library( "//tensorflow/core:framework_lite", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:eager_executor", "//tensorflow/core/common_runtime/eager:kernel_and_device", ], diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index a23015c99e..2402a6d044 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -71,18 +71,6 @@ std::atomic_int_fast64_t func_id_generator(0); } // namespace -TFE_ContextDevicePlacementPolicy PlacementPolicy( - bool soft_placement, TFE_ContextDevicePlacementPolicy original_policy) { - if (!soft_placement) { - return original_policy; - } - if (original_policy == TFE_DEVICE_PLACEMENT_EXPLICIT || - original_policy == TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32) { - return TFE_DEVICE_PLACEMENT_SILENT; - } - return original_policy; -} - extern "C" { TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; } @@ -104,19 +92,7 @@ void TFE_ContextOptionsSetDevicePlacementPolicy( TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, unsigned char async, TF_Status* status) { - { - tensorflow::mutex_lock l(ctx->async_map_mu); - ctx->thread_local_async[std::this_thread::get_id()] = async; - } - if (async) { - ctx->executor.EnableAsync(); - } else { - // TODO(agarwal): Currently we add a wait here to handle cases where a sync - // op has a control dependency on an async op, and the latter has not - // executed yet. This wait can be removed by storing all the control inputs - // and waiting for them when executing ops. - status->status = ctx->executor.WaitForAllPendingNodes(); - } + status->status = ctx->context.SetAsyncForThread(async); } void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } @@ -133,34 +109,26 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { new tensorflow::DeviceMgr(devices)); tensorflow::Rendezvous* r = new tensorflow::IntraProcessRendezvous(device_mgr.get()); - return new TFE_Context(*opts, std::move(device_mgr), r); + return new TFE_Context(opts->session_options.options, opts->policy, + opts->async, std::move(device_mgr), r); } void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) { - status->status = ctx->executor.WaitForAllPendingNodes(); - { - tensorflow::mutex_lock ml(ctx->cache_mu); - tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache); - } - ctx->rendezvous->Unref(); delete ctx; } TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { TF_DeviceList* list = new TF_DeviceList; - ctx->device_manager->ListDeviceAttributes(&list->response); + ctx->context.device_mgr()->ListDeviceAttributes(&list->response); return list; } -void TFE_ContextClearCaches(TFE_Context* ctx) { - tensorflow::mutex_lock ml(ctx->cache_mu); - tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache); -} +void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context.ClearCaches(); } void TFE_ContextSetThreadLocalDevicePlacementPolicy( TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { - tensorflow::mutex_lock ml(ctx->policy_map_mu); - ctx->thread_local_policies[std::this_thread::get_id()] = policy; + ctx->context.SetThreadLocalDevicePlacementPolicy( + static_cast<tensorflow::ContextDevicePlacementPolicy>(policy)); } // Note: this function looks up a thread local policy. So it should be called in @@ -168,25 +136,20 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy( // safe to call this function from the async EagerExecutor threads. extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( TFE_Context* ctx) { - tensorflow::mutex_lock ml(ctx->policy_map_mu); - auto policy_map_it = - ctx->thread_local_policies.find(std::this_thread::get_id()); - if (policy_map_it != ctx->thread_local_policies.end()) { - return policy_map_it->second; - } - return ctx->policy; + return static_cast<TFE_ContextDevicePlacementPolicy>( + ctx->context.GetDevicePlacementPolicy()); } void TFE_ContextAsyncWait(TFE_Context* ctx, TF_Status* status) { - status->status = ctx->executor.WaitForAllPendingNodes(); + status->status = ctx->context.AsyncWait(); } void TFE_ContextGetStatus(TFE_Context* ctx, TF_Status* status) { - status->status = ctx->executor.status(); + status->status = ctx->context.GetStatus(); } void TFE_ContextAsyncClearError(TFE_Context* ctx) { - ctx->executor.ClearError(); + ctx->context.ClearAsyncError(); } TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { @@ -259,7 +222,7 @@ tensorflow::Status TensorHandleCopyToDevice(TFE_TensorHandle* h, // nullptr. tensorflow::Device* src_opd = nullptr; TF_RETURN_IF_ERROR(h->TensorAndDevice(&src, &srcd, &src_opd)); - if (srcd == nullptr) srcd = ctx->devices[0]; + if (srcd == nullptr) srcd = ctx->context.HostCPU(); bool is_same_device = (srcd == dstd) || (DeviceName(srcd) == DeviceName(dstd)); const bool dst_cpu = IsCPU(dstd); @@ -332,8 +295,7 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, status->status = tensorflow::AttrTypeMapForOp(name, &types); if (status->status.ok()) return new TFE_Op(ctx, name, types); if (TF_GetCode(status) == TF_NOT_FOUND) { - tensorflow::mutex_lock l(ctx->functions_mu); - if (ctx->func_lib_def.Find(name) != nullptr) { + if (ctx->context.FindFunctionByName(name)) { status->status = tensorflow::Status::OK(); return new TFE_Op(ctx, name, nullptr); } @@ -346,20 +308,14 @@ void TFE_DeleteOp(TFE_Op* op) { delete op; } void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) { tensorflow::Device* d = nullptr; if (device_name != nullptr && strlen(device_name) > 0) { - auto it = op->ctx->devices_map.find(device_name); - if (it == op->ctx->devices_map.end()) { - status->status = - tensorflow::errors::InvalidArgument(device_name, " unknown device."); - return; - } - d = it->second; + status->status = op->ctx->context.FindDeviceByName(device_name, &d); } op->device = d; } const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) { tensorflow::Device* device = - (op->device == nullptr) ? op->ctx->devices[0] : op->device; + (op->device == nullptr) ? op->ctx->context.HostCPU() : op->device; return device->name().c_str(); } @@ -634,7 +590,7 @@ tensorflow::Status ValidateInputTypeAndPlacement( tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef, TFE_Context* ctx, TF_Status* status) { tensorflow::DeviceSet ds; - for (tensorflow::Device* d : ctx->devices) { + for (tensorflow::Device* d : *ctx->context.devices()) { ds.AddDevice(d); } tensorflow::DeviceTypeVector final_devices; @@ -648,7 +604,7 @@ tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef, "Could not find valid device for node ", ndef.DebugString()); return nullptr; } - for (tensorflow::Device* d : ctx->devices) { + for (tensorflow::Device* d : *ctx->context.devices()) { if (d->device_type() == final_devices[0].type_string()) { return d; } @@ -663,9 +619,8 @@ tensorflow::Status Execute( const tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4>& op_inputs, tensorflow::KernelAndDevice* kernel, tensorflow::NodeExecStats* maybe_stats, TFE_TensorHandle** retvals, int num_retvals) { - if (!ctx->soft_placement && device == nullptr) { - // TODO(ashankar): ASSUMPTION: ctx->devices[0] is always CPU - device = ctx->devices[0]; + if (!ctx->context.SoftPlacement() && device == nullptr) { + device = ctx->context.HostCPU(); } if (device == nullptr) { @@ -684,8 +639,8 @@ tensorflow::Status Execute( inputs[i] = *input_tensor; } // WARNING: kernel->Run utilizes the FunctionLibraryRuntime - // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def, - // which is GUARDED_BY(ctx->functions_mu). But knowledge of the implementation + // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def. + // But knowledge of the implementation // of FunctionLibraryRuntime tells us that func_lib_def is not accessed by // FunctionLibraryRuntime::Run(), so there is no thread-safety concern here. // This is quite subtle. Re-work things to make this better? (Would it make @@ -697,18 +652,18 @@ tensorflow::Status Execute( if (maybe_stats != nullptr) { maybe_stats->set_op_end_rel_micros(tensorflow::Env::Default()->NowMicros() - maybe_stats->all_start_micros()); - tensorflow::mutex_lock ml(ctx->metadata_mu); - if (ctx->should_store_metadata.load()) { - auto* step_stats = ctx->run_metadata.mutable_step_stats(); + tensorflow::mutex_lock ml(*ctx->context.MetadataMu()); + if (ctx->context.ShouldStoreMetadata()) { + auto* step_stats = ctx->context.RunMetadataProto()->mutable_step_stats(); // Lazily initialize the RunMetadata with information about all devices if // this is the first call. - while (step_stats->dev_stats_size() < ctx->devices.size()) { + while (step_stats->dev_stats_size() < ctx->context.devices()->size()) { step_stats->add_dev_stats(); } // Find the current device's index. int device_idx = 0; - for (int i = 0; i < ctx->devices.size(); ++i) { - if (ctx->devices[i] == device) { + for (int i = 0; i < ctx->context.devices()->size(); ++i) { + if (ctx->context.devices()->at(i) == device) { device_idx = i; break; } @@ -744,7 +699,7 @@ class ExecuteNode : public tensorflow::EagerNode { tensorflow::NodeExecStats* maybe_stats, const tensorflow::DataTypeVector& output_dtypes, TFE_TensorHandle** retvals, int num_retvals) - : tensorflow::EagerNode(op->ctx->executor.NextId()), + : tensorflow::EagerNode(op->ctx->context.NextId()), ctx_(op->ctx), op_device_(op->device), inputs_(op->inputs), @@ -800,7 +755,7 @@ class CopyToDeviceNode : public tensorflow::EagerNode { public: CopyToDeviceNode(TFE_TensorHandle* src, tensorflow::Device* dstd, TFE_Context* ctx) - : tensorflow::EagerNode(ctx->executor.NextId()), + : tensorflow::EagerNode(ctx->context.NextId()), src_(src), dstd_(dstd), ctx_(ctx), @@ -866,8 +821,7 @@ const tensorflow::FunctionDef* OpToFunction( TFE_Context* ctx = op->ctx; const tensorflow::OpRegistrationData* op_data; { - tensorflow::tf_shared_lock l(ctx->functions_mu); - status->status = ctx->func_lib_def.LookUp(op->name, &op_data); + status->status = ctx->context.FindFunctionOpData(op->name, &op_data); if (!status->status.ok()) { return nullptr; } @@ -963,10 +917,9 @@ const tensorflow::FunctionDef* OpToFunction( } VLOG(1) << "Fixed Output names and all types: " << fdef.DebugString(); - tensorflow::mutex_lock l(ctx->functions_mu); - status->status = ctx->func_lib_def.AddFunctionDef(fdef); + ctx->context.AddFunctionDef(fdef); if (!status->status.ok()) return nullptr; - const auto ret = ctx->func_lib_def.Find(signature->name()); + const auto ret = ctx->context.FindFunctionDef(signature->name()); DCHECK(ret != nullptr); return ret; } @@ -985,8 +938,7 @@ std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) { const tensorflow::FunctionDef* fdef; { - tensorflow::tf_shared_lock l(op->ctx->functions_mu); - fdef = op->ctx->func_lib_def.Find(op->name); + fdef = op->ctx->context.FindFunctionDef(op->name); } std::vector<TF_DataType> const_input_types; std::vector<TF_DataType> arg_input_types; @@ -1063,7 +1015,7 @@ extern "C" { void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, TF_Status* status) { TFE_Context* ctx = op->ctx; - status->status = ctx->executor.status(); + status->status = ctx->context.GetStatus(); if (!status->status.ok()) { return; } @@ -1087,7 +1039,7 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, if (op->inputs[i]->dtype == tensorflow::DT_RESOURCE && input_op_device != op->device) { tensorflow::Device* d = - input_op_device == nullptr ? ctx->devices[0] : input_op_device; + input_op_device == nullptr ? ctx->context.HostCPU() : input_op_device; VLOG(1) << "Changing device of operation " << op->name << " to " << d->name() << " because input #" << i << " is a resource in this device."; @@ -1095,40 +1047,35 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, } } tensorflow::Device* device = op->device; - if (!ctx->soft_placement && device == nullptr) { - // TODO(ashankar): ASSUMPTION: ctx->devices[0] is always CPU - device = ctx->devices[0]; + if (!ctx->context.SoftPlacement() && device == nullptr) { + device = ctx->context.HostCPU(); } tensorflow::Fprint128 cache_key = op->attrs.CacheKey(device == nullptr ? "unspecified" : device->name()); - tensorflow::KernelAndDevice* kernel; - { - tensorflow::tf_shared_lock l(ctx->cache_mu); - kernel = tensorflow::gtl::FindPtrOrNull(ctx->kernel_cache, cache_key); - } + tensorflow::KernelAndDevice* kernel = ctx->context.GetCachedKernel(cache_key); if (kernel == nullptr) { const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef(); - if (ctx->soft_placement && device == nullptr) { + if (ctx->context.SoftPlacement() && device == nullptr) { device = SelectDevice(ndef, ctx, status); if (!status->status.ok()) { return; } } CHECK(device != nullptr); - if (ctx->log_device_placement) { + if (ctx->context.LogDevicePlacement()) { LOG(INFO) << "Executing op " << ndef.op() << " in device " << device->name(); } - kernel = new tensorflow::KernelAndDevice(ctx->rendezvous); + kernel = new tensorflow::KernelAndDevice(ctx->context.GetRendezvous()); // Knowledge of the implementation of Init (and in-turn // FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def // will be accessed, so grab on to the lock. // See WARNING comment in Execute (before kernel->Run) - would be nice to // rework to avoid this subtlety. - tensorflow::tf_shared_lock l(ctx->functions_mu); - status->status = - tensorflow::KernelAndDevice::Init(ndef, ctx->func_lib(device), kernel); + tensorflow::tf_shared_lock l(*ctx->context.FunctionsMu()); + status->status = tensorflow::KernelAndDevice::Init( + ndef, ctx->context.func_lib(device), kernel); if (!status->status.ok()) { delete kernel; return; @@ -1136,7 +1083,7 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, // Update output_dtypes inside `kernel`. const tensorflow::OpDef* op_def = nullptr; const tensorflow::FunctionDef* function_def = - ctx->func_lib_def.Find(ndef.op()); + ctx->context.FuncLibDef()->Find(ndef.op()); if (function_def != nullptr) { op_def = &(function_def->signature()); } @@ -1152,8 +1099,7 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, if (!status->status.ok()) { return; } - tensorflow::mutex_lock ml(ctx->cache_mu); - tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel); + ctx->context.AddKernelToCache(cache_key, kernel); } const tensorflow::DataTypeVector& output_dtypes = kernel->output_dtypes(); const int output_dtypes_size = output_dtypes.size(); @@ -1171,11 +1117,11 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, // device from the one requested above. device = kernel->device(); } - status->status = ValidateInputTypeAndPlacement(ctx, ctx->devices[0], device, - op, kernel->kernel()); + status->status = ValidateInputTypeAndPlacement(ctx, ctx->context.HostCPU(), + device, op, kernel->kernel()); if (!status->status.ok()) return; std::unique_ptr<tensorflow::NodeExecStats> maybe_stats; - if (ctx->should_store_metadata.load()) { + if (ctx->context.ShouldStoreMetadata()) { maybe_stats.reset(new tensorflow::NodeExecStats); maybe_stats->set_node_name(op->name); maybe_stats->set_all_start_micros(tensorflow::Env::Default()->NowMicros()); @@ -1183,14 +1129,14 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, maybe_stats->set_scheduled_micros(tensorflow::Env::Default()->NowMicros()); // TODO(apassos) track referenced tensors } - if (ctx->Async()) { + if (ctx->context.Async()) { // Note that for async mode, execution order will make sure that all // input handles are ready before executing them. // TODO(agarwal): Consider executing "cheap" kernels inline for performance. tensorflow::EagerNode* node = new ExecuteNode(op, kernel, maybe_stats.release(), output_dtypes, retvals, *num_retvals); - ctx->executor.Add(node); + ctx->context.ExecutorAdd(node); } else { // Execute checks if retvals[i] is nullptr or not to figure if it needs to // allocate it. @@ -1206,23 +1152,24 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, TFE_Context* ctx, const char* device_name, TF_Status* status) { - status->status = ctx->executor.status(); + status->status = ctx->context.GetStatus(); if (!status->status.ok()) { return nullptr; } - tensorflow::Device* dstd = ctx->devices[0]; + tensorflow::Device* dstd = ctx->context.HostCPU(); if (device_name != nullptr && strlen(device_name) > 0) { - status->status = ctx->device_manager->LookupDevice(device_name, &dstd); + status->status = + ctx->context.device_mgr()->LookupDevice(device_name, &dstd); if (!status->status.ok()) return nullptr; } - if (ctx->Async()) { + if (ctx->context.Async()) { // Note that `h` may not be currently ready. However execution order will // make sure that `h` is ready before the copy is actually done. CopyToDeviceNode* node = new CopyToDeviceNode(h, dstd, ctx); TFE_TensorHandle* output = node->dst(); // Note that calling Add makes `node` accessible by the EagerExecutor // thread. So further accesses need to be thread-safe. - ctx->executor.Add(node); + ctx->context.ExecutorAdd(node); return output; } else { TFE_TensorHandle* output = nullptr; @@ -1240,24 +1187,20 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx, tensorflow::errors::InvalidArgument("Invalid FunctionDef proto"); return; } - tensorflow::mutex_lock l(ctx->functions_mu); - status->status = ctx->func_lib_def.AddFunctionDef(function_def); + status->status = ctx->context.AddFunctionDef(function_def); } void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, TF_Status* status) { - tensorflow::mutex_lock l(ctx->functions_mu); - status->status = ctx->func_lib_def.AddFunctionDef(function->fdef); + status->status = ctx->context.AddFunctionDef(function->fdef); } void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { - ctx->should_store_metadata.store(true); + ctx->context.SetShouldStoreMetadata(true); } void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { - tensorflow::mutex_lock ml(ctx->metadata_mu); - ctx->should_store_metadata.store(false); - ctx->run_metadata.Clear(); + ctx->context.SetShouldStoreMetadata(false); } } // extern "C" @@ -1286,9 +1229,9 @@ void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status) { TFE_ContextAsyncWait(ctx, status); if (!status->status.ok()) return; - tensorflow::mutex_lock ml(ctx->metadata_mu); - status->status = MessageToBuffer(ctx->run_metadata, buf); - ctx->run_metadata.Clear(); + tensorflow::mutex_lock ml(*ctx->context.MetadataMu()); + status->status = MessageToBuffer(*ctx->context.RunMetadataProto(), buf); + ctx->context.RunMetadataProto()->Clear(); } namespace { @@ -1363,11 +1306,6 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, } // namespace tensorflow -bool TFE_Context::Async() const { - tensorflow::mutex_lock l(async_map_mu); - return tensorflow::gtl::FindWithDefault( - thread_local_async, std::this_thread::get_id(), async_default); -} bool TFE_TensorHandle::IsReady() { if (node_id == 0) return true; @@ -1381,7 +1319,7 @@ tensorflow::Status TFE_TensorHandle::WaitReady() { { tensorflow::mutex_lock l(ctx_mutex_); if (ctx_ == nullptr) return tensorflow::Status::OK(); - executor = &ctx_->executor; + executor = ctx_->context.Executor(); } return executor->WaitFor(node_id); } diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index a79f8ddd33..5b29120b40 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/runtime.h" #include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/eager_executor.h" #include "tensorflow/core/common_runtime/eager/kernel_and_device.h" #include "tensorflow/core/common_runtime/function.h" @@ -52,85 +53,18 @@ struct TFE_ContextOptions { TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32}; }; -TFE_ContextDevicePlacementPolicy PlacementPolicy( - bool soft_placement, TFE_ContextDevicePlacementPolicy original_policy); - struct TFE_Context { - explicit TFE_Context(const TFE_ContextOptions& opts, + explicit TFE_Context(const tensorflow::SessionOptions& opts, + TFE_ContextDevicePlacementPolicy default_policy, + bool async, std::unique_ptr<tensorflow::DeviceMgr> device_mgr, tensorflow::Rendezvous* rendezvous) - : soft_placement( - opts.session_options.options.config.allow_soft_placement()), - policy(PlacementPolicy(soft_placement, opts.policy)), - device_manager(std::move(device_mgr)), - devices(device_manager->ListDevices()), - rendezvous(rendezvous), - pflr(new tensorflow::ProcessFunctionLibraryRuntime( - device_manager.get(), opts.session_options.options.env, - TF_GRAPH_DEF_VERSION, &func_lib_def, {})), - log_device_placement( - opts.session_options.options.config.log_device_placement()), - async_default(opts.async) { - if (async_default) executor.EnableAsync(); - - for (auto* device : devices) { - devices_map[tensorflow::StringPiece(device->name())] = device; - } - } - - const bool soft_placement; - const TFE_ContextDevicePlacementPolicy policy; - - // Note: we cannot use C++11 thread_local here as there is no concept of a - // thread-local-object-local variable in C++11. - tensorflow::mutex policy_map_mu; - std::unordered_map<std::thread::id, TFE_ContextDevicePlacementPolicy> - thread_local_policies GUARDED_BY(policy_map_mu); - - std::unique_ptr<tensorflow::DeviceMgr> device_manager; - // Devices owned by device_manager - std::vector<tensorflow::Device*> devices; - // All devices are not owned. - tensorflow::gtl::FlatMap<tensorflow::StringPiece, tensorflow::Device*, - tensorflow::StringPieceHasher> - devices_map; - tensorflow::Rendezvous* const rendezvous; - - tensorflow::mutex functions_mu; - tensorflow::FunctionLibraryDefinition func_lib_def GUARDED_BY(functions_mu){ - tensorflow::OpRegistry::Global(), {}}; - - // One FunctionLibraryRuntime per device. - // func_libs[i] is the FunctionLibraryRuntime corresponding to - // session->devices[i]. - const std::unique_ptr<tensorflow::ProcessFunctionLibraryRuntime> pflr; - - tensorflow::mutex cache_mu; - std::unordered_map<tensorflow::Fprint128, tensorflow::KernelAndDevice*, - tensorflow::Fprint128Hasher> - kernel_cache GUARDED_BY(cache_mu); - - tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) const { - return pflr->GetFLR(d->name()); - } + : context(opts, + static_cast<tensorflow::ContextDevicePlacementPolicy>( + default_policy), + async, std::move(device_mgr), rendezvous) {} - // Whether we should compute RunMetadata. - std::atomic<bool> should_store_metadata{false}; - tensorflow::mutex metadata_mu; - tensorflow::RunMetadata run_metadata GUARDED_BY(metadata_mu); - const bool log_device_placement; - // EagerExecutor for async execution. - tensorflow::EagerExecutor executor; - - // True if running in asynchronous mode. - bool Async() const; - - // True if the default value for execution mode is async. Note that this value - // can be overridden per thread based on `thread_local_async` overrides. - const bool async_default; - mutable tensorflow::mutex async_map_mu; - std::unordered_map<std::thread::id, bool> thread_local_async - GUARDED_BY(async_map_mu); + tensorflow::EagerContext context; }; struct TFE_TensorHandle : public tensorflow::core::RefCounted { diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index 8ba560bef8..de10b10b7e 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -33,6 +33,28 @@ tf_cuda_library( ) tf_cuda_library( + name = "context", + srcs = [ + "context.cc", + ], + hdrs = [ + "context.h", + ], + visibility = ["//tensorflow:internal"], + deps = [ + ":eager_executor", + ":kernel_and_device", + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:session_options", + ], +) + +tf_cuda_library( name = "kernel_and_device", srcs = [ "kernel_and_device.cc", diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc new file mode 100644 index 0000000000..0566329f18 --- /dev/null +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -0,0 +1,153 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/eager/context.h" + +namespace tensorflow { + +ContextDevicePlacementPolicy PlacementPolicy( + bool soft_placement, ContextDevicePlacementPolicy original_policy) { + if (!soft_placement) { + return original_policy; + } + if (original_policy == DEVICE_PLACEMENT_EXPLICIT || + original_policy == DEVICE_PLACEMENT_SILENT_FOR_INT32) { + return DEVICE_PLACEMENT_SILENT; + } + return original_policy; +} + +EagerContext::EagerContext(const SessionOptions& opts, + ContextDevicePlacementPolicy default_policy, + bool async, std::unique_ptr<DeviceMgr> device_mgr, + Rendezvous* rendezvous) + : soft_placement_(opts.config.allow_soft_placement()), + policy_(PlacementPolicy(soft_placement_, default_policy)), + device_manager_(std::move(device_mgr)), + devices_(device_manager_->ListDevices()), + rendezvous_(rendezvous), + pflr_(new ProcessFunctionLibraryRuntime(device_manager_.get(), opts.env, + TF_GRAPH_DEF_VERSION, + &func_lib_def_, {})), + log_device_placement_(opts.config.log_device_placement()), + async_default_(async) { + if (async_default_) { + executor_.EnableAsync(); + } + + for (auto* device : devices_) { + devices_map_[device->name()] = device; + } +} + +bool EagerContext::Async() const { + mutex_lock l(async_map_mu_); + return gtl::FindWithDefault(thread_local_async_, std::this_thread::get_id(), + async_default_); +} + +Status EagerContext::SetAsyncForThread(bool async) { + { + tensorflow::mutex_lock l(async_map_mu_); + thread_local_async_[std::this_thread::get_id()] = async; + } + if (async) { + executor_.EnableAsync(); + } else { + // TODO(agarwal): Currently we add a wait here to handle cases where a + // sync op has a control dependency on an async op, and the latter has not + // executed yet. This wait can be removed by storing all the control + // inputs and waiting for them when executing ops. + return executor_.WaitForAllPendingNodes(); + } + return Status::OK(); +} + +void EagerContext::ClearCaches() { + mutex_lock ml(cache_mu_); + gtl::STLDeleteValues(&kernel_cache_); +} + +void EagerContext::SetThreadLocalDevicePlacementPolicy( + ContextDevicePlacementPolicy policy) { + mutex_lock ml(policy_map_mu_); + thread_local_policies_[std::this_thread::get_id()] = policy; +} + +ContextDevicePlacementPolicy EagerContext::GetDevicePlacementPolicy() { + mutex_lock ml(policy_map_mu_); + auto policy_map_it = thread_local_policies_.find(std::this_thread::get_id()); + if (policy_map_it != thread_local_policies_.end()) { + return policy_map_it->second; + } + return policy_; +} + +EagerContext::~EagerContext() { + executor_.WaitForAllPendingNodes().IgnoreError(); + ClearCaches(); + rendezvous_->Unref(); +} + +bool EagerContext::FindFunctionByName(const string& name) { + mutex_lock l(functions_mu_); + return func_lib_def_.Find(name) != nullptr; +} + +Status EagerContext::FindFunctionOpData( + const string& name, const tensorflow::OpRegistrationData** op_data) { + mutex_lock l(functions_mu_); + return func_lib_def_.LookUp(name, op_data); +} + +const FunctionDef* EagerContext::FindFunctionDef(const string& name) { + mutex_lock l(functions_mu_); + return func_lib_def_.Find(name); +} + +Status EagerContext::FindDeviceByName(const string& name, Device** result) { + auto it = devices_map_.find(name); + if (it == devices_map_.end()) { + return errors::InvalidArgument(name, " unknown device."); + } + *result = it->second; + return Status::OK(); +} + +Status EagerContext::AddFunctionDef(const FunctionDef& fdef) { + mutex_lock l(functions_mu_); + return func_lib_def_.AddFunctionDef(fdef); +} + +KernelAndDevice* EagerContext::GetCachedKernel(Fprint128 cache_key) { + tf_shared_lock l(cache_mu_); + return gtl::FindPtrOrNull(kernel_cache_, cache_key); +} + +void EagerContext::AddKernelToCache(Fprint128 cache_key, + KernelAndDevice* kernel) { + mutex_lock ml(cache_mu_); + gtl::InsertOrUpdate(&kernel_cache_, cache_key, kernel); +} + +void EagerContext::SetShouldStoreMetadata(bool value) { + should_store_metadata_.store(value); + if (!value) { + mutex_lock ml(metadata_mu_); + run_metadata_.Clear(); + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h new file mode 100644 index 0000000000..bc97219dae --- /dev/null +++ b/tensorflow/core/common_runtime/eager/context.h @@ -0,0 +1,198 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_ + +#include <algorithm> +#include <cstddef> +#include <map> +#include <memory> +#include <queue> +#include <string> +#include <vector> + +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/eager/eager_executor.h" +#include "tensorflow/core/common_runtime/eager/kernel_and_device.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/platform/fingerprint.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { + +// Note: there's a copy enum in eager/c_api.h. It should be kept in sync. +enum ContextDevicePlacementPolicy { + // Running operations with input tensors on the wrong device will fail. When + // soft placement is enabled acts like TFE_DEVICE_PLACEMENT_SILENT. + DEVICE_PLACEMENT_EXPLICIT = 0, + // Copy the tensor to the right device but log a warning. + DEVICE_PLACEMENT_WARN = 1, + // Silently copy the tensor, which has a performance cost since the + // operation will be blocked till the copy completes. + DEVICE_PLACEMENT_SILENT = 2, + // Default placement policy which silently copies int32 tensors but not other + // dtypes. When soft placement is enabled acts like + // TFE_DEVICE_PLACEMENT_SILENT. + DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3, +}; + +ContextDevicePlacementPolicy PlacementPolicy( + bool soft_placement, ContextDevicePlacementPolicy original_policy); + +class EagerContext { + public: + explicit EagerContext(const SessionOptions& opts, + ContextDevicePlacementPolicy default_policy, bool async, + std::unique_ptr<DeviceMgr> device_mgr, + Rendezvous* rendezvous); + + ~EagerContext(); + + // Returns the function library runtime for the given device. + FunctionLibraryRuntime* func_lib(Device* d) const { + return pflr_->GetFLR(d->name()); + } + + // True if running in asynchronous mode. + bool Async() const; + + EagerExecutor* Executor() { return &executor_; } + + // Sets whether this thread should run in synchronous or asynchronous mode. + Status SetAsyncForThread(bool async); + + // TODO(apassos) make this return a constant reference + gtl::FlatMap<string, Device*, StringPieceHasher>* device_map() { + return &devices_map_; + } + + // TODO(apassos) make this return a constant reference + std::vector<Device*>* devices() { return &devices_; } + + // Clears the kernel caches. + void ClearCaches(); + + // Sets the device placement policy for the current thread. + void SetThreadLocalDevicePlacementPolicy(ContextDevicePlacementPolicy policy); + + // Returns the device placement policy for the current thread. + ContextDevicePlacementPolicy GetDevicePlacementPolicy(); + + Status AsyncWait() { return executor_.WaitForAllPendingNodes(); } + + Status GetStatus() { return executor_.status(); } + + void ClearAsyncError() { executor_.ClearError(); } + + bool FindFunctionByName(const string& name); + + Status FindFunctionOpData(const string& name, + const tensorflow::OpRegistrationData** op_data); + + const FunctionDef* FindFunctionDef(const string& name); + + Status FindDeviceByName(const string& name, Device** result); + + Device* HostCPU() { return devices_[0]; } + + bool SoftPlacement() { return soft_placement_; } + + uint64 NextId() { return executor_.NextId(); } + + void ExecutorAdd(EagerNode* node) { executor_.Add(node); } + + Status AddFunctionDef(const FunctionDef& fdef); + + KernelAndDevice* GetCachedKernel(Fprint128 cache_key); + + void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel); + + bool LogDevicePlacement() { return log_device_placement_; } + + Rendezvous* GetRendezvous() { return rendezvous_; } + + mutex* FunctionsMu() { return &functions_mu_; } + + tensorflow::DeviceMgr* device_mgr() { return device_manager_.get(); } + + // TODO(apassos) remove the need for this + void ReleaseDeviceMgr() { device_manager_.release(); } + + // TODO(apassos) clean up RunMetadata storage. + mutex* MetadataMu() { return &metadata_mu_; } + bool ShouldStoreMetadata() { return should_store_metadata_.load(); } + void SetShouldStoreMetadata(bool value); + RunMetadata* RunMetadataProto() { return &run_metadata_; } + + FunctionLibraryDefinition* FuncLibDef() { return &func_lib_def_; } + + private: + const bool soft_placement_; + const ContextDevicePlacementPolicy policy_; + + // Note: we cannot use C++11 thread_local here as there is no concept of a + // thread-local-object-local variable in C++11. + mutex policy_map_mu_; + std::unordered_map<std::thread::id, ContextDevicePlacementPolicy> + thread_local_policies_ GUARDED_BY(policy_map_mu_); + + std::unique_ptr<DeviceMgr> device_manager_; + // Devices owned by device_manager + std::vector<Device*> devices_; + // All devices are not owned. + gtl::FlatMap<string, Device*, StringPieceHasher> devices_map_; + Rendezvous* const rendezvous_; + + mutex functions_mu_; + FunctionLibraryDefinition func_lib_def_ GUARDED_BY(functions_mu_){ + OpRegistry::Global(), {}}; + + // One FunctionLibraryRuntime per device. + // func_libs[i] is the FunctionLibraryRuntime corresponding to + // session->devices[i]. + const std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; + + mutex cache_mu_; + std::unordered_map<Fprint128, KernelAndDevice*, Fprint128Hasher> kernel_cache_ + GUARDED_BY(cache_mu_); + + // Whether we should compute RunMetadata. + std::atomic<bool> should_store_metadata_{false}; + mutex metadata_mu_; + RunMetadata run_metadata_ GUARDED_BY(metadata_mu_); + const bool log_device_placement_; + // EagerExecutor for async execution. + EagerExecutor executor_; + + // True if the default value for execution mode is async. Note that this value + // can be overridden per thread based on `thread_local_async` overrides. + const bool async_default_; + mutable mutex async_map_mu_; + std::unordered_map<std::thread::id, bool> thread_local_async_ + GUARDED_BY(async_map_mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_ |