aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-02-07 13:44:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-07 13:48:12 -0800
commita271c36b5ead4686b72d972b193bf1f534a92ffd (patch)
tree0e1ed2e06e4e4e0bb4e13a13e681db69b7f80ae5 /tensorflow/core
parentd1eeceb562d6defc4b517ad83cea56a894ff4c98 (diff)
[tf.data] Move the C++ Dataset class implementations to the framework library.
This enables the use of the `DatasetOpKernel` subclasses in custom op library code. A subsequent change will move `tf.contrib.data` kernel implementations to a custom op library. Implementation note: This change moves some classes from "tensorflow/core/graph/..." into the framework library, which does not include any code in "tensorflow/core/common_runtime/...". To break the dependency from "tensorflow/core/framework/dataset.cc" to "tensorflow/core/common_runtime/...", the `GraphDefBuilderToGraph()` method has been split out from the `GraphDefBuilder` class (where it was previously exposed as the `GraphDefBuilder::ToGraph()` utility method) and added to a new "tensorflow/core/graph/graph_def_builder_util.h" module. This method depends on ".../graph/graph_constructor.cc", which depends directly on ".../common_runtime/shape_refiner.h" and indirectly on ".../common_runtime/graph_runner.h". Since this method was used only in tests, these have been updated to point to the new utility method. PiperOrigin-RevId: 184888903
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/BUILD12
-rw-r--r--tensorflow/core/common_runtime/placer_test.cc3
-rw-r--r--tensorflow/core/framework/dataset.cc (renamed from tensorflow/core/kernels/data/dataset.cc)8
-rw-r--r--tensorflow/core/framework/dataset.h6
-rw-r--r--tensorflow/core/graph/algorithm_test.cc5
-rw-r--r--tensorflow/core/graph/graph_def_builder.cc11
-rw-r--r--tensorflow/core/graph/graph_def_builder.h8
-rw-r--r--tensorflow/core/graph/graph_def_builder_test.cc3
-rw-r--r--tensorflow/core/graph/graph_def_builder_util.cc28
-rw-r--r--tensorflow/core/graph/graph_def_builder_util.h35
-rw-r--r--tensorflow/core/graph/subgraph_test.cc3
-rw-r--r--tensorflow/core/kernels/data/BUILD3
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc17
13 files changed, 106 insertions, 36 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 7fade697de..c25aac3acf 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -784,6 +784,7 @@ tf_cuda_library(
"graph/graph.h",
"graph/graph_constructor.h",
"graph/graph_def_builder.h",
+ "graph/graph_def_builder_util.h",
"graph/node_builder.h",
"graph/validate.h",
"graph/while_context.h",
@@ -1718,6 +1719,9 @@ FRAMEWORK_INTERNAL_PRIVATE_HEADERS = [
"platform/variant_coding.h",
"graph/edgeset.h",
"graph/graph.h",
+ "graph/graph_def_builder.h",
+ "graph/node_builder.h",
+ "graph/tensor_id.h",
] + glob(
[
"example/**/*.h",
@@ -1804,6 +1808,9 @@ tf_cuda_library(
] + [
"graph/edgeset.cc",
"graph/graph.cc",
+ "graph/graph_def_builder.cc",
+ "graph/node_builder.cc",
+ "graph/tensor_id.cc",
"graph/while_context.h",
"graph/while_context.cc",
],
@@ -1932,6 +1939,7 @@ GRAPH_HDRS = [
"graph/graph.h",
"graph/graph_constructor.h", # NOTE(mrry): Don't include the .cc since it depends on common_runtime.
"graph/graph_def_builder.h",
+ "graph/graph_def_builder_util.h",
"graph/graph_partition.h",
"graph/mkl_layout_pass.h",
"graph/mkl_tfconversion_pass.h",
@@ -1952,12 +1960,9 @@ tf_cuda_library(
"graph/colors.cc",
"graph/control_flow.cc",
"graph/costmodel.cc",
- "graph/graph_def_builder.cc",
"graph/graph_partition.cc",
- "graph/node_builder.cc",
"graph/optimizer_cse.cc",
"graph/subgraph.cc",
- "graph/tensor_id.cc",
"graph/validate.cc",
],
hdrs = GRAPH_HDRS,
@@ -1986,6 +1991,7 @@ tf_cuda_library(
"common_runtime/shape_refiner.h",
"framework/versions.h",
"graph/graph_constructor.cc", # Depends on common_runtime.
+ "graph/graph_def_builder_util.cc", # Depends on common_runtime.
"public/session.h",
"public/session_options.h",
"public/version.h",
diff --git a/tensorflow/core/common_runtime/placer_test.cc b/tensorflow/core/common_runtime/placer_test.cc
index 02c9cd5313..098024d219 100644
--- a/tensorflow/core/common_runtime/placer_test.cc
+++ b/tensorflow/core/common_runtime/placer_test.cc
@@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -193,7 +194,7 @@ class PlacerTest : public ::testing::Test {
// Builds the given graph, and (if successful) indexes the node
// names for use in placement, and later lookup.
Status BuildGraph(const GraphDefBuilder& builder, Graph* out_graph) {
- TF_RETURN_IF_ERROR(builder.ToGraph(out_graph));
+ TF_RETURN_IF_ERROR(GraphDefBuilderToGraph(builder, out_graph));
nodes_by_name_.clear();
for (Node* node : out_graph->nodes()) {
nodes_by_name_[node->name()] = node->id();
diff --git a/tensorflow/core/kernels/data/dataset.cc b/tensorflow/core/framework/dataset.cc
index d18cb16018..4145ef7bc9 100644
--- a/tensorflow/core/kernels/data/dataset.cc
+++ b/tensorflow/core/framework/dataset.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/kernels/data/dataset.h"
-#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/dataset.h"
+
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/node_builder.h"
@@ -265,10 +265,6 @@ void BinaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
MakeDataset(ctx, input, another_input, output);
}
-Allocator* IteratorContext::allocator(AllocatorAttributes attrs) {
- return params_.lib->device()->GetAllocator(attrs);
-}
-
const char GraphDatasetBase::kDatasetGraphKey[] = "_DATASET_GRAPH";
const char GraphDatasetBase::kDatasetGraphOutputNodeKey[] =
"_DATASET_GRAPH_OUTPUT_NODE";
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 96566c285a..6ab23d92a4 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -274,7 +274,7 @@ class IteratorContext {
std::shared_ptr<const FunctionLibraryDefinition> function_library = nullptr;
// The Allocator to be used to allocate the output of an iterator.
- Allocator* allocator = nullptr;
+ std::function<Allocator*(AllocatorAttributes)> allocator_getter = nullptr;
};
explicit IteratorContext(Params params) : params_(std::move(params)) {}
@@ -301,7 +301,9 @@ class IteratorContext {
void set_lib(FunctionLibraryRuntime* lib) { params_.lib = lib; }
- Allocator* allocator(AllocatorAttributes attrs);
+ Allocator* allocator(AllocatorAttributes attrs) {
+ return params_.allocator_getter(attrs);
+ }
private:
Params params_;
diff --git a/tensorflow/core/graph/algorithm_test.cc b/tensorflow/core/graph/algorithm_test.cc
index 0cdcdb6685..99ced0c0f5 100644
--- a/tensorflow/core/graph/algorithm_test.cc
+++ b/tensorflow/core/graph/algorithm_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/graph/subgraph.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status.h"
@@ -81,7 +82,7 @@ TEST(AlgorithmTest, ReversePostOrder) {
BinaryOp("TestMul", w2, {input, 1}, b.opts().WithName("t3"));
Graph g(OpRegistry::Global());
- TF_ASSERT_OK(b.ToGraph(&g));
+ TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g));
std::vector<Node*> order;
// Test reverse post order:
@@ -139,7 +140,7 @@ TEST(AlgorithmTest, ReversePostOrderStable) {
BinaryOp("TestMul", w1, {input, 1}, b.opts().WithName("t3"));
Graph g(OpRegistry::Global());
- TF_ASSERT_OK(b.ToGraph(&g));
+ TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g));
std::vector<Node*> order;
// Test reverse post order generates expected ordering.
diff --git a/tensorflow/core/graph/graph_def_builder.cc b/tensorflow/core/graph/graph_def_builder.cc
index 33d2021f38..7a58347bd1 100644
--- a/tensorflow/core/graph/graph_def_builder.cc
+++ b/tensorflow/core/graph/graph_def_builder.cc
@@ -17,7 +17,6 @@ limitations under the License.
#include <utility>
-#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -72,16 +71,6 @@ Status GraphDefBuilder::ToGraphDef(GraphDef* graph_def) const {
return status_;
}
-Status GraphDefBuilder::ToGraph(Graph* graph) const {
- if (status_.ok()) {
- GraphDef graph_def;
- graph_.ToGraphDef(&graph_def);
- GraphConstructorOptions opts;
- TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graph_def, graph));
- }
- return status_;
-}
-
string GraphDefBuilder::Options::GetNameForOp(StringPiece op) const {
if (name_.empty()) return graph_->NewName(op);
return name_;
diff --git a/tensorflow/core/graph/graph_def_builder.h b/tensorflow/core/graph/graph_def_builder.h
index a2c0c4d553..776a74c6d8 100644
--- a/tensorflow/core/graph/graph_def_builder.h
+++ b/tensorflow/core/graph/graph_def_builder.h
@@ -161,14 +161,6 @@ class GraphDefBuilder {
// successful, and if so fill *graph_def.
Status ToGraphDef(GraphDef* graph_def) const;
- // Like ToGraphDef(), but converts to a Graph (using the default
- // GraphConstructorOptions).
- // TODO(josh11b): Make this faster; right now it converts
- // Graph->GraphDef->Graph. This cleans up the graph (e.g. adds
- // edges from the source and to the sink node, resolves back edges
- // by name), and makes sure the resulting graph is valid.
- Status ToGraph(Graph* graph) const;
-
// Adds the function and gradient definitions in `fdef_lib` to this graph's op
// registry. Ignores duplicate functions, and returns a bad status if an
// imported function differs from an existing function or op with the same
diff --git a/tensorflow/core/graph/graph_def_builder_test.cc b/tensorflow/core/graph/graph_def_builder_test.cc
index e928c81b45..be3c2be800 100644
--- a/tensorflow/core/graph/graph_def_builder_test.cc
+++ b/tensorflow/core/graph/graph_def_builder_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -34,7 +35,7 @@ TEST(GraphDefBuilderTest, Version) {
// Check version when we convert to a Graph
Graph graph(OpRegistry::Global());
- TF_EXPECT_OK(builder.ToGraph(&graph));
+ TF_EXPECT_OK(GraphDefBuilderToGraph(builder, &graph));
ASSERT_EQ(graph.versions().producer(), TF_GRAPH_DEF_VERSION);
ASSERT_EQ(graph.versions().min_consumer(), TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
diff --git a/tensorflow/core/graph/graph_def_builder_util.cc b/tensorflow/core/graph/graph_def_builder_util.cc
new file mode 100644
index 0000000000..102c72185f
--- /dev/null
+++ b/tensorflow/core/graph/graph_def_builder_util.cc
@@ -0,0 +1,28 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/graph/graph_def_builder_util.h"
+
+#include "tensorflow/core/graph/graph_constructor.h"
+
+namespace tensorflow {
+
+Status GraphDefBuilderToGraph(const GraphDefBuilder& builder, Graph* graph) {
+ GraphDef graph_def;
+ TF_RETURN_IF_ERROR(builder.ToGraphDef(&graph_def));
+ GraphConstructorOptions opts;
+ return ConvertGraphDefToGraph(opts, graph_def, graph);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/graph_def_builder_util.h b/tensorflow/core/graph/graph_def_builder_util.h
new file mode 100644
index 0000000000..4a157e5b71
--- /dev/null
+++ b/tensorflow/core/graph/graph_def_builder_util.h
@@ -0,0 +1,35 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_UTIL_H_
+#define TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_UTIL_H_
+
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+class Graph;
+
+// Converts the `GraphDef` being built by `builder` to a `Graph` and
+// stores it in `*graph`.
+// TODO(josh11b): Make this faster; right now it converts
+// Graph->GraphDef->Graph. This cleans up the graph (e.g. adds
+// edges from the source and to the sink node, resolves back edges
+// by name), and makes sure the resulting graph is valid.
+Status GraphDefBuilderToGraph(const GraphDefBuilder& builder, Graph* graph);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_UTIL_H_
diff --git a/tensorflow/core/graph/subgraph_test.cc b/tensorflow/core/graph/subgraph_test.cc
index fde1ea1743..7219d9812f 100644
--- a/tensorflow/core/graph/subgraph_test.cc
+++ b/tensorflow/core/graph/subgraph_test.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -361,7 +362,7 @@ static void BM_SubgraphHelper(int iters, int num_nodes,
last_node = ops::SourceOp("In", b.opts().WithName(name));
}
}
- TF_CHECK_OK(b.ToGraph(&g));
+ TF_CHECK_OK(GraphDefBuilderToGraph(b, &g));
}
std::vector<string> fed;
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index c4e21257ff..8e91baaa1c 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -44,9 +44,10 @@ tf_kernel_library(
],
)
+# TODO(mrry): Remove this empty forwarding library.
cc_library(
name = "dataset",
- srcs = ["dataset.cc"],
+ srcs = [],
hdrs = ["dataset.h"],
deps = [
"//tensorflow/core:core_cpu",
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index fc3e291afb..d7d4ad5cf7 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -160,6 +160,10 @@ class IteratorResource : public ResourceBase {
params.runner = *(ctx->runner());
params.function_library = flib_def;
params.lib = lib_;
+ DeviceBase* device = lib_->device();
+ params.allocator_getter = [device](AllocatorAttributes attrs) {
+ return device->GetAllocator(attrs);
+ };
IteratorContext iter_ctx(std::move(params));
TF_RETURN_IF_ERROR(captured_iterator->Restore(&iter_ctx, reader));
@@ -605,6 +609,11 @@ class ToSingleElementOp : public AsyncOpKernel {
params.env = ctx->env();
params.runner = *(ctx->runner());
params.lib = ctx->function_library();
+ DeviceBase* device = ctx->function_library()->device();
+ params.allocator_getter = [device](AllocatorAttributes attrs) {
+ return device->GetAllocator(attrs);
+ };
+
IteratorContext iter_ctx(std::move(params));
std::vector<Tensor> components;
@@ -863,6 +872,10 @@ class IteratorGetNextOp : public AsyncOpKernel {
};
params.runner = *(ctx->runner());
params.function_library = iterator->function_library();
+ DeviceBase* device = ctx->function_library()->device();
+ params.allocator_getter = [device](AllocatorAttributes attrs) {
+ return device->GetAllocator(attrs);
+ };
IteratorContext iter_ctx(std::move(params));
OP_REQUIRES_OK_ASYNC(
@@ -905,6 +918,10 @@ class IteratorGetNextSyncOp : public OpKernel {
};
params.runner = *(ctx->runner());
params.function_library = iterator->function_library();
+ DeviceBase* device = ctx->function_library()->device();
+ params.allocator_getter = [device](AllocatorAttributes attrs) {
+ return device->GetAllocator(attrs);
+ };
IteratorContext iter_ctx(std::move(params));
OP_REQUIRES_OK(ctx,