aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar James Qin <jamesqin@google.com>2018-01-05 14:55:24 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-05 14:58:51 -0800
commitcc94f3888eb008fd78a2d95e30ecf76fbce0e0af (patch)
treecc126be226c801b8fc374b5ff4050deace0c098d
parentc152d4bbac33d9dc10f7278226aaeb42fce799b5 (diff)
Automated g4 rollback of changelist 180848930
PiperOrigin-RevId: 180979141
-rw-r--r--tensorflow/contrib/data/kernels/prefetching_kernels.cc7
-rw-r--r--tensorflow/core/common_runtime/function.cc51
-rw-r--r--tensorflow/core/common_runtime/function_test.cc16
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime.cc25
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime.h6
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime_test.cc59
-rw-r--r--tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc24
-rw-r--r--tensorflow/core/distributed_runtime/cluster_function_library_runtime.h9
-rw-r--r--tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc48
-rw-r--r--tensorflow/core/framework/function.cc9
-rw-r--r--tensorflow/core/framework/function.h45
-rw-r--r--tensorflow/core/kernels/function_ops.cc12
-rw-r--r--tensorflow/python/debug/wrappers/framework.py6
-rw-r--r--tensorflow/python/debug/wrappers/framework_test.py4
14 files changed, 171 insertions, 150 deletions
diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc
index 742835c964..c9a3537c70 100644
--- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc
+++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc
@@ -83,8 +83,11 @@ class FunctionBufferingResource : public ResourceBase {
return Status::OK();
}
AttrValueMap attr_values = func_.attr();
- return lib_->Instantiate(func_.name(), AttrSlice(&attr_values),
- {target_device_}, &handle_);
+ AttrValue v;
+ v.set_s(target_device_);
+ AddAttr("_target", v, &attr_values);
+
+ return lib_->Instantiate(func_.name(), AttrSlice(&attr_values), &handle_);
}
// Returns true if we've got to the end of the sequence and exhausted the
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 286266a485..51d7f98f72 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -152,7 +152,6 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
~FunctionLibraryRuntimeImpl() override;
Status Instantiate(const string& function_name, AttrSlice attrs,
- const InstantiateOptions& options,
Handle* handle) override;
Status ReleaseHandle(Handle handle) override;
@@ -224,7 +223,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
Status GetOrCreateItem(Handle handle, Item** item);
Status InstantiateSymbolicGradient(const NameAttrList& func,
FunctionBody** g_body);
- bool IsLocalTarget(const InstantiateOptions& options);
+ bool IsLocalTarget(const AttrSlice& attrs);
AttrValueMap FixAttrs(const AttrSlice& attrs);
void RunRemote(const Options& opts, Handle handle,
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
@@ -353,8 +352,7 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
// Try to instantiate this function for the func/attr. Maybe it's
// cached already.
Handle handle;
- TF_RETURN_IF_ERROR(
- Instantiate(ndef.op(), AttrSlice(&ndef.attr()), {}, &handle));
+ TF_RETURN_IF_ERROR(Instantiate(ndef.op(), AttrSlice(&ndef.attr()), &handle));
const FunctionBody* fbody = GetFunctionBody(handle);
CHECK_NOTNULL(fbody);
@@ -413,7 +411,7 @@ Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient(
// f is a user-defined function.
Handle f_handle;
TF_RETURN_IF_ERROR(
- Instantiate(func.name(), AttrSlice(&func.attr()), {}, &f_handle));
+ Instantiate(func.name(), AttrSlice(&func.attr()), &f_handle));
const FunctionBody* f_body = GetFunctionBody(f_handle);
CHECK_NOTNULL(f_body);
*g_body = SymbolicGradient(*f_body);
@@ -421,25 +419,42 @@ Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient(
return Status::OK();
}
-bool FunctionLibraryRuntimeImpl::IsLocalTarget(
- const InstantiateOptions& options) {
+bool FunctionLibraryRuntimeImpl::IsLocalTarget(const AttrSlice& attrs) {
if (device_ == nullptr) return true;
- if (options.target.empty()) return true;
+ string target = ProcessFunctionLibraryRuntime::ObtainFunctionTarget(attrs);
+ if (target.empty()) return true;
Device* target_device;
- if (!device_mgr_->LookupDevice(options.target, &target_device).ok()) {
+ if (!device_mgr_->LookupDevice(target, &target_device).ok()) {
return false;
}
return target_device == device_;
}
-Status FunctionLibraryRuntimeImpl::Instantiate(
- const string& function_name, AttrSlice attrs,
- const InstantiateOptions& options, Handle* handle) {
- if (!IsLocalTarget(options)) {
- return parent_->Instantiate(function_name, attrs, options, handle);
+AttrValueMap FunctionLibraryRuntimeImpl::FixAttrs(const AttrSlice& attrs) {
+ AttrValueMap value_map;
+ for (auto it : attrs) {
+ value_map[it.first] = it.second;
+ }
+ if (attrs.Find("_target") != nullptr) {
+ return value_map;
+ }
+ AttrValue v;
+ v.set_s(device_name_);
+ AddAttr("_target", v, &value_map);
+ return value_map;
+}
+
+Status FunctionLibraryRuntimeImpl::Instantiate(const string& function_name,
+ AttrSlice attrs,
+ Handle* handle) {
+ AttrValueMap value_map = FixAttrs(attrs);
+ AttrSlice new_attrs(&value_map);
+
+ if (!IsLocalTarget(new_attrs)) {
+ return parent_->Instantiate(function_name, new_attrs, handle);
}
- const string key = Canonicalize(function_name, attrs, options);
+ const string key = Canonicalize(function_name, new_attrs);
*handle = parent_->GetHandle(key);
if (*handle != kInvalidHandle) {
return Status::OK();
@@ -448,7 +463,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
Status s;
FunctionBody* fbody = nullptr;
if (function_name == kGradientOp) {
- const AttrValue* f = attrs.Find(kFuncAttr);
+ const AttrValue* f = new_attrs.Find(kFuncAttr);
if (f == nullptr) {
return errors::InvalidArgument("SymbolicGradient is missing attr: f");
}
@@ -458,7 +473,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
}
const string grad = lib_def_->FindGradient(func.name());
if (!grad.empty()) {
- return Instantiate(grad, AttrSlice(&func.attr()), options, handle);
+ return Instantiate(grad, AttrSlice(&func.attr()), handle);
}
TF_RETURN_IF_ERROR(InstantiateSymbolicGradient(func, &fbody));
} else {
@@ -466,7 +481,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
if (fdef == nullptr) {
return errors::NotFound("Function ", function_name, " is not defined.");
}
- TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, attrs, &fbody));
+ TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, new_attrs, &fbody));
}
{
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 2dacacea7b..d4181ff48c 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -191,14 +191,11 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
Status Instantiate(FunctionLibraryRuntime* flr, const string& name,
test::function::Attrs attrs,
FunctionLibraryRuntime::Handle* handle) {
- return flr->Instantiate(name, attrs, handle);
- }
-
- Status Instantiate(FunctionLibraryRuntime* flr, const string& name,
- test::function::Attrs attrs,
- const FunctionLibraryRuntime::InstantiateOptions& options,
- FunctionLibraryRuntime::Handle* handle) {
- return flr->Instantiate(name, attrs, options, handle);
+ Status status = flr->Instantiate(name, attrs, handle);
+ if (!status.ok()) {
+ return status;
+ }
+ return Status::OK();
}
Status InstantiateAndRun(FunctionLibraryRuntime* flr, const string& name,
@@ -1091,7 +1088,8 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
Init({test::function::FindDevice()});
FunctionLibraryRuntime::Handle handle;
- TF_CHECK_OK(Instantiate(flr0_, "FindDevice", {}, {"/device:CPU:1"}, &handle));
+ TF_CHECK_OK(Instantiate(flr0_, "FindDevice", {{"_target", "/device:CPU:1"}},
+ &handle));
Tensor y;
FunctionLibraryRuntime::Options opts;
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc
index 12947e284a..53a14121d4 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc
@@ -88,6 +88,16 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
std::move(custom_kernel_creator), nullptr /* cluster_flr */) {}
/* static */
+string ProcessFunctionLibraryRuntime::ObtainFunctionTarget(
+ const AttrSlice& attrs) {
+ const AttrValue* value;
+ if (!attrs.Find("_target", &value).ok()) {
+ return "";
+ }
+ return DeviceNameUtils::CanonicalizeDeviceName(value->s());
+}
+
+/* static */
Status ProcessFunctionLibraryRuntime::SendTensors(
const string& source_device, const string& target_device,
const string& key_prefix, int64 src_incarnation,
@@ -230,23 +240,22 @@ string ProcessFunctionLibraryRuntime::GetDeviceName(
Status ProcessFunctionLibraryRuntime::Instantiate(
const string& function_name, AttrSlice attrs,
- const FunctionLibraryRuntime::InstantiateOptions& options,
FunctionLibraryRuntime::Handle* handle) {
*handle = kInvalidHandle;
- FunctionLibraryRuntime* flr = GetFLR(options.target);
+ string target = ObtainFunctionTarget(attrs);
+ FunctionLibraryRuntime* flr = GetFLR(target);
if (flr != nullptr) {
- return flr->Instantiate(function_name, attrs, options, handle);
+ return flr->Instantiate(function_name, attrs, handle);
}
if (parent_ == nullptr) {
return errors::Internal(
- "Currently don't support instantiating functions on device: ",
- options.target);
+ "Currently don't support instantiating functions on device: ", target);
}
FunctionLibraryRuntime::Handle cluster_handle;
- TF_RETURN_IF_ERROR(parent_->Instantiate(function_name, *lib_def_, attrs,
- options, &cluster_handle));
+ TF_RETURN_IF_ERROR(
+ parent_->Instantiate(function_name, *lib_def_, attrs, &cluster_handle));
string function_key = Canonicalize(function_name, attrs);
- *handle = AddHandle(function_key, options.target, cluster_handle);
+ *handle = AddHandle(function_key, target, cluster_handle);
return Status::OK();
}
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h
index 38003b7726..3aa7b87286 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.h
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.h
@@ -53,6 +53,11 @@ class ProcessFunctionLibraryRuntime {
const OptimizerOptions& optimizer_options,
CustomKernelCreator custom_kernel_creator);
+ // Given a list of attrs on a function, extracts the "_target" attribute which
+ // indicates which device to run the function on. If it can't find the _target
+ // attribute, returns "". Canonicalizes the device name.
+ static string ObtainFunctionTarget(const AttrSlice& attrs);
+
// Sends `tensors_to_send` from `source_device` to `target_device` using
// `rendezvous`. `key_prefix` is used as a prefix for the keys sent to the
// Rendezvous. `device_context` should be the DeviceContext of the device
@@ -116,7 +121,6 @@ class ProcessFunctionLibraryRuntime {
// Allows for function_name to be instantiated on different devices
// as specified in attrs.
Status Instantiate(const string& function_name, AttrSlice attrs,
- const FunctionLibraryRuntime::InstantiateOptions& options,
FunctionLibraryRuntime::Handle* handle);
// Delegates to the local FLR that owns state corresponding to `handle` and
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
index f11b7a851f..270e46dfe9 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
@@ -49,12 +49,10 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
}
Status Run(const string& name, FunctionLibraryRuntime::Options opts,
- test::function::Attrs attrs,
- const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
- const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
+ test::function::Attrs attrs, const std::vector<Tensor>& args,
+ std::vector<Tensor*> rets) {
FunctionLibraryRuntime::Handle handle;
- Status status =
- proc_flr_->Instantiate(name, attrs, instantiate_opts, &handle);
+ Status status = proc_flr_->Instantiate(name, attrs, &handle);
if (!status.ok()) {
return status;
}
@@ -144,6 +142,21 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) {
rendezvous_->Unref();
}
+TEST_F(ProcessFunctionLibraryRuntimeTest, ObtainFunctionTarget) {
+ AttrSlice empty_attrs;
+ string target =
+ ProcessFunctionLibraryRuntime::ObtainFunctionTarget(empty_attrs);
+ EXPECT_EQ("", target);
+
+ AttrValueMap attr_values;
+ AttrValue v;
+ v.set_s("/job:a/replica:0/task:0/cpu:1");
+ AddAttr("_target", v, &attr_values);
+ AttrSlice attrs(&attr_values);
+ target = ProcessFunctionLibraryRuntime::ObtainFunctionTarget(attrs);
+ EXPECT_EQ("/job:a/replica:0/task:0/device:CPU:1", target);
+}
+
TEST_F(ProcessFunctionLibraryRuntimeTest, GetDeviceIncarnation) {
Init({});
int64 incarnation;
@@ -165,8 +178,10 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCall) {
opts.remote_execution = true;
auto x = test::AsTensor<float>({1, 2, 3, 4});
Tensor y;
- TF_CHECK_OK(Run("XTimesTwo", opts, {{"T", DT_FLOAT}},
- {"/job:a/replica:0/task:0/cpu:0"}, {x}, {&y}));
+ TF_CHECK_OK(
+ Run("XTimesTwo", opts,
+ {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
+ {&y}));
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
rendezvous_->Unref();
}
@@ -178,8 +193,8 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCallFindDevice) {
opts.rendezvous = rendezvous_;
opts.remote_execution = true;
Tensor y;
- TF_CHECK_OK(
- Run("FindDevice", opts, {}, {"/job:a/replica:0/task:0/cpu:0"}, {}, {&y}));
+ TF_CHECK_OK(Run("FindDevice", opts,
+ {{"_target", "/job:a/replica:0/task:0/cpu:0"}}, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:0"},
TensorShape({})));
@@ -194,11 +209,15 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceXTimes) {
opts.rendezvous = rendezvous_;
opts.remote_execution = true;
Tensor y;
- TF_CHECK_OK(Run("XTimesTwo", opts, {{"T", DT_FLOAT}},
- {"/job:a/replica:0/task:0/cpu:0"}, {x}, {&y}));
+ TF_CHECK_OK(
+ Run("XTimesTwo", opts,
+ {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
+ {&y}));
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
- TF_CHECK_OK(Run("XTimesFour", opts, {{"T", DT_FLOAT}},
- {"/job:a/replica:0/task:0/cpu:0"}, {x}, {&y}));
+ TF_CHECK_OK(
+ Run("XTimesFour", opts,
+ {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
+ {&y}));
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({4, 8, 12, 16}));
rendezvous_->Unref();
}
@@ -210,13 +229,13 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceFindDevice) {
opts.rendezvous = rendezvous_;
opts.remote_execution = true;
Tensor y;
- TF_CHECK_OK(
- Run("FindDevice", opts, {}, {"/job:a/replica:0/task:0/cpu:1"}, {}, {&y}));
+ TF_CHECK_OK(Run("FindDevice", opts,
+ {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:1"},
TensorShape({})));
- TF_CHECK_OK(
- Run("FindDevice", opts, {}, {"/job:a/replica:0/task:0/cpu:1"}, {}, {&y}));
+ TF_CHECK_OK(Run("FindDevice", opts,
+ {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:1"},
TensorShape({})));
@@ -230,13 +249,11 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) {
opts.rendezvous = rendezvous_;
opts.remote_execution = true;
Tensor y;
- TF_CHECK_OK(Run("FindDevice", opts, {},
- {"/job:a/replica:0/task:0/device:CPU:0"}, {}, {&y}));
+ TF_CHECK_OK(Run("FindDevice", opts, {{"_target", "/cpu:0"}}, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:0"},
TensorShape({})));
- TF_CHECK_OK(Run("FindDevice", opts, {},
- {"/job:a/replica:0/task:0/device:CPU:1"}, {}, {&y}));
+ TF_CHECK_OK(Run("FindDevice", opts, {{"_target", "/cpu:1"}}, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:1"},
TensorShape({})));
diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc
index 3a8d591236..d84b69d06b 100644
--- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc
+++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc
@@ -26,10 +26,10 @@ namespace tensorflow {
/* static */
Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph(
- const OpDef& sig, AttrSlice attrs,
- const FunctionLibraryRuntime::InstantiateOptions& options, GraphDef* g,
+ const OpDef& sig, AttrSlice attrs, GraphDef* g,
std::vector<string>* send_keys, std::vector<string>* recv_keys) {
- const string& target = options.target;
+ const string& target =
+ ProcessFunctionLibraryRuntime::ObtainFunctionTarget(attrs);
// Construct recv nodes for each input argument.
int i = 0;
for (const auto& in : sig.input_arg()) {
@@ -119,16 +119,16 @@ ClusterFunctionLibraryRuntime::~ClusterFunctionLibraryRuntime() {
Status ClusterFunctionLibraryRuntime::Instantiate(
const string& function_name, const FunctionLibraryDefinition& lib_def,
- AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options,
- FunctionLibraryRuntime::LocalHandle* handle) {
- WorkerInterface* wi =
- worker_session_->worker_cache->CreateWorker(options.target);
+ AttrSlice attrs, FunctionLibraryRuntime::LocalHandle* handle) {
+ const string& target =
+ ProcessFunctionLibraryRuntime::ObtainFunctionTarget(attrs);
+ WorkerInterface* wi = worker_session_->worker_cache->CreateWorker(target);
if (wi == nullptr) {
std::vector<string> workers;
worker_session_->worker_cache->ListWorkers(&workers);
return errors::InvalidArgument(
- "Could not find worker with target: ", options.target,
+ "Could not find worker with target: ", target,
" Available workers: ", str_util::Join(workers, ", "));
}
@@ -137,8 +137,8 @@ Status ClusterFunctionLibraryRuntime::Instantiate(
const OpDef& sig = fdef->signature();
GraphDef gdef;
std::vector<string> send_keys, recv_keys;
- TF_RETURN_IF_ERROR(ConstructFunctionGraph(sig, attrs, options, &gdef,
- &send_keys, &recv_keys));
+ TF_RETURN_IF_ERROR(
+ ConstructFunctionGraph(sig, attrs, &gdef, &send_keys, &recv_keys));
*gdef.mutable_library() = lib_def.ToProto();
RegisterGraphRequest req;
@@ -152,8 +152,8 @@ Status ClusterFunctionLibraryRuntime::Instantiate(
mutex_lock l(mu_);
*handle = function_data_.size();
- function_data_.push_back(FunctionData(resp.graph_handle(), options.target, wi,
- send_keys, recv_keys));
+ function_data_.push_back(
+ FunctionData(resp.graph_handle(), target, wi, send_keys, recv_keys));
return Status::OK();
}
diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h
index 3deb80dff7..dd4ea68f57 100644
--- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h
+++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h
@@ -34,7 +34,6 @@ class ClusterFunctionLibraryRuntime : public DistributedFunctionLibraryRuntime {
Status Instantiate(const string& function_name,
const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
- const FunctionLibraryRuntime::InstantiateOptions& options,
FunctionLibraryRuntime::LocalHandle* handle) override;
void Run(const FunctionLibraryRuntime::Options& opts,
@@ -43,10 +42,10 @@ class ClusterFunctionLibraryRuntime : public DistributedFunctionLibraryRuntime {
FunctionLibraryRuntime::DoneCallback done) override;
private:
- static Status ConstructFunctionGraph(
- const OpDef& sig, AttrSlice attrs,
- const FunctionLibraryRuntime::InstantiateOptions& options, GraphDef* g,
- std::vector<string>* send_keys, std::vector<string>* recv_keys);
+ static Status ConstructFunctionGraph(const OpDef& sig, AttrSlice attrs,
+ GraphDef* g,
+ std::vector<string>* send_keys,
+ std::vector<string>* recv_keys);
friend class ClusterFunctionLibraryRuntimeTest;
mutable mutex mu_;
diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc b/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc
index 98512bce18..6dd8b9ec73 100644
--- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc
+++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc
@@ -47,31 +47,30 @@ class ClusterFunctionLibraryRuntimeTest : public ::testing::Test {
new ClusterFunctionLibraryRuntime(worker_session_.get()));
}
- Status ConstructFunctionGraphHelper(
- const OpDef& sig, test::function::Attrs attrs,
- const FunctionLibraryRuntime::InstantiateOptions& options, GraphDef* g,
- std::vector<string>* send_keys, std::vector<string>* recv_keys) {
+ Status ConstructFunctionGraphHelper(const OpDef& sig,
+ test::function::Attrs attrs, GraphDef* g,
+ std::vector<string>* send_keys,
+ std::vector<string>* recv_keys) {
return ClusterFunctionLibraryRuntime::ConstructFunctionGraph(
- sig, attrs, options, g, send_keys, recv_keys);
+ sig, attrs, g, send_keys, recv_keys);
}
Status Instantiate(const string& function_name,
const FunctionLibraryDefinition& lib_def,
test::function::Attrs attrs,
- const FunctionLibraryRuntime::InstantiateOptions& options,
FunctionLibraryRuntime::LocalHandle* local_handle) {
- return cluster_flr_->Instantiate(function_name, lib_def, attrs, options,
+ return cluster_flr_->Instantiate(function_name, lib_def, attrs,
local_handle);
}
- Status InstantiateAndRun(
- const string& function_name, const FunctionLibraryDefinition& lib_def,
- test::function::Attrs attrs,
- const FunctionLibraryRuntime::InstantiateOptions& options,
- const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
+ Status InstantiateAndRun(const string& function_name,
+ const FunctionLibraryDefinition& lib_def,
+ test::function::Attrs attrs,
+ const std::vector<Tensor>& args,
+ std::vector<Tensor*> rets) {
FunctionLibraryRuntime::LocalHandle handle;
- TF_RETURN_IF_ERROR(cluster_flr_->Instantiate(function_name, lib_def, attrs,
- options, &handle));
+ TF_RETURN_IF_ERROR(
+ cluster_flr_->Instantiate(function_name, lib_def, attrs, &handle));
Notification done;
FunctionLibraryRuntime::Options opts;
@@ -104,9 +103,9 @@ TEST_F(ClusterFunctionLibraryRuntimeTest, ConstructFunctionGraph) {
GraphDef actual;
std::vector<string> send_keys, recv_keys;
TF_CHECK_OK(ConstructFunctionGraphHelper(
- test::function::Swap().signature(), {{"T", DT_FLOAT}},
- {"/job:a/replica:0/task:0/device:CPU:0"}, &actual, &send_keys,
- &recv_keys));
+ test::function::Swap().signature(),
+ {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, &actual,
+ &send_keys, &recv_keys));
GraphDef expected;
protobuf::TextFormat::ParseFromString(R"(
node {
@@ -206,7 +205,7 @@ node {
attr {
key: "_target"
value {
- s: "/job:a/replica:0/task:0/device:CPU:0"
+ s: "/job:a/replica:0/task:0/cpu:0"
}
}
}
@@ -310,9 +309,9 @@ TEST_F(ClusterFunctionLibraryRuntimeTest, DISABLED_InstantiateAndRun) {
Tensor y;
auto x = test::AsTensor<int32>({1, 2, 3, 4});
- TF_EXPECT_OK(InstantiateAndRun("XTimesTwoInt32", lib_def, {},
- {"/job:localhost/replica:0/task:1/cpu:0"}, {x},
- {&y}));
+ TF_EXPECT_OK(InstantiateAndRun(
+ "XTimesTwoInt32", lib_def,
+ {{"_target", "/job:localhost/replica:0/task:1/cpu:0"}}, {x}, {&y}));
test::ExpectTensorEqual<int32>(y, test::AsTensor<int32>({2, 4, 6, 8}));
}
@@ -325,9 +324,10 @@ TEST_F(ClusterFunctionLibraryRuntimeTest,
Tensor y1, y2;
auto x1 = test::AsTensor<float>({1, 2, 3, 4});
auto x2 = test::AsTensor<float>({4, 3, 2, 1});
- TF_EXPECT_OK(InstantiateAndRun("Swap", lib_def, {{"T", DT_FLOAT}},
- {"/job:localhost/replica:0/task:1/cpu:0"},
- {x1, x2}, {&y1, &y2}));
+ TF_EXPECT_OK(InstantiateAndRun(
+ "Swap", lib_def,
+ {{"T", DT_FLOAT}, {"_target", "/job:localhost/replica:0/task:1/cpu:0"}},
+ {x1, x2}, {&y1, &y2}));
test::ExpectTensorEqual<float>(y1, test::AsTensor<float>({4, 3, 2, 1}));
test::ExpectTensorEqual<float>(y2, test::AsTensor<float>({1, 2, 3, 4}));
}
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index 783015481a..d757e962e5 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -795,17 +795,12 @@ uint64 FunctionDefHash(const FunctionDef& fdef) {
return h;
}
-string Canonicalize(const string& funcname, AttrSlice attrs,
- const FunctionLibraryRuntime::InstantiateOptions& options) {
+string Canonicalize(const string& funcname, AttrSlice attrs) {
std::vector<string> entries;
- entries.reserve(options.target.empty() ? attrs.size() : (attrs.size() + 1));
+ entries.reserve(attrs.size());
for (auto p : attrs) {
entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
}
- if (!options.target.empty()) {
- entries.push_back(
- strings::StrCat("_target", "=", str_util::CEscape(options.target)));
- }
std::sort(entries.begin(), entries.end());
return strings::StrCat(funcname, "[", str_util::Join(entries, ","), "]");
}
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index e5d0e49dbb..1a579ab631 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -234,6 +234,15 @@ bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2);
// same.
uint64 FunctionDefHash(const FunctionDef& fdef);
+// Returns a canonicalized string for the instantiation of the
+// function of the given "name" and attributes "attrs".
+//
+// The returned string is guaranteed to be stable within one address
+// space. But it may be change as the implementation
+// evolves. Therefore, it should not be persisted or compared across
+// address spaces.
+string Canonicalize(const string& funcname, AttrSlice attrs);
+
class CallFrameInterface {
public:
virtual ~CallFrameInterface() {}
@@ -409,23 +418,9 @@ class FunctionLibraryRuntime {
//
// Returns OK and fills in "handle" if the instantiation succeeds.
// Otherwise returns an error and "handle" is undefined.
- struct InstantiateOptions {
- // The canonical device name of the device on which the function
- // should be instantiated. If empty, the function will be
- // instantiated on the local device.
- string target;
-
- // TODO(b/70352992): Add an API for allowing a different
- // FunctionLibraryDefinition to be overlaid on this runtime's library.
- };
typedef uint64 Handle;
virtual Status Instantiate(const string& function_name, AttrSlice attrs,
- const InstantiateOptions& options,
Handle* handle) = 0;
- Status Instantiate(const string& function_name, AttrSlice attrs,
- Handle* handle) {
- return Instantiate(function_name, attrs, {}, handle);
- }
// Releases state associated with the handle.
virtual Status ReleaseHandle(Handle handle) = 0;
@@ -507,19 +502,6 @@ class FunctionLibraryRuntime {
typedef uint64 LocalHandle;
};
-// Returns a canonicalized string for the instantiation of the
-// function of the given "name", attributes "attrs", and "options".
-//
-// The returned string is guaranteed to be stable within one address
-// space. But it may be change as the implementation
-// evolves. Therefore, it should not be persisted or compared across
-// address spaces.
-string Canonicalize(const string& funcname, AttrSlice attrs,
- const FunctionLibraryRuntime::InstantiateOptions& options);
-inline string Canonicalize(const string& funcname, AttrSlice attrs) {
- return Canonicalize(funcname, attrs, {});
-}
-
const FunctionLibraryRuntime::Handle kInvalidHandle = -1;
const FunctionLibraryRuntime::LocalHandle kInvalidLocalHandle = -1;
typedef std::function<Status(FunctionLibraryRuntime*, const NodeDef&,
@@ -532,11 +514,10 @@ class DistributedFunctionLibraryRuntime {
virtual ~DistributedFunctionLibraryRuntime() {}
// The _target attr in attrs determines where the function is instantiated.
- virtual Status Instantiate(
- const string& function_name, const FunctionLibraryDefinition& lib_def,
- AttrSlice attrs,
- const FunctionLibraryRuntime::InstantiateOptions& options,
- FunctionLibraryRuntime::LocalHandle* handle) = 0;
+ virtual Status Instantiate(const string& function_name,
+ const FunctionLibraryDefinition& lib_def,
+ AttrSlice attrs,
+ FunctionLibraryRuntime::LocalHandle* handle) = 0;
// opts.runner isn't used for execution.
virtual void Run(const FunctionLibraryRuntime::Options& opts,
diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc
index facac10f66..f469f41e06 100644
--- a/tensorflow/core/kernels/function_ops.cc
+++ b/tensorflow/core/kernels/function_ops.cc
@@ -296,19 +296,21 @@ class RemoteCallOp : public AsyncOpKernel {
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
const Tensor* target;
OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done);
+ AttrValueMap attr_values = func_.attr();
+ AttrValue v;
const string& target_device =
DeviceNameUtils::CanonicalizeDeviceName(target->scalar<string>()());
+ v.set_s(target_device);
+ AddAttr("_target", v, &attr_values);
FunctionLibraryRuntime* lib = ctx->function_library();
OP_REQUIRES_ASYNC(ctx, lib != nullptr,
errors::Internal("No function library is provided."),
done);
- AttrValueMap attr_values = func_.attr();
FunctionLibraryRuntime::Handle handle;
- OP_REQUIRES_OK_ASYNC(ctx,
- lib->Instantiate(func_.name(), AttrSlice(&attr_values),
- {target_device}, &handle),
- done);
+ OP_REQUIRES_OK_ASYNC(
+ ctx, lib->Instantiate(func_.name(), AttrSlice(&attr_values), &handle),
+ done);
OpInputList arguments;
OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done);
diff --git a/tensorflow/python/debug/wrappers/framework.py b/tensorflow/python/debug/wrappers/framework.py
index 00c38eac10..909150eb6a 100644
--- a/tensorflow/python/debug/wrappers/framework.py
+++ b/tensorflow/python/debug/wrappers/framework.py
@@ -154,8 +154,7 @@ class OnSessionInitRequest(object):
sess: A tensorflow Session object.
"""
- _check_type(sess, (session.SessionInterface,
- monitored_session.MonitoredSession))
+ _check_type(sess, (session.BaseSession, monitored_session.MonitoredSession))
self.session = sess
@@ -359,8 +358,7 @@ class BaseDebugWrapperSession(session.SessionInterface):
NotImplementedError: If a non-DirectSession sess object is received.
"""
- _check_type(sess, (session.SessionInterface,
- monitored_session.MonitoredSession))
+ _check_type(sess, (session.BaseSession, monitored_session.MonitoredSession))
# The session being wrapped.
self._sess = sess
diff --git a/tensorflow/python/debug/wrappers/framework_test.py b/tensorflow/python/debug/wrappers/framework_test.py
index 5240e0d0ad..73e08ce7d5 100644
--- a/tensorflow/python/debug/wrappers/framework_test.py
+++ b/tensorflow/python/debug/wrappers/framework_test.py
@@ -271,9 +271,9 @@ class DebugWrapperSessionTest(test_util.TensorFlowTestCase):
def testSessionInitInvalidSessionType(self):
"""Attempt to wrap a non-Session-type object should cause an exception."""
- sess = "not a session"
+ wrapper = TestDebugWrapperSessionBadAction(self._sess)
with self.assertRaisesRegexp(TypeError, "Expected type .*; got type .*"):
- TestDebugWrapperSessionBadAction(sess)
+ TestDebugWrapperSessionBadAction(wrapper)
def testSessionInitBadActionValue(self):
with self.assertRaisesRegexp(