aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/function.cc
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2017-08-17 11:33:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-17 11:36:50 -0700
commit935ff49201edd7a6297b313fb9545d1299b9a28d (patch)
tree36486015014d33efa99d7fd0875eb1545bd518cb /tensorflow/core/common_runtime/function.cc
parentd94dca2174f0c05dfa03796c3ae31d345813d025 (diff)
Automated g4 rollback of changelist 165521057
PiperOrigin-RevId: 165604864
Diffstat (limited to 'tensorflow/core/common_runtime/function.cc')
-rw-r--r--tensorflow/core/common_runtime/function.cc83
1 files changed, 30 insertions, 53 deletions
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 829fba780f..6b529d8f13 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -139,14 +139,15 @@ static Node* AddRet(Graph* g, Endpoint input, int index) {
return ret;
}
+static const FunctionLibraryRuntime::Handle kInvalidHandle = -1;
+
class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
public:
FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Env* env, Device* device,
int graph_def_version,
const FunctionLibraryDefinition* lib_def,
const OptimizerOptions& optimizer_options,
- CustomKernelCreator custom_kernel_creator,
- ProcessFunctionLibraryRuntime* parent);
+ CustomKernelCreator custom_kernel_creator);
~FunctionLibraryRuntimeImpl() override;
@@ -183,13 +184,17 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
const FunctionLibraryDefinition* const lib_def_;
GraphOptimizer optimizer_;
const CustomKernelCreator custom_kernel_creator_;
- const string device_name_;
std::function<Status(const string&, const OpDef**)> get_func_sig_;
std::function<Status(const NodeDef&, OpKernel**)> create_kernel_;
mutable mutex mu_;
+ // Maps function instantiation to a handle. The key is a
+ // canonicalized representation of the function name and
+ // instantiation attrs. The handle is an index into the items_.
+ std::unordered_map<string, Handle> table_ GUARDED_BY(mu_);
+
// func_graphs_ never shrinks or reorders its members.
std::vector<FunctionBody*> func_graphs_ GUARDED_BY(mu_);
@@ -203,15 +208,12 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
};
std::vector<Item*> items_;
- ProcessFunctionLibraryRuntime* parent_ = nullptr; // not owned.
-
Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs,
FunctionBody** fbody);
Status CreateItem(Handle handle, Item** item);
Status GetOrCreateItem(Handle handle, Item** item);
Status InstantiateSymbolicGradient(const NameAttrList& func,
FunctionBody** g_body);
- bool IsLocalTarget(const AttrSlice& attrs);
TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl);
};
@@ -220,19 +222,14 @@ FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
const DeviceMgr* dmgr, Env* env, Device* device, int graph_def_version,
const FunctionLibraryDefinition* lib_def,
const OptimizerOptions& optimizer_options,
- CustomKernelCreator custom_kernel_creator,
- ProcessFunctionLibraryRuntime* parent)
+ CustomKernelCreator custom_kernel_creator)
: device_mgr_(dmgr),
device_(device),
env_(env),
graph_def_version_(graph_def_version),
lib_def_(lib_def),
optimizer_(optimizer_options),
- custom_kernel_creator_(std::move(custom_kernel_creator)),
- device_name_(device_ == nullptr
- ? ProcessFunctionLibraryRuntime::kDefaultFLRDevice
- : device_->name()),
- parent_(parent) {
+ custom_kernel_creator_(std::move(custom_kernel_creator)) {
get_func_sig_ = [this](const string& op, const OpDef** sig) {
return lib_def_->LookUpOpDef(op, sig);
};
@@ -297,17 +294,10 @@ class CallOp : public AsyncOpKernel {
};
const FunctionBody* FunctionLibraryRuntimeImpl::GetFunctionBody(Handle h) {
- LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, h);
- if (local_handle == kInvalidLocalHandle) {
- LOG(ERROR) << "Could not find Handle: " << h
- << " on device: " << device_name_;
- return nullptr;
- }
-
mutex_lock l(mu_);
- CHECK_LE(0, local_handle);
- CHECK_LT(local_handle, func_graphs_.size());
- return func_graphs_[local_handle];
+ CHECK_LE(static_cast<Handle>(0), h);
+ CHECK_LT(h, func_graphs_.size());
+ return func_graphs_[h];
}
Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
@@ -403,24 +393,16 @@ Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient(
return Status::OK();
}
-bool FunctionLibraryRuntimeImpl::IsLocalTarget(const AttrSlice& attrs) {
- if (device_ == nullptr) return true;
- string target = ProcessFunctionLibraryRuntime::ObtainFunctionTarget(attrs);
- if (target.empty()) return true;
- return target == device_->name();
-}
-
Status FunctionLibraryRuntimeImpl::Instantiate(const string& function_name,
AttrSlice attrs,
Handle* handle) {
- if (!IsLocalTarget(attrs)) {
- return parent_->Instantiate(function_name, attrs, handle);
- }
-
const string key = Canonicalize(function_name, attrs);
- *handle = parent_->GetHandle(key);
- if (*handle != kInvalidHandle) {
- return Status::OK();
+ {
+ mutex_lock l(mu_);
+ *handle = gtl::FindWithDefault(table_, key, kInvalidHandle);
+ if (*handle != kInvalidHandle) {
+ return Status::OK();
+ }
}
Status s;
@@ -449,11 +431,12 @@ Status FunctionLibraryRuntimeImpl::Instantiate(const string& function_name,
{
mutex_lock l(mu_);
- *handle = parent_->GetHandle(key);
+ *handle = gtl::FindWithDefault(table_, key, kInvalidHandle);
if (*handle != kInvalidHandle) {
delete fbody;
} else {
- *handle = parent_->AddHandle(key, device_name_, func_graphs_.size());
+ *handle = func_graphs_.size();
+ table_.insert({key, *handle});
func_graphs_.push_back(fbody);
items_.resize(func_graphs_.size());
}
@@ -511,14 +494,13 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
}
Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) {
- LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
{
mutex_lock l(mu_);
- if (local_handle >= items_.size()) {
+ if (handle >= items_.size()) {
return errors::NotFound("Function handle ", handle,
" is not valid. Likely an internal error.");
}
- *item = items_[local_handle];
+ *item = items_[handle];
if (*item != nullptr) {
(*item)->Ref();
return Status::OK();
@@ -530,9 +512,9 @@ Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) {
{
mutex_lock l(mu_);
- if (items_[local_handle] == nullptr) {
+ if (items_[handle] == nullptr) {
// Install *item in items_.
- items_[local_handle] = *item;
+ items_[handle] = *item;
(*item)->Ref();
}
}
@@ -546,9 +528,6 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
return done(errors::Cancelled(""));
}
- if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) {
- return parent_->Run(opts, handle, args, rets, done);
- }
const FunctionBody* fbody = GetFunctionBody(handle);
FunctionCallFrame* frame =
new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
@@ -637,21 +616,19 @@ std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
const DeviceMgr* device_mgr, Env* env, Device* device,
int graph_def_version, const FunctionLibraryDefinition* lib_def,
const OptimizerOptions& optimizer_options,
- CustomKernelCreator custom_kernel_creator,
- ProcessFunctionLibraryRuntime* parent) {
+ CustomKernelCreator custom_kernel_creator) {
return std::unique_ptr<FunctionLibraryRuntime>(new FunctionLibraryRuntimeImpl(
device_mgr, env, device, graph_def_version, lib_def, optimizer_options,
- std::move(custom_kernel_creator), parent));
+ std::move(custom_kernel_creator)));
}
std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
const DeviceMgr* device_mgr, Env* env, Device* device,
int graph_def_version, const FunctionLibraryDefinition* lib_def,
- const OptimizerOptions& optimizer_options,
- ProcessFunctionLibraryRuntime* parent) {
+ const OptimizerOptions& optimizer_options) {
return NewFunctionLibraryRuntime(device_mgr, env, device, graph_def_version,
lib_def, optimizer_options,
- GetCustomCreatorSingleton()->Get(), parent);
+ GetCustomCreatorSingleton()->Get());
}
bool RemoveDeadNodes(Graph* g) {