aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/function_test.cc
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2017-09-08 13:30:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-08 13:34:27 -0700
commit450c3b5626030bd02ef6c86f8387cb2ca213dfe5 (patch)
tree13994c6c1625084e99ae0f37d9b0833d9c468539 /tensorflow/core/common_runtime/function_test.cc
parent82cc6529f4c8d23013096bb5f79514247aa73433 (diff)
Using rendezvous manager to pass args / rets between devices during function remote execution. This enables CPU->GPU remote device executions now.
PiperOrigin-RevId: 168038285
Diffstat (limited to 'tensorflow/core/common_runtime/function_test.cc')
-rw-r--r--tensorflow/core/common_runtime/function_test.cc16
1 files changed, 12 insertions, 4 deletions
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index a9f06c4df0..7eac1674e7 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/function_testlib.h"
+#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/op.h"
@@ -155,6 +156,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
}
Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
+ FunctionLibraryRuntime::Options opts,
const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
std::atomic<int32> call_count(0);
std::function<void(std::function<void()>)> runner =
@@ -164,7 +166,6 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
};
Notification done;
- FunctionLibraryRuntime::Options opts;
opts.runner = &runner;
std::vector<Tensor> out;
Status status;
@@ -205,7 +206,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
if (!status.ok()) {
return status;
}
- return Run(flr, handle, args, std::move(rets));
+ FunctionLibraryRuntime::Options opts;
+ return Run(flr, handle, opts, args, std::move(rets));
}
std::unique_ptr<Graph> GetFuncBody(FunctionLibraryRuntime* flr,
@@ -963,15 +965,21 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
{{"_target", "/job:localhost/replica:0/task:0/cpu:1"}}, &handle));
Tensor y;
+ FunctionLibraryRuntime::Options opts;
+ opts.rendezvous = new IntraProcessRendezvous(device_mgr_.get());
+ opts.source_device = "/device:CPU:1";
// Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1.
- TF_CHECK_OK(Run(flr1_, handle, {}, {&y}));
+ TF_CHECK_OK(Run(flr1_, handle, opts, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:localhost/replica:0/task:0/cpu:1"},
TensorShape({})));
- TF_CHECK_OK(Run(flr2_, handle, {}, {&y}));
+ opts.remote_execution = true;
+ opts.source_device = "/job:localhost/replica:0/task:0/cpu:2";
+ TF_CHECK_OK(Run(flr2_, handle, opts, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:localhost/replica:0/task:0/cpu:1"},
TensorShape({})));
+ opts.rendezvous->Unref();
}
namespace {