aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-02-11 10:06:15 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-02-11 11:49:17 -0800
commit7415dfae93d4ab3e1f15f874bf7a42f82cf8b377 (patch)
treebfadf0512203625350fb21416bef3526595f944d /tensorflow
parent73b9dd18ce3017829edef4a5b4190a4f0579369c (diff)
Moves MemoryType inference code out of OpKernel so that it can reused.
Change: 114448861
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/BUILD1
-rw-r--r--tensorflow/core/framework/memory_types.cc106
-rw-r--r--tensorflow/core/framework/memory_types.h37
-rw-r--r--tensorflow/core/framework/memory_types_test.cc71
-rw-r--r--tensorflow/core/framework/op_kernel.cc88
-rw-r--r--tensorflow/core/framework/op_kernel.h22
-rw-r--r--tensorflow/core/framework/op_kernel_test.cc37
-rw-r--r--tensorflow/core/graph/graph_partition.cc5
8 files changed, 233 insertions, 134 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 6324fe4a8a..efcc7f11b5 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -172,6 +172,7 @@ tf_cuda_library(
"framework/function.h",
"framework/kernel_def_builder.h",
"framework/lookup_interface.h",
+ "framework/memory_types.h",
"framework/node_def_builder.h",
"framework/node_def_util.h",
"framework/numeric_op.h",
diff --git a/tensorflow/core/framework/memory_types.cc b/tensorflow/core/framework/memory_types.cc
new file mode 100644
index 0000000000..5394edbf25
--- /dev/null
+++ b/tensorflow/core/framework/memory_types.cc
@@ -0,0 +1,106 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/memory_types.h"
+
+#include "tensorflow/core/framework/kernel_def.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+namespace {
+
+// Fills memory_types for either input or output, setting everything
+// to DEVICE_MEMORY except those args in host_memory_args. Removes
+// elements of host_memory_args that were used.
+void MemoryTypesHelper(const NameRangeMap& name_map,
+ std::vector<string>* host_memory_args,
+ MemoryTypeVector* memory_types) {
+ // Set total to the largest endpoint of anything in the name_map.
+ int total = 0;
+ for (const auto& item : name_map) {
+ total = std::max(total, item.second.second);
+ }
+
+ // Now that we know the size, fill with the default 'DEVICE_MEMORY'.
+ memory_types->clear();
+ memory_types->resize(total, DEVICE_MEMORY);
+
+ // Update args that have been marked as in "HOST_MEMORY".
+ size_t keep = 0;
+ for (size_t i = 0; i < host_memory_args->size(); ++i) {
+ auto iter = name_map.find((*host_memory_args)[i]);
+ if (iter != name_map.end()) {
+ for (int j = iter->second.first; j < iter->second.second; ++j) {
+ (*memory_types)[j] = HOST_MEMORY;
+ }
+ } else {
+ // (*host_memory_args)[i] not found, save it for the next pass.
+ if (i > keep) (*host_memory_args)[keep] = (*host_memory_args)[i];
+ ++keep;
+ }
+ }
+ host_memory_args->resize(keep);
+}
+
+Status MemoryTypesForNode(DeviceType device_type, const NodeDef& ndef,
+ const OpDef& op_def,
+ const NameRangeMap& input_name_map,
+ const NameRangeMap& output_name_map,
+ MemoryTypeVector* input_memory_types,
+ MemoryTypeVector* output_memory_types) {
+ Status status;
+ const KernelDef* kdef = nullptr;
+ TF_RETURN_IF_ERROR(FindKernelDef(device_type, ndef, &kdef));
+
+ if (kdef != nullptr) {
+ const auto& from_proto = kdef->host_memory_arg();
+ std::vector<string> host_memory_args(from_proto.begin(), from_proto.end());
+ MemoryTypesHelper(input_name_map, &host_memory_args, input_memory_types);
+ MemoryTypesHelper(output_name_map, &host_memory_args, output_memory_types);
+ if (!host_memory_args.empty()) {
+ return errors::InvalidArgument(
+ "HostMemory args '", str_util::Join(host_memory_args, "', '"),
+ "' not found in OpDef: ", SummarizeOpDef(op_def));
+ }
+ }
+ return status;
+}
+
+} // namespace
+
+Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
+ DeviceType device_type, const NodeDef& ndef,
+ MemoryTypeVector* input_memory_types,
+ MemoryTypeVector* output_memory_types) {
+ // Look up the Op registered for this op name.
+ Status status;
+ const OpDef* op_def = op_registry->LookUp(ndef.op(), &status);
+ if (op_def == nullptr) return status;
+
+ NameRangeMap inputs;
+ NameRangeMap outputs;
+ status = NameRangesForNode(ndef, *op_def, &inputs, &outputs);
+ if (!status.ok()) return status;
+
+ return MemoryTypesForNode(device_type, ndef, *op_def, inputs, outputs,
+ input_memory_types, output_memory_types);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/memory_types.h b/tensorflow/core/framework/memory_types.h
new file mode 100644
index 0000000000..b9a0cf3207
--- /dev/null
+++ b/tensorflow/core/framework/memory_types.h
@@ -0,0 +1,37 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_FRAMEWORK_MEMORY_TYPES_H_
+#define TENSORFLOW_FRAMEWORK_MEMORY_TYPES_H_
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/types.h"
+
+namespace tensorflow {
+
+// Returns into *{input,output}_memory_types the memory type of each
+// {input,output} tensor.
+//
+// REQUIRES: * '*_memory_types' is not nullptr.
+// * def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
+Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
+ DeviceType device_type, const NodeDef& ndef,
+ MemoryTypeVector* input_memory_types,
+ MemoryTypeVector* output_memory_types);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_MEMORY_TYPES_H_
diff --git a/tensorflow/core/framework/memory_types_test.cc b/tensorflow/core/framework/memory_types_test.cc
new file mode 100644
index 0000000000..7d593514a8
--- /dev/null
+++ b/tensorflow/core/framework/memory_types_test.cc
@@ -0,0 +1,71 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/memory_types.h"
+
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+class DummyKernel : public OpKernel {
+ public:
+ explicit DummyKernel(tensorflow::OpKernelConstruction* context)
+ : OpKernel(context) {}
+ void Compute(tensorflow::OpKernelContext* context) override {}
+};
+
+REGISTER_OP("HostMemoryTest")
+ .Input("a: float")
+ .Input("b: T")
+ .Input("c: N * string")
+ .Output("o: N * T")
+ .Attr("T: type")
+ .Attr("N: int");
+REGISTER_KERNEL_BUILDER(Name("HostMemoryTest").Device(DEVICE_CPU), DummyKernel);
+REGISTER_KERNEL_BUILDER(Name("HostMemoryTest")
+ .Device(DEVICE_GPU)
+ .HostMemory("a")
+ .HostMemory("c")
+ .HostMemory("o"),
+ DummyKernel);
+
+TEST(MemoryTypesForNode, Simple) {
+ NodeDef node_def;
+ TF_ASSERT_OK(NodeDefBuilder("test", "HostMemoryTest")
+ .Input(FakeInput())
+ .Input(FakeInput(DT_BOOL))
+ .Input(FakeInput(3))
+ .Finalize(&node_def));
+ MemoryTypeVector input, output;
+
+ TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_CPU, node_def,
+ &input, &output));
+ EXPECT_EQ(MemoryTypeVector(5, DEVICE_MEMORY), input);
+ EXPECT_EQ(MemoryTypeVector(3, DEVICE_MEMORY), output);
+
+ TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_GPU, node_def,
+ &input, &output));
+ EXPECT_EQ(MemoryTypeVector({HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY,
+ HOST_MEMORY, HOST_MEMORY}),
+ input);
+ EXPECT_EQ(MemoryTypeVector(3, HOST_MEMORY), output);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index 9bcb101ef9..f988cf41fb 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/framework/types.h"
@@ -550,6 +551,14 @@ Status FindKernelRegistration(DeviceType device_type, const NodeDef& node_def,
} // namespace
+Status FindKernelDef(DeviceType device_type, const NodeDef& node_def,
+ const KernelDef** def) {
+ const KernelRegistration* reg;
+ TF_RETURN_IF_ERROR(FindKernelRegistration(device_type, node_def, &reg));
+ *def = &reg->def;
+ return Status::OK();
+}
+
Status SupportedDeviceTypesForNode(
const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
DeviceTypeVector* device_types) {
@@ -627,7 +636,7 @@ Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
// the kernel's input and output memory types.
MemoryTypeVector input_memory_types;
MemoryTypeVector output_memory_types;
- TF_RETURN_IF_ERROR(MemoryTypesForNode(*OpRegistry::Global(), device_type,
+ TF_RETURN_IF_ERROR(MemoryTypesForNode(OpRegistry::Global(), device_type,
node_def, &input_memory_types,
&output_memory_types));
@@ -643,83 +652,6 @@ Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
return s;
}
-namespace { // Helper for MemoryTypesForNode.
-// Fills memory_types for either input or output, setting everything
-// to DEVICE_MEMORY except those args in host_memory_args. Removes
-// elements of host_memory_args that were used.
-void MemoryTypesHelper(const NameRangeMap& name_map,
- std::vector<string>* host_memory_args,
- MemoryTypeVector* memory_types) {
- // Set total to the largest endpoint of anything in the name_map.
- int total = 0;
- for (const auto& item : name_map) {
- total = std::max(total, item.second.second);
- }
-
- // Now that we know the size, fill with the default 'DEVICE_MEMORY'.
- memory_types->clear();
- memory_types->resize(total, DEVICE_MEMORY);
-
- // Update args that have been marked as in "HOST_MEMORY".
- size_t keep = 0;
- for (size_t i = 0; i < host_memory_args->size(); ++i) {
- auto iter = name_map.find((*host_memory_args)[i]);
- if (iter != name_map.end()) {
- for (int j = iter->second.first; j < iter->second.second; ++j) {
- (*memory_types)[j] = HOST_MEMORY;
- }
- } else {
- // (*host_memory_args)[i] not found, save it for the next pass.
- if (i > keep) (*host_memory_args)[keep] = (*host_memory_args)[i];
- ++keep;
- }
- }
- host_memory_args->resize(keep);
-}
-} // namespace
-
-Status MemoryTypesForNode(DeviceType device_type, const NodeDef& ndef,
- const OpDef& op_def,
- const NameRangeMap& input_name_map,
- const NameRangeMap& output_name_map,
- MemoryTypeVector* input_memory_types,
- MemoryTypeVector* output_memory_types) {
- Status status;
- const KernelRegistration* registration;
- TF_RETURN_IF_ERROR(FindKernelRegistration(device_type, ndef, &registration));
-
- if (registration != nullptr) {
- const auto& from_proto = registration->def.host_memory_arg();
- std::vector<string> host_memory_args(from_proto.begin(), from_proto.end());
- MemoryTypesHelper(input_name_map, &host_memory_args, input_memory_types);
- MemoryTypesHelper(output_name_map, &host_memory_args, output_memory_types);
- if (!host_memory_args.empty()) {
- return errors::InvalidArgument(
- "HostMemory args '", str_util::Join(host_memory_args, "', '"),
- "' not found in OpDef: ", SummarizeOpDef(op_def));
- }
- }
- return status;
-}
-
-Status MemoryTypesForNode(const OpRegistryInterface& op_registry,
- DeviceType device_type, const NodeDef& ndef,
- MemoryTypeVector* input_memory_types,
- MemoryTypeVector* output_memory_types) {
- // Look up the Op registered for this op name.
- Status status;
- const OpDef* op_def = op_registry.LookUp(ndef.op(), &status);
- if (op_def == nullptr) return status;
-
- NameRangeMap inputs;
- NameRangeMap outputs;
- status = NameRangesForNode(ndef, *op_def, &inputs, &outputs);
- if (!status.ok()) return status;
-
- return MemoryTypesForNode(device_type, ndef, *op_def, inputs, outputs,
- input_memory_types, output_memory_types);
-}
-
namespace {
bool FindArgInOp(const string& arg_name,
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index a762e9316a..28f4ffacb8 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -1014,23 +1014,6 @@ Status SupportedDeviceTypesForNode(
const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
DeviceTypeVector* device_types);
-// Returns into *{input,output}_memory_types the memory type of each
-// {input,output} tensor.
-//
-// REQUIRES: * '*_memory_types' is not nullptr.
-// * def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
-Status MemoryTypesForNode(DeviceType device_type, const NodeDef& ndef,
- const OpDef& op_def,
- const NameRangeMap& input_name_map,
- const NameRangeMap& output_name_map,
- MemoryTypeVector* input_memory_types,
- MemoryTypeVector* output_memory_types);
-
-Status MemoryTypesForNode(const OpRegistryInterface& op_registry,
- DeviceType device_type, const NodeDef& ndef,
- MemoryTypeVector* input_memory_types,
- MemoryTypeVector* output_memory_types);
-
// Call once after Op registration has completed.
Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry);
@@ -1057,6 +1040,11 @@ typedef ::tensorflow::KernelDefBuilder Name;
void* GlobalKernelRegistry();
+// If node_def has a corresponding kernel registered on device_type,
+// returns OK and fill in the kernel def.
+Status FindKernelDef(DeviceType device_type, const NodeDef& node_def,
+ const KernelDef** def);
+
// Treats 'registry_ptr' as a pointer to KernelRegistry. For each kernel 'k'
// registered with the current library's global kernel registry (obtained by
// calling GlobalKernelRegistry()), inserts 'k' into registry_ptr.
diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc
index a531b10cde..4275ad0f73 100644
--- a/tensorflow/core/framework/op_kernel_test.cc
+++ b/tensorflow/core/framework/op_kernel_test.cc
@@ -745,43 +745,6 @@ TEST_F(GetAttrTest, TypeList) {
EXPECT_EQ(DT_BOOL, get_attr_kernel->type_vector[1]);
}
-REGISTER_OP("HostMemoryTest")
- .Input("a: float")
- .Input("b: T")
- .Input("c: N * string")
- .Output("o: N * T")
- .Attr("T: type")
- .Attr("N: int");
-REGISTER_KERNEL_BUILDER(Name("HostMemoryTest").Device(DEVICE_CPU), DummyKernel);
-REGISTER_KERNEL_BUILDER(Name("HostMemoryTest")
- .Device(DEVICE_GPU)
- .HostMemory("a")
- .HostMemory("c")
- .HostMemory("o"),
- DummyKernel);
-
-TEST(MemoryTypesForNode, Simple) {
- NodeDef node_def;
- TF_ASSERT_OK(NodeDefBuilder("test", "HostMemoryTest")
- .Input(FakeInput())
- .Input(FakeInput(DT_BOOL))
- .Input(FakeInput(3))
- .Finalize(&node_def));
- MemoryTypeVector input, output;
-
- TF_EXPECT_OK(MemoryTypesForNode(*OpRegistry::Global(), DEVICE_CPU, node_def,
- &input, &output));
- EXPECT_EQ(MemoryTypeVector(5, DEVICE_MEMORY), input);
- EXPECT_EQ(MemoryTypeVector(3, DEVICE_MEMORY), output);
-
- TF_EXPECT_OK(MemoryTypesForNode(*OpRegistry::Global(), DEVICE_GPU, node_def,
- &input, &output));
- EXPECT_EQ(MemoryTypeVector({HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY,
- HOST_MEMORY, HOST_MEMORY}),
- input);
- EXPECT_EQ(MemoryTypeVector(3, HOST_MEMORY), output);
-}
-
class BaseKernel : public ::tensorflow::OpKernel {
public:
explicit BaseKernel(OpKernelConstruction* context) : OpKernel(context) {}
diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc
index 16efeb777a..e73a7aa3e5 100644
--- a/tensorflow/core/graph/graph_partition.cc
+++ b/tensorflow/core/graph/graph_partition.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/node_def_builder.h"
-#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/costmodel.h"
#include "tensorflow/core/graph/graph_def_builder.h"
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
@@ -643,7 +644,7 @@ Status BuildMemoryDeviceInfo(const Graph& g, GraphInfo* info) {
input_memory_types.resize(node->num_inputs());
output_memory_types.clear();
output_memory_types.resize(node->num_outputs());
- status = MemoryTypesForNode(*g.op_registry(), DeviceType(parsed.type),
+ status = MemoryTypesForNode(g.op_registry(), DeviceType(parsed.type),
node->def(), &input_memory_types,
&output_memory_types);
if (!status.ok()) return status;