diff options
author | Rohan Jain <rohanj@google.com> | 2017-09-08 13:30:17 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-08 13:34:27 -0700 |
commit | 450c3b5626030bd02ef6c86f8387cb2ca213dfe5 (patch) | |
tree | 13994c6c1625084e99ae0f37d9b0833d9c468539 /tensorflow/core/common_runtime/function_test.cc | |
parent | 82cc6529f4c8d23013096bb5f79514247aa73433 (diff) |
Using rendezvous manager to pass args / rets between devices during function remote execution. This enables CPU->GPU remote device executions now.
PiperOrigin-RevId: 168038285
Diffstat (limited to 'tensorflow/core/common_runtime/function_test.cc')
-rw-r--r-- | tensorflow/core/common_runtime/function_test.cc | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index a9f06c4df0..7eac1674e7 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function_testlib.h" +#include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/op.h" @@ -155,6 +156,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { } Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle, + FunctionLibraryRuntime::Options opts, const std::vector<Tensor>& args, std::vector<Tensor*> rets) { std::atomic<int32> call_count(0); std::function<void(std::function<void()>)> runner = @@ -164,7 +166,6 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { }; Notification done; - FunctionLibraryRuntime::Options opts; opts.runner = &runner; std::vector<Tensor> out; Status status; @@ -205,7 +206,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { if (!status.ok()) { return status; } - return Run(flr, handle, args, std::move(rets)); + FunctionLibraryRuntime::Options opts; + return Run(flr, handle, opts, args, std::move(rets)); } std::unique_ptr<Graph> GetFuncBody(FunctionLibraryRuntime* flr, @@ -963,15 +965,21 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) { {{"_target", "/job:localhost/replica:0/task:0/cpu:1"}}, &handle)); Tensor y; + FunctionLibraryRuntime::Options opts; + opts.rendezvous = new IntraProcessRendezvous(device_mgr_.get()); + opts.source_device = "/device:CPU:1"; // Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1. - TF_CHECK_OK(Run(flr1_, handle, {}, {&y})); + TF_CHECK_OK(Run(flr1_, handle, opts, {}, {&y})); test::ExpectTensorEqual<string>( y, test::AsTensor<string>({"/job:localhost/replica:0/task:0/cpu:1"}, TensorShape({}))); - TF_CHECK_OK(Run(flr2_, handle, {}, {&y})); + opts.remote_execution = true; + opts.source_device = "/job:localhost/replica:0/task:0/cpu:2"; + TF_CHECK_OK(Run(flr2_, handle, opts, {}, {&y})); test::ExpectTensorEqual<string>( y, test::AsTensor<string>({"/job:localhost/replica:0/task:0/cpu:1"}, TensorShape({}))); + opts.rendezvous->Unref(); } namespace { |