aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/common_runtime/graph_runner.cc8
-rw-r--r--tensorflow/core/common_runtime/graph_runner_test.cc27
2 files changed, 34 insertions, 1 deletions
diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc
index d4dc8f0057..514a63590b 100644
--- a/tensorflow/core/common_runtime/graph_runner.cc
+++ b/tensorflow/core/common_runtime/graph_runner.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
@@ -175,8 +176,13 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
Rendezvous::ParsedKey parsed;
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(output_key, &parsed));
bool is_dead;
+ Tensor output_tensor;
TF_RETURN_IF_ERROR(
- rendez->Recv(parsed, Rendezvous::Args(), &(*outputs)[i], &is_dead));
+ rendez->Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead));
+ // Does a deep copy so that ownership of the tensor isn't tied to the
+ // allocator of the cpu device we created above. The allocator could be
+ // deleted along with the device.
+ (*outputs)[i] = tensor::DeepCopy(output_tensor);
}
return Status::OK();
diff --git a/tensorflow/core/common_runtime/graph_runner_test.cc b/tensorflow/core/common_runtime/graph_runner_test.cc
index ccb44af0ec..e969ee8df7 100644
--- a/tensorflow/core/common_runtime/graph_runner_test.cc
+++ b/tensorflow/core/common_runtime/graph_runner_test.cc
@@ -53,6 +53,33 @@ TEST(GraphRunnerTest, SingleConst) {
ExpectEqual(42.0f, outputs[0].scalar<float>()());
}
+// If not using DeepCopy, and the allocator is deleted with the cpu-device,
+// this test will seg-fault.
+TEST(GraphRunnerTest, DeepCopy) {
+ Scope root = Scope::NewRootScope();
+ auto p1 = ops::Placeholder(root.WithOpName("p1"), DT_FLOAT);
+ auto p2 = ops::Placeholder(root.WithOpName("p2"), DT_FLOAT);
+ auto add = ops::Add(root.WithOpName("add"), p1, p2);
+
+ Tensor p1_data(DT_FLOAT, TensorShape({}));
+ Tensor p2_data(DT_FLOAT, TensorShape({}));
+ p1_data.scalar<float>()() = 1.0f;
+ p2_data.scalar<float>()() = 2.0f;
+ std::vector<std::pair<string, Tensor>> inputs = {{"p1:0", p1_data},
+ {"p2:0", p2_data}};
+
+ // Create and destroy the GraphRunner, and ensure that the outputs are
+ // consumable beyond the lifetime of GraphRunner.
+ std::vector<Tensor> outputs;
+ {
+ GraphRunner graph_runner(Env::Default());
+ Status s =
+ graph_runner.Run(root.graph(), nullptr, inputs, {"add:0"}, &outputs);
+ TF_ASSERT_OK(s);
+ }
+ ExpectEqual(3.0f, outputs[0].scalar<float>()());
+}
+
TEST(GraphRunnerTest, MultiFetchConst) {
Scope root = Scope::NewRootScope();
auto c = ops::Const(root, 42.0f);