aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-01-04 14:48:09 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-04 14:51:51 -0800
commit19bbc31eee8b81bce6eb08b7ada539943fac6014 (patch)
tree079f8dcc4b87783f3c356011e5708e11a33814d4 /tensorflow
parentb639608a6da140f720636582022a575d7c8a7650 (diff)
Add `FunctionLibraryRuntime::InstantiateOptions` struct.
This new struct allows optional arguments to be passed to the `FunctionLibraryRuntime::Instantiate()` API. The new struct is now used to configure the target device for a function instantiation (instead of an attr). PiperOrigin-RevId: 180848930
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/data/kernels/prefetching_kernels.cc7
-rw-r--r--tensorflow/core/common_runtime/function.cc51
-rw-r--r--tensorflow/core/common_runtime/function_test.cc16
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime.cc25
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime.h6
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime_test.cc59
-rw-r--r--tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc24
-rw-r--r--tensorflow/core/distributed_runtime/cluster_function_library_runtime.h9
-rw-r--r--tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc48
-rw-r--r--tensorflow/core/framework/function.cc9
-rw-r--r--tensorflow/core/framework/function.h45
-rw-r--r--tensorflow/core/kernels/function_ops.cc12
-rw-r--r--tensorflow/python/debug/wrappers/framework.py6
-rw-r--r--tensorflow/python/debug/wrappers/framework_test.py4
14 files changed, 150 insertions, 171 deletions
diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc
index c9a3537c70..742835c964 100644
--- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc
+++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc
@@ -83,11 +83,8 @@ class FunctionBufferingResource : public ResourceBase {
return Status::OK();
}
AttrValueMap attr_values = func_.attr();
- AttrValue v;
- v.set_s(target_device_);
- AddAttr("_target", v, &attr_values);
-
- return lib_->Instantiate(func_.name(), AttrSlice(&attr_values), &handle_);
+ return lib_->Instantiate(func_.name(), AttrSlice(&attr_values),
+ {target_device_}, &handle_);
}
// Returns true if we've got to the end of the sequence and exhausted the
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 51d7f98f72..286266a485 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -152,6 +152,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
~FunctionLibraryRuntimeImpl() override;
Status Instantiate(const string& function_name, AttrSlice attrs,
+ const InstantiateOptions& options,
Handle* handle) override;
Status ReleaseHandle(Handle handle) override;
@@ -223,7 +224,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
Status GetOrCreateItem(Handle handle, Item** item);
Status InstantiateSymbolicGradient(const NameAttrList& func,
FunctionBody** g_body);
- bool IsLocalTarget(const AttrSlice& attrs);
+ bool IsLocalTarget(const InstantiateOptions& options);
AttrValueMap FixAttrs(const AttrSlice& attrs);
void RunRemote(const Options& opts, Handle handle,
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
@@ -352,7 +353,8 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
// Try to instantiate this function for the func/attr. Maybe it's
// cached already.
Handle handle;
- TF_RETURN_IF_ERROR(Instantiate(ndef.op(), AttrSlice(&ndef.attr()), &handle));
+ TF_RETURN_IF_ERROR(
+ Instantiate(ndef.op(), AttrSlice(&ndef.attr()), {}, &handle));
const FunctionBody* fbody = GetFunctionBody(handle);
CHECK_NOTNULL(fbody);
@@ -411,7 +413,7 @@ Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient(
// f is a user-defined function.
Handle f_handle;
TF_RETURN_IF_ERROR(
- Instantiate(func.name(), AttrSlice(&func.attr()), &f_handle));
+ Instantiate(func.name(), AttrSlice(&func.attr()), {}, &f_handle));
const FunctionBody* f_body = GetFunctionBody(f_handle);
CHECK_NOTNULL(f_body);
*g_body = SymbolicGradient(*f_body);
@@ -419,42 +421,25 @@ Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient(
return Status::OK();
}
-bool FunctionLibraryRuntimeImpl::IsLocalTarget(const AttrSlice& attrs) {
+bool FunctionLibraryRuntimeImpl::IsLocalTarget(
+ const InstantiateOptions& options) {
if (device_ == nullptr) return true;
- string target = ProcessFunctionLibraryRuntime::ObtainFunctionTarget(attrs);
- if (target.empty()) return true;
+ if (options.target.empty()) return true;
Device* target_device;
- if (!device_mgr_->LookupDevice(target, &target_device).ok()) {
+ if (!device_mgr_->LookupDevice(options.target, &target_device).ok()) {
return false;
}
return target_device == device_;
}
-AttrValueMap FunctionLibraryRuntimeImpl::FixAttrs(const AttrSlice& attrs) {
- AttrValueMap value_map;
- for (auto it : attrs) {
- value_map[it.first] = it.second;
- }
- if (attrs.Find("_target") != nullptr) {
- return value_map;
- }
- AttrValue v;
- v.set_s(device_name_);
- AddAttr("_target", v, &value_map);
- return value_map;
-}
-
-Status FunctionLibraryRuntimeImpl::Instantiate(const string& function_name,
- AttrSlice attrs,
- Handle* handle) {
- AttrValueMap value_map = FixAttrs(attrs);
- AttrSlice new_attrs(&value_map);
-
- if (!IsLocalTarget(new_attrs)) {
- return parent_->Instantiate(function_name, new_attrs, handle);
+Status FunctionLibraryRuntimeImpl::Instantiate(
+ const string& function_name, AttrSlice attrs,
+ const InstantiateOptions& options, Handle* handle) {
+ if (!IsLocalTarget(options)) {
+ return parent_->Instantiate(function_name, attrs, options, handle);
}
- const string key = Canonicalize(function_name, new_attrs);
+ const string key = Canonicalize(function_name, attrs, options);
*handle = parent_->GetHandle(key);
if (*handle != kInvalidHandle) {
return Status::OK();
@@ -463,7 +448,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate(const string& function_name,
Status s;
FunctionBody* fbody = nullptr;
if (function_name == kGradientOp) {
- const AttrValue* f = new_attrs.Find(kFuncAttr);
+ const AttrValue* f = attrs.Find(kFuncAttr);
if (f == nullptr) {
return errors::InvalidArgument("SymbolicGradient is missing attr: f");
}
@@ -473,7 +458,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate(const string& function_name,
}
const string grad = lib_def_->FindGradient(func.name());
if (!grad.empty()) {
- return Instantiate(grad, AttrSlice(&func.attr()), handle);
+ return Instantiate(grad, AttrSlice(&func.attr()), options, handle);
}
TF_RETURN_IF_ERROR(InstantiateSymbolicGradient(func, &fbody));
} else {
@@ -481,7 +466,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate(const string& function_name,
if (fdef == nullptr) {
return errors::NotFound("Function ", function_name, " is not defined.");
}
- TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, new_attrs, &fbody));
+ TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, attrs, &fbody));
}
{
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index d4181ff48c..2dacacea7b 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -191,11 +191,14 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
Status Instantiate(FunctionLibraryRuntime* flr, const string& name,
test::function::Attrs attrs,
FunctionLibraryRuntime::Handle* handle) {
- Status status = flr->Instantiate(name, attrs, handle);
- if (!status.ok()) {
- return status;
- }
- return Status::OK();
+ return flr->Instantiate(name, attrs, handle);
+ }
+
+ Status Instantiate(FunctionLibraryRuntime* flr, const string& name,
+ test::function::Attrs attrs,
+ const FunctionLibraryRuntime::InstantiateOptions& options,
+ FunctionLibraryRuntime::Handle* handle) {
+ return flr->Instantiate(name, attrs, options, handle);
}
Status InstantiateAndRun(FunctionLibraryRuntime* flr, const string& name,
@@ -1088,8 +1091,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
Init({test::function::FindDevice()});
FunctionLibraryRuntime::Handle handle;
- TF_CHECK_OK(Instantiate(flr0_, "FindDevice", {{"_target", "/device:CPU:1"}},
- &handle));
+ TF_CHECK_OK(Instantiate(flr0_, "FindDevice", {}, {"/device:CPU:1"}, &handle));
Tensor y;
FunctionLibraryRuntime::Options opts;
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc
index 53a14121d4..12947e284a 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc
@@ -88,16 +88,6 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
std::move(custom_kernel_creator), nullptr /* cluster_flr */) {}
/* static */
-string ProcessFunctionLibraryRuntime::ObtainFunctionTarget(
- const AttrSlice& attrs) {
- const AttrValue* value;
- if (!attrs.Find("_target", &value).ok()) {
- return "";
- }
- return DeviceNameUtils::CanonicalizeDeviceName(value->s());
-}
-
-/* static */
Status ProcessFunctionLibraryRuntime::SendTensors(
const string& source_device, const string& target_device,
const string& key_prefix, int64 src_incarnation,
@@ -240,22 +230,23 @@ string ProcessFunctionLibraryRuntime::GetDeviceName(
Status ProcessFunctionLibraryRuntime::Instantiate(
const string& function_name, AttrSlice attrs,
+ const FunctionLibraryRuntime::InstantiateOptions& options,
FunctionLibraryRuntime::Handle* handle) {
*handle = kInvalidHandle;
- string target = ObtainFunctionTarget(attrs);
- FunctionLibraryRuntime* flr = GetFLR(target);
+ FunctionLibraryRuntime* flr = GetFLR(options.target);
if (flr != nullptr) {
- return flr->Instantiate(function_name, attrs, handle);
+ return flr->Instantiate(function_name, attrs, options, handle);
}
if (parent_ == nullptr) {
return errors::Internal(
- "Currently don't support instantiating functions on device: ", target);
+ "Currently don't support instantiating functions on device: ",
+ options.target);
}
FunctionLibraryRuntime::Handle cluster_handle;
- TF_RETURN_IF_ERROR(
- parent_->Instantiate(function_name, *lib_def_, attrs, &cluster_handle));
+ TF_RETURN_IF_ERROR(parent_->Instantiate(function_name, *lib_def_, attrs,
+ options, &cluster_handle));
string function_key = Canonicalize(function_name, attrs);
- *handle = AddHandle(function_key, target, cluster_handle);
+ *handle = AddHandle(function_key, options.target, cluster_handle);
return Status::OK();
}
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h
index 3aa7b87286..38003b7726 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.h
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.h
@@ -53,11 +53,6 @@ class ProcessFunctionLibraryRuntime {
const OptimizerOptions& optimizer_options,
CustomKernelCreator custom_kernel_creator);
- // Given a list of attrs on a function, extracts the "_target" attribute which
- // indicates which device to run the function on. If it can't find the _target
- // attribute, returns "". Canonicalizes the device name.
- static string ObtainFunctionTarget(const AttrSlice& attrs);
-
// Sends `tensors_to_send` from `source_device` to `target_device` using
// `rendezvous`. `key_prefix` is used as a prefix for the keys sent to the
// Rendezvous. `device_context` should be the DeviceContext of the device
@@ -121,6 +116,7 @@ class ProcessFunctionLibraryRuntime {
// Allows for function_name to be instantiated on different devices
// as specified in attrs.
Status Instantiate(const string& function_name, AttrSlice attrs,
+ const FunctionLibraryRuntime::InstantiateOptions& options,
FunctionLibraryRuntime::Handle* handle);
// Delegates to the local FLR that owns state corresponding to `handle` and
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
index 270e46dfe9..f11b7a851f 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
@@ -49,10 +49,12 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
}
Status Run(const string& name, FunctionLibraryRuntime::Options opts,
- test::function::Attrs attrs, const std::vector<Tensor>& args,
- std::vector<Tensor*> rets) {
+ test::function::Attrs attrs,
+ const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
+ const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
FunctionLibraryRuntime::Handle handle;
- Status status = proc_flr_->Instantiate(name, attrs, &handle);
+ Status status =
+ proc_flr_->Instantiate(name, attrs, instantiate_opts, &handle);
if (!status.ok()) {
return status;
}
@@ -142,21 +144,6 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) {
rendezvous_->Unref();
}
-TEST_F(ProcessFunctionLibraryRuntimeTest, ObtainFunctionTarget) {
- AttrSlice empty_attrs;
- string target =
- ProcessFunctionLibraryRuntime::ObtainFunctionTarget(empty_attrs);
- EXPECT_EQ("", target);
-
- AttrValueMap attr_values;
- AttrValue v;
- v.set_s("/job:a/replica:0/task:0/cpu:1");
- AddAttr("_target", v, &attr_values);
- AttrSlice attrs(&attr_values);
- target = ProcessFunctionLibraryRuntime::ObtainFunctionTarget(attrs);
- EXPECT_EQ("/job:a/replica:0/task:0/device:CPU:1", target);
-}
-
TEST_F(ProcessFunctionLibraryRuntimeTest, GetDeviceIncarnation) {
Init({});
int64 incarnation;
@@ -178,10 +165,8 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCall) {
opts.remote_execution = true;
auto x = test::AsTensor<float>({1, 2, 3, 4});
Tensor y;
- TF_CHECK_OK(
- Run("XTimesTwo", opts,
- {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
- {&y}));
+ TF_CHECK_OK(Run("XTimesTwo", opts, {{"T", DT_FLOAT}},
+ {"/job:a/replica:0/task:0/cpu:0"}, {x}, {&y}));
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
rendezvous_->Unref();
}
@@ -193,8 +178,8 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCallFindDevice) {
opts.rendezvous = rendezvous_;
opts.remote_execution = true;
Tensor y;
- TF_CHECK_OK(Run("FindDevice", opts,
- {{"_target", "/job:a/replica:0/task:0/cpu:0"}}, {}, {&y}));
+ TF_CHECK_OK(
+ Run("FindDevice", opts, {}, {"/job:a/replica:0/task:0/cpu:0"}, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:0"},
TensorShape({})));
@@ -209,15 +194,11 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceXTimes) {
opts.rendezvous = rendezvous_;
opts.remote_execution = true;
Tensor y;
- TF_CHECK_OK(
- Run("XTimesTwo", opts,
- {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
- {&y}));
+ TF_CHECK_OK(Run("XTimesTwo", opts, {{"T", DT_FLOAT}},
+ {"/job:a/replica:0/task:0/cpu:0"}, {x}, {&y}));
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
- TF_CHECK_OK(
- Run("XTimesFour", opts,
- {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
- {&y}));
+ TF_CHECK_OK(Run("XTimesFour", opts, {{"T", DT_FLOAT}},
+ {"/job:a/replica:0/task:0/cpu:0"}, {x}, {&y}));
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({4, 8, 12, 16}));
rendezvous_->Unref();
}
@@ -229,13 +210,13 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceFindDevice) {
opts.rendezvous = rendezvous_;
opts.remote_execution = true;
Tensor y;
- TF_CHECK_OK(Run("FindDevice", opts,
- {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y}));
+ TF_CHECK_OK(
+ Run("FindDevice", opts, {}, {"/job:a/replica:0/task:0/cpu:1"}, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:1"},
TensorShape({})));
- TF_CHECK_OK(Run("FindDevice", opts,
- {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y}));
+ TF_CHECK_OK(
+ Run("FindDevice", opts, {}, {"/job:a/replica:0/task:0/cpu:1"}, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:1"},
TensorShape({})));
@@ -249,11 +230,13 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) {
opts.rendezvous = rendezvous_;
opts.remote_execution = true;
Tensor y;
- TF_CHECK_OK(Run("FindDevice", opts, {{"_target", "/cpu:0"}}, {}, {&y}));
+ TF_CHECK_OK(Run("FindDevice", opts, {},
+ {"/job:a/replica:0/task:0/device:CPU:0"}, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:0"},
TensorShape({})));
- TF_CHECK_OK(Run("FindDevice", opts, {{"_target", "/cpu:1"}}, {}, {&y}));
+ TF_CHECK_OK(Run("FindDevice", opts, {},
+ {"/job:a/replica:0/task:0/device:CPU:1"}, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:1"},
TensorShape({})));
diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc
index d84b69d06b..3a8d591236 100644
--- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc
+++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc
@@ -26,10 +26,10 @@ namespace tensorflow {
/* static */
Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph(
- const OpDef& sig, AttrSlice attrs, GraphDef* g,
+ const OpDef& sig, AttrSlice attrs,
+ const FunctionLibraryRuntime::InstantiateOptions& options, GraphDef* g,
std::vector<string>* send_keys, std::vector<string>* recv_keys) {
- const string& target =
- ProcessFunctionLibraryRuntime::ObtainFunctionTarget(attrs);
+ const string& target = options.target;
// Construct recv nodes for each input argument.
int i = 0;
for (const auto& in : sig.input_arg()) {
@@ -119,16 +119,16 @@ ClusterFunctionLibraryRuntime::~ClusterFunctionLibraryRuntime() {
Status ClusterFunctionLibraryRuntime::Instantiate(
const string& function_name, const FunctionLibraryDefinition& lib_def,
- AttrSlice attrs, FunctionLibraryRuntime::LocalHandle* handle) {
- const string& target =
- ProcessFunctionLibraryRuntime::ObtainFunctionTarget(attrs);
- WorkerInterface* wi = worker_session_->worker_cache->CreateWorker(target);
+ AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options,
+ FunctionLibraryRuntime::LocalHandle* handle) {
+ WorkerInterface* wi =
+ worker_session_->worker_cache->CreateWorker(options.target);
if (wi == nullptr) {
std::vector<string> workers;
worker_session_->worker_cache->ListWorkers(&workers);
return errors::InvalidArgument(
- "Could not find worker with target: ", target,
+ "Could not find worker with target: ", options.target,
" Available workers: ", str_util::Join(workers, ", "));
}
@@ -137,8 +137,8 @@ Status ClusterFunctionLibraryRuntime::Instantiate(
const OpDef& sig = fdef->signature();
GraphDef gdef;
std::vector<string> send_keys, recv_keys;
- TF_RETURN_IF_ERROR(
- ConstructFunctionGraph(sig, attrs, &gdef, &send_keys, &recv_keys));
+ TF_RETURN_IF_ERROR(ConstructFunctionGraph(sig, attrs, options, &gdef,
+ &send_keys, &recv_keys));
*gdef.mutable_library() = lib_def.ToProto();
RegisterGraphRequest req;
@@ -152,8 +152,8 @@ Status ClusterFunctionLibraryRuntime::Instantiate(
mutex_lock l(mu_);
*handle = function_data_.size();
- function_data_.push_back(
- FunctionData(resp.graph_handle(), target, wi, send_keys, recv_keys));
+ function_data_.push_back(FunctionData(resp.graph_handle(), options.target, wi,
+ send_keys, recv_keys));
return Status::OK();
}
diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h
index dd4ea68f57..3deb80dff7 100644
--- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h
+++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h
@@ -34,6 +34,7 @@ class ClusterFunctionLibraryRuntime : public DistributedFunctionLibraryRuntime {
Status Instantiate(const string& function_name,
const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
+ const FunctionLibraryRuntime::InstantiateOptions& options,
FunctionLibraryRuntime::LocalHandle* handle) override;
void Run(const FunctionLibraryRuntime::Options& opts,
@@ -42,10 +43,10 @@ class ClusterFunctionLibraryRuntime : public DistributedFunctionLibraryRuntime {
FunctionLibraryRuntime::DoneCallback done) override;
private:
- static Status ConstructFunctionGraph(const OpDef& sig, AttrSlice attrs,
- GraphDef* g,
- std::vector<string>* send_keys,
- std::vector<string>* recv_keys);
+ static Status ConstructFunctionGraph(
+ const OpDef& sig, AttrSlice attrs,
+ const FunctionLibraryRuntime::InstantiateOptions& options, GraphDef* g,
+ std::vector<string>* send_keys, std::vector<string>* recv_keys);
friend class ClusterFunctionLibraryRuntimeTest;
mutable mutex mu_;
diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc b/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc
index 6dd8b9ec73..98512bce18 100644
--- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc
+++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc
@@ -47,30 +47,31 @@ class ClusterFunctionLibraryRuntimeTest : public ::testing::Test {
new ClusterFunctionLibraryRuntime(worker_session_.get()));
}
- Status ConstructFunctionGraphHelper(const OpDef& sig,
- test::function::Attrs attrs, GraphDef* g,
- std::vector<string>* send_keys,
- std::vector<string>* recv_keys) {
+ Status ConstructFunctionGraphHelper(
+ const OpDef& sig, test::function::Attrs attrs,
+ const FunctionLibraryRuntime::InstantiateOptions& options, GraphDef* g,
+ std::vector<string>* send_keys, std::vector<string>* recv_keys) {
return ClusterFunctionLibraryRuntime::ConstructFunctionGraph(
- sig, attrs, g, send_keys, recv_keys);
+ sig, attrs, options, g, send_keys, recv_keys);
}
Status Instantiate(const string& function_name,
const FunctionLibraryDefinition& lib_def,
test::function::Attrs attrs,
+ const FunctionLibraryRuntime::InstantiateOptions& options,
FunctionLibraryRuntime::LocalHandle* local_handle) {
- return cluster_flr_->Instantiate(function_name, lib_def, attrs,
+ return cluster_flr_->Instantiate(function_name, lib_def, attrs, options,
local_handle);
}
- Status InstantiateAndRun(const string& function_name,
- const FunctionLibraryDefinition& lib_def,
- test::function::Attrs attrs,
- const std::vector<Tensor>& args,
- std::vector<Tensor*> rets) {
+ Status InstantiateAndRun(
+ const string& function_name, const FunctionLibraryDefinition& lib_def,
+ test::function::Attrs attrs,
+ const FunctionLibraryRuntime::InstantiateOptions& options,
+ const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
FunctionLibraryRuntime::LocalHandle handle;
- TF_RETURN_IF_ERROR(
- cluster_flr_->Instantiate(function_name, lib_def, attrs, &handle));
+ TF_RETURN_IF_ERROR(cluster_flr_->Instantiate(function_name, lib_def, attrs,
+ options, &handle));
Notification done;
FunctionLibraryRuntime::Options opts;
@@ -103,9 +104,9 @@ TEST_F(ClusterFunctionLibraryRuntimeTest, ConstructFunctionGraph) {
GraphDef actual;
std::vector<string> send_keys, recv_keys;
TF_CHECK_OK(ConstructFunctionGraphHelper(
- test::function::Swap().signature(),
- {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, &actual,
- &send_keys, &recv_keys));
+ test::function::Swap().signature(), {{"T", DT_FLOAT}},
+ {"/job:a/replica:0/task:0/device:CPU:0"}, &actual, &send_keys,
+ &recv_keys));
GraphDef expected;
protobuf::TextFormat::ParseFromString(R"(
node {
@@ -205,7 +206,7 @@ node {
attr {
key: "_target"
value {
- s: "/job:a/replica:0/task:0/cpu:0"
+ s: "/job:a/replica:0/task:0/device:CPU:0"
}
}
}
@@ -309,9 +310,9 @@ TEST_F(ClusterFunctionLibraryRuntimeTest, DISABLED_InstantiateAndRun) {
Tensor y;
auto x = test::AsTensor<int32>({1, 2, 3, 4});
- TF_EXPECT_OK(InstantiateAndRun(
- "XTimesTwoInt32", lib_def,
- {{"_target", "/job:localhost/replica:0/task:1/cpu:0"}}, {x}, {&y}));
+ TF_EXPECT_OK(InstantiateAndRun("XTimesTwoInt32", lib_def, {},
+ {"/job:localhost/replica:0/task:1/cpu:0"}, {x},
+ {&y}));
test::ExpectTensorEqual<int32>(y, test::AsTensor<int32>({2, 4, 6, 8}));
}
@@ -324,10 +325,9 @@ TEST_F(ClusterFunctionLibraryRuntimeTest,
Tensor y1, y2;
auto x1 = test::AsTensor<float>({1, 2, 3, 4});
auto x2 = test::AsTensor<float>({4, 3, 2, 1});
- TF_EXPECT_OK(InstantiateAndRun(
- "Swap", lib_def,
- {{"T", DT_FLOAT}, {"_target", "/job:localhost/replica:0/task:1/cpu:0"}},
- {x1, x2}, {&y1, &y2}));
+ TF_EXPECT_OK(InstantiateAndRun("Swap", lib_def, {{"T", DT_FLOAT}},
+ {"/job:localhost/replica:0/task:1/cpu:0"},
+ {x1, x2}, {&y1, &y2}));
test::ExpectTensorEqual<float>(y1, test::AsTensor<float>({4, 3, 2, 1}));
test::ExpectTensorEqual<float>(y2, test::AsTensor<float>({1, 2, 3, 4}));
}
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index d757e962e5..783015481a 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -795,12 +795,17 @@ uint64 FunctionDefHash(const FunctionDef& fdef) {
return h;
}
-string Canonicalize(const string& funcname, AttrSlice attrs) {
+string Canonicalize(const string& funcname, AttrSlice attrs,
+ const FunctionLibraryRuntime::InstantiateOptions& options) {
std::vector<string> entries;
- entries.reserve(attrs.size());
+ entries.reserve(options.target.empty() ? attrs.size() : (attrs.size() + 1));
for (auto p : attrs) {
entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
}
+ if (!options.target.empty()) {
+ entries.push_back(
+ strings::StrCat("_target", "=", str_util::CEscape(options.target)));
+ }
std::sort(entries.begin(), entries.end());
return strings::StrCat(funcname, "[", str_util::Join(entries, ","), "]");
}
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index 1a579ab631..e5d0e49dbb 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -234,15 +234,6 @@ bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2);
// same.
uint64 FunctionDefHash(const FunctionDef& fdef);
-// Returns a canonicalized string for the instantiation of the
-// function of the given "name" and attributes "attrs".
-//
-// The returned string is guaranteed to be stable within one address
-// space. But it may be change as the implementation
-// evolves. Therefore, it should not be persisted or compared across
-// address spaces.
-string Canonicalize(const string& funcname, AttrSlice attrs);
-
class CallFrameInterface {
public:
virtual ~CallFrameInterface() {}
@@ -418,9 +409,23 @@ class FunctionLibraryRuntime {
//
// Returns OK and fills in "handle" if the instantiation succeeds.
// Otherwise returns an error and "handle" is undefined.
+ struct InstantiateOptions {
+ // The canonical device name of the device on which the function
+ // should be instantiated. If empty, the function will be
+ // instantiated on the local device.
+ string target;
+
+ // TODO(b/70352992): Add an API for allowing a different
+ // FunctionLibraryDefinition to be overlaid on this runtime's library.
+ };
typedef uint64 Handle;
virtual Status Instantiate(const string& function_name, AttrSlice attrs,
+ const InstantiateOptions& options,
Handle* handle) = 0;
+ Status Instantiate(const string& function_name, AttrSlice attrs,
+ Handle* handle) {
+ return Instantiate(function_name, attrs, {}, handle);
+ }
// Releases state associated with the handle.
virtual Status ReleaseHandle(Handle handle) = 0;
@@ -502,6 +507,19 @@ class FunctionLibraryRuntime {
typedef uint64 LocalHandle;
};
+// Returns a canonicalized string for the instantiation of the
+// function of the given "name", attributes "attrs", and "options".
+//
+// The returned string is guaranteed to be stable within one address
+// space. But it may be change as the implementation
+// evolves. Therefore, it should not be persisted or compared across
+// address spaces.
+string Canonicalize(const string& funcname, AttrSlice attrs,
+ const FunctionLibraryRuntime::InstantiateOptions& options);
+inline string Canonicalize(const string& funcname, AttrSlice attrs) {
+ return Canonicalize(funcname, attrs, {});
+}
+
const FunctionLibraryRuntime::Handle kInvalidHandle = -1;
const FunctionLibraryRuntime::LocalHandle kInvalidLocalHandle = -1;
typedef std::function<Status(FunctionLibraryRuntime*, const NodeDef&,
@@ -514,10 +532,11 @@ class DistributedFunctionLibraryRuntime {
virtual ~DistributedFunctionLibraryRuntime() {}
// The _target attr in attrs determines where the function is instantiated.
- virtual Status Instantiate(const string& function_name,
- const FunctionLibraryDefinition& lib_def,
- AttrSlice attrs,
- FunctionLibraryRuntime::LocalHandle* handle) = 0;
+ virtual Status Instantiate(
+ const string& function_name, const FunctionLibraryDefinition& lib_def,
+ AttrSlice attrs,
+ const FunctionLibraryRuntime::InstantiateOptions& options,
+ FunctionLibraryRuntime::LocalHandle* handle) = 0;
// opts.runner isn't used for execution.
virtual void Run(const FunctionLibraryRuntime::Options& opts,
diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc
index f469f41e06..facac10f66 100644
--- a/tensorflow/core/kernels/function_ops.cc
+++ b/tensorflow/core/kernels/function_ops.cc
@@ -296,21 +296,19 @@ class RemoteCallOp : public AsyncOpKernel {
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
const Tensor* target;
OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done);
- AttrValueMap attr_values = func_.attr();
- AttrValue v;
const string& target_device =
DeviceNameUtils::CanonicalizeDeviceName(target->scalar<string>()());
- v.set_s(target_device);
- AddAttr("_target", v, &attr_values);
FunctionLibraryRuntime* lib = ctx->function_library();
OP_REQUIRES_ASYNC(ctx, lib != nullptr,
errors::Internal("No function library is provided."),
done);
+ AttrValueMap attr_values = func_.attr();
FunctionLibraryRuntime::Handle handle;
- OP_REQUIRES_OK_ASYNC(
- ctx, lib->Instantiate(func_.name(), AttrSlice(&attr_values), &handle),
- done);
+ OP_REQUIRES_OK_ASYNC(ctx,
+ lib->Instantiate(func_.name(), AttrSlice(&attr_values),
+ {target_device}, &handle),
+ done);
OpInputList arguments;
OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done);
diff --git a/tensorflow/python/debug/wrappers/framework.py b/tensorflow/python/debug/wrappers/framework.py
index 909150eb6a..00c38eac10 100644
--- a/tensorflow/python/debug/wrappers/framework.py
+++ b/tensorflow/python/debug/wrappers/framework.py
@@ -154,7 +154,8 @@ class OnSessionInitRequest(object):
sess: A tensorflow Session object.
"""
- _check_type(sess, (session.BaseSession, monitored_session.MonitoredSession))
+ _check_type(sess, (session.SessionInterface,
+ monitored_session.MonitoredSession))
self.session = sess
@@ -358,7 +359,8 @@ class BaseDebugWrapperSession(session.SessionInterface):
NotImplementedError: If a non-DirectSession sess object is received.
"""
- _check_type(sess, (session.BaseSession, monitored_session.MonitoredSession))
+ _check_type(sess, (session.SessionInterface,
+ monitored_session.MonitoredSession))
# The session being wrapped.
self._sess = sess
diff --git a/tensorflow/python/debug/wrappers/framework_test.py b/tensorflow/python/debug/wrappers/framework_test.py
index 73e08ce7d5..5240e0d0ad 100644
--- a/tensorflow/python/debug/wrappers/framework_test.py
+++ b/tensorflow/python/debug/wrappers/framework_test.py
@@ -271,9 +271,9 @@ class DebugWrapperSessionTest(test_util.TensorFlowTestCase):
def testSessionInitInvalidSessionType(self):
"""Attempt to wrap a non-Session-type object should cause an exception."""
- wrapper = TestDebugWrapperSessionBadAction(self._sess)
+ sess = "not a session"
with self.assertRaisesRegexp(TypeError, "Expected type .*; got type .*"):
- TestDebugWrapperSessionBadAction(wrapper)
+ TestDebugWrapperSessionBadAction(sess)
def testSessionInitBadActionValue(self):
with self.assertRaisesRegexp(