diff options
-rw-r--r-- | tensorflow/core/common_runtime/graph_runner.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/graph_runner_test.cc | 27 |
2 files changed, 34 insertions, 1 deletions
diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc index d4dc8f0057..514a63590b 100644 --- a/tensorflow/core/common_runtime/graph_runner.cc +++ b/tensorflow/core/common_runtime/graph_runner.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_util.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" @@ -175,8 +176,13 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library, Rendezvous::ParsedKey parsed; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(output_key, &parsed)); bool is_dead; + Tensor output_tensor; TF_RETURN_IF_ERROR( - rendez->Recv(parsed, Rendezvous::Args(), &(*outputs)[i], &is_dead)); + rendez->Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead)); + // Does a deep copy so that ownership of the tensor isn't tied to the + // allocator of the cpu device we created above. The allocator could be + // deleted along with the device. + (*outputs)[i] = tensor::DeepCopy(output_tensor); } return Status::OK(); diff --git a/tensorflow/core/common_runtime/graph_runner_test.cc b/tensorflow/core/common_runtime/graph_runner_test.cc index ccb44af0ec..e969ee8df7 100644 --- a/tensorflow/core/common_runtime/graph_runner_test.cc +++ b/tensorflow/core/common_runtime/graph_runner_test.cc @@ -53,6 +53,33 @@ TEST(GraphRunnerTest, SingleConst) { ExpectEqual(42.0f, outputs[0].scalar<float>()()); } +// If not using DeepCopy, and the allocator is deleted with the cpu-device, +// this test will seg-fault. +TEST(GraphRunnerTest, DeepCopy) { + Scope root = Scope::NewRootScope(); + auto p1 = ops::Placeholder(root.WithOpName("p1"), DT_FLOAT); + auto p2 = ops::Placeholder(root.WithOpName("p2"), DT_FLOAT); + auto add = ops::Add(root.WithOpName("add"), p1, p2); + + Tensor p1_data(DT_FLOAT, TensorShape({})); + Tensor p2_data(DT_FLOAT, TensorShape({})); + p1_data.scalar<float>()() = 1.0f; + p2_data.scalar<float>()() = 2.0f; + std::vector<std::pair<string, Tensor>> inputs = {{"p1:0", p1_data}, + {"p2:0", p2_data}}; + + // Create and destroy the GraphRunner, and ensure that the outputs are + // consumable beyond the lifetime of GraphRunner. + std::vector<Tensor> outputs; + { + GraphRunner graph_runner(Env::Default()); + Status s = + graph_runner.Run(root.graph(), nullptr, inputs, {"add:0"}, &outputs); + TF_ASSERT_OK(s); + } + ExpectEqual(3.0f, outputs[0].scalar<float>()()); +} + TEST(GraphRunnerTest, MultiFetchConst) { Scope root = Scope::NewRootScope(); auto c = ops::Const(root, 42.0f); |