diff options
author | A. Unique TensorFlower <nobody@tensorflow.org> | 2016-02-11 10:06:15 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-02-11 11:49:17 -0800 |
commit | 7415dfae93d4ab3e1f15f874bf7a42f82cf8b377 (patch) | |
tree | bfadf0512203625350fb21416bef3526595f944d /tensorflow | |
parent | 73b9dd18ce3017829edef4a5b4190a4f0579369c (diff) |
Moves MemoryType inference code out of OpKernel so that it can reused.
Change: 114448861
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/core/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/core/framework/memory_types.cc | 106 | ||||
-rw-r--r-- | tensorflow/core/framework/memory_types.h | 37 | ||||
-rw-r--r-- | tensorflow/core/framework/memory_types_test.cc | 71 | ||||
-rw-r--r-- | tensorflow/core/framework/op_kernel.cc | 88 | ||||
-rw-r--r-- | tensorflow/core/framework/op_kernel.h | 22 | ||||
-rw-r--r-- | tensorflow/core/framework/op_kernel_test.cc | 37 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_partition.cc | 5 |
8 files changed, 233 insertions, 134 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 6324fe4a8a..efcc7f11b5 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -172,6 +172,7 @@ tf_cuda_library( "framework/function.h", "framework/kernel_def_builder.h", "framework/lookup_interface.h", + "framework/memory_types.h", "framework/node_def_builder.h", "framework/node_def_util.h", "framework/numeric_op.h", diff --git a/tensorflow/core/framework/memory_types.cc b/tensorflow/core/framework/memory_types.cc new file mode 100644 index 0000000000..5394edbf25 --- /dev/null +++ b/tensorflow/core/framework/memory_types.cc @@ -0,0 +1,106 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/memory_types.h" + +#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +namespace { + +// Fills memory_types for either input or output, setting everything +// to DEVICE_MEMORY except those args in host_memory_args. Removes +// elements of host_memory_args that were used. +void MemoryTypesHelper(const NameRangeMap& name_map, + std::vector<string>* host_memory_args, + MemoryTypeVector* memory_types) { + // Set total to the largest endpoint of anything in the name_map. + int total = 0; + for (const auto& item : name_map) { + total = std::max(total, item.second.second); + } + + // Now that we know the size, fill with the default 'DEVICE_MEMORY'. + memory_types->clear(); + memory_types->resize(total, DEVICE_MEMORY); + + // Update args that have been marked as in "HOST_MEMORY". + size_t keep = 0; + for (size_t i = 0; i < host_memory_args->size(); ++i) { + auto iter = name_map.find((*host_memory_args)[i]); + if (iter != name_map.end()) { + for (int j = iter->second.first; j < iter->second.second; ++j) { + (*memory_types)[j] = HOST_MEMORY; + } + } else { + // (*host_memory_args)[i] not found, save it for the next pass. + if (i > keep) (*host_memory_args)[keep] = (*host_memory_args)[i]; + ++keep; + } + } + host_memory_args->resize(keep); +} + +Status MemoryTypesForNode(DeviceType device_type, const NodeDef& ndef, + const OpDef& op_def, + const NameRangeMap& input_name_map, + const NameRangeMap& output_name_map, + MemoryTypeVector* input_memory_types, + MemoryTypeVector* output_memory_types) { + Status status; + const KernelDef* kdef = nullptr; + TF_RETURN_IF_ERROR(FindKernelDef(device_type, ndef, &kdef)); + + if (kdef != nullptr) { + const auto& from_proto = kdef->host_memory_arg(); + std::vector<string> host_memory_args(from_proto.begin(), from_proto.end()); + MemoryTypesHelper(input_name_map, &host_memory_args, input_memory_types); + MemoryTypesHelper(output_name_map, &host_memory_args, output_memory_types); + if (!host_memory_args.empty()) { + return errors::InvalidArgument( + "HostMemory args '", str_util::Join(host_memory_args, "', '"), + "' not found in OpDef: ", SummarizeOpDef(op_def)); + } + } + return status; +} + +} // namespace + +Status MemoryTypesForNode(const OpRegistryInterface* op_registry, + DeviceType device_type, const NodeDef& ndef, + MemoryTypeVector* input_memory_types, + MemoryTypeVector* output_memory_types) { + // Look up the Op registered for this op name. + Status status; + const OpDef* op_def = op_registry->LookUp(ndef.op(), &status); + if (op_def == nullptr) return status; + + NameRangeMap inputs; + NameRangeMap outputs; + status = NameRangesForNode(ndef, *op_def, &inputs, &outputs); + if (!status.ok()) return status; + + return MemoryTypesForNode(device_type, ndef, *op_def, inputs, outputs, + input_memory_types, output_memory_types); +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/memory_types.h b/tensorflow/core/framework/memory_types.h new file mode 100644 index 0000000000..b9a0cf3207 --- /dev/null +++ b/tensorflow/core/framework/memory_types.h @@ -0,0 +1,37 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_MEMORY_TYPES_H_ +#define TENSORFLOW_FRAMEWORK_MEMORY_TYPES_H_ + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { + +// Returns into *{input,output}_memory_types the memory type of each +// {input,output} tensor. +// +// REQUIRES: * '*_memory_types' is not nullptr. +// * def has all attrs specified (e.g. using AddDefaultsToNodeDef()). +Status MemoryTypesForNode(const OpRegistryInterface* op_registry, + DeviceType device_type, const NodeDef& ndef, + MemoryTypeVector* input_memory_types, + MemoryTypeVector* output_memory_types); + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_MEMORY_TYPES_H_ diff --git a/tensorflow/core/framework/memory_types_test.cc b/tensorflow/core/framework/memory_types_test.cc new file mode 100644 index 0000000000..7d593514a8 --- /dev/null +++ b/tensorflow/core/framework/memory_types_test.cc @@ -0,0 +1,71 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/memory_types.h" + +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +class DummyKernel : public OpKernel { + public: + explicit DummyKernel(tensorflow::OpKernelConstruction* context) + : OpKernel(context) {} + void Compute(tensorflow::OpKernelContext* context) override {} +}; + +REGISTER_OP("HostMemoryTest") + .Input("a: float") + .Input("b: T") + .Input("c: N * string") + .Output("o: N * T") + .Attr("T: type") + .Attr("N: int"); +REGISTER_KERNEL_BUILDER(Name("HostMemoryTest").Device(DEVICE_CPU), DummyKernel); +REGISTER_KERNEL_BUILDER(Name("HostMemoryTest") + .Device(DEVICE_GPU) + .HostMemory("a") + .HostMemory("c") + .HostMemory("o"), + DummyKernel); + +TEST(MemoryTypesForNode, Simple) { + NodeDef node_def; + TF_ASSERT_OK(NodeDefBuilder("test", "HostMemoryTest") + .Input(FakeInput()) + .Input(FakeInput(DT_BOOL)) + .Input(FakeInput(3)) + .Finalize(&node_def)); + MemoryTypeVector input, output; + + TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_CPU, node_def, + &input, &output)); + EXPECT_EQ(MemoryTypeVector(5, DEVICE_MEMORY), input); + EXPECT_EQ(MemoryTypeVector(3, DEVICE_MEMORY), output); + + TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_GPU, node_def, + &input, &output)); + EXPECT_EQ(MemoryTypeVector({HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY, + HOST_MEMORY, HOST_MEMORY}), + input); + EXPECT_EQ(MemoryTypeVector(3, HOST_MEMORY), output); +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index 9bcb101ef9..f988cf41fb 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -19,6 +19,7 @@ limitations under the License. #include <vector> #include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/memory_types.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/framework/types.h" @@ -550,6 +551,14 @@ Status FindKernelRegistration(DeviceType device_type, const NodeDef& node_def, } // namespace +Status FindKernelDef(DeviceType device_type, const NodeDef& node_def, + const KernelDef** def) { + const KernelRegistration* reg; + TF_RETURN_IF_ERROR(FindKernelRegistration(device_type, node_def, ®)); + *def = ®->def; + return Status::OK(); +} + Status SupportedDeviceTypesForNode( const std::vector<DeviceType>& prioritized_types, const NodeDef& def, DeviceTypeVector* device_types) { @@ -627,7 +636,7 @@ Status CreateOpKernel(DeviceType device_type, DeviceBase* device, // the kernel's input and output memory types. MemoryTypeVector input_memory_types; MemoryTypeVector output_memory_types; - TF_RETURN_IF_ERROR(MemoryTypesForNode(*OpRegistry::Global(), device_type, + TF_RETURN_IF_ERROR(MemoryTypesForNode(OpRegistry::Global(), device_type, node_def, &input_memory_types, &output_memory_types)); @@ -643,83 +652,6 @@ Status CreateOpKernel(DeviceType device_type, DeviceBase* device, return s; } -namespace { // Helper for MemoryTypesForNode. -// Fills memory_types for either input or output, setting everything -// to DEVICE_MEMORY except those args in host_memory_args. Removes -// elements of host_memory_args that were used. -void MemoryTypesHelper(const NameRangeMap& name_map, - std::vector<string>* host_memory_args, - MemoryTypeVector* memory_types) { - // Set total to the largest endpoint of anything in the name_map. - int total = 0; - for (const auto& item : name_map) { - total = std::max(total, item.second.second); - } - - // Now that we know the size, fill with the default 'DEVICE_MEMORY'. - memory_types->clear(); - memory_types->resize(total, DEVICE_MEMORY); - - // Update args that have been marked as in "HOST_MEMORY". - size_t keep = 0; - for (size_t i = 0; i < host_memory_args->size(); ++i) { - auto iter = name_map.find((*host_memory_args)[i]); - if (iter != name_map.end()) { - for (int j = iter->second.first; j < iter->second.second; ++j) { - (*memory_types)[j] = HOST_MEMORY; - } - } else { - // (*host_memory_args)[i] not found, save it for the next pass. - if (i > keep) (*host_memory_args)[keep] = (*host_memory_args)[i]; - ++keep; - } - } - host_memory_args->resize(keep); -} -} // namespace - -Status MemoryTypesForNode(DeviceType device_type, const NodeDef& ndef, - const OpDef& op_def, - const NameRangeMap& input_name_map, - const NameRangeMap& output_name_map, - MemoryTypeVector* input_memory_types, - MemoryTypeVector* output_memory_types) { - Status status; - const KernelRegistration* registration; - TF_RETURN_IF_ERROR(FindKernelRegistration(device_type, ndef, ®istration)); - - if (registration != nullptr) { - const auto& from_proto = registration->def.host_memory_arg(); - std::vector<string> host_memory_args(from_proto.begin(), from_proto.end()); - MemoryTypesHelper(input_name_map, &host_memory_args, input_memory_types); - MemoryTypesHelper(output_name_map, &host_memory_args, output_memory_types); - if (!host_memory_args.empty()) { - return errors::InvalidArgument( - "HostMemory args '", str_util::Join(host_memory_args, "', '"), - "' not found in OpDef: ", SummarizeOpDef(op_def)); - } - } - return status; -} - -Status MemoryTypesForNode(const OpRegistryInterface& op_registry, - DeviceType device_type, const NodeDef& ndef, - MemoryTypeVector* input_memory_types, - MemoryTypeVector* output_memory_types) { - // Look up the Op registered for this op name. - Status status; - const OpDef* op_def = op_registry.LookUp(ndef.op(), &status); - if (op_def == nullptr) return status; - - NameRangeMap inputs; - NameRangeMap outputs; - status = NameRangesForNode(ndef, *op_def, &inputs, &outputs); - if (!status.ok()) return status; - - return MemoryTypesForNode(device_type, ndef, *op_def, inputs, outputs, - input_memory_types, output_memory_types); -} - namespace { bool FindArgInOp(const string& arg_name, diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index a762e9316a..28f4ffacb8 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -1014,23 +1014,6 @@ Status SupportedDeviceTypesForNode( const std::vector<DeviceType>& prioritized_types, const NodeDef& def, DeviceTypeVector* device_types); -// Returns into *{input,output}_memory_types the memory type of each -// {input,output} tensor. -// -// REQUIRES: * '*_memory_types' is not nullptr. -// * def has all attrs specified (e.g. using AddDefaultsToNodeDef()). -Status MemoryTypesForNode(DeviceType device_type, const NodeDef& ndef, - const OpDef& op_def, - const NameRangeMap& input_name_map, - const NameRangeMap& output_name_map, - MemoryTypeVector* input_memory_types, - MemoryTypeVector* output_memory_types); - -Status MemoryTypesForNode(const OpRegistryInterface& op_registry, - DeviceType device_type, const NodeDef& ndef, - MemoryTypeVector* input_memory_types, - MemoryTypeVector* output_memory_types); - // Call once after Op registration has completed. Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry); @@ -1057,6 +1040,11 @@ typedef ::tensorflow::KernelDefBuilder Name; void* GlobalKernelRegistry(); +// If node_def has a corresponding kernel registered on device_type, +// returns OK and fill in the kernel def. +Status FindKernelDef(DeviceType device_type, const NodeDef& node_def, + const KernelDef** def); + // Treats 'registry_ptr' as a pointer to KernelRegistry. For each kernel 'k' // registered with the current library's global kernel registry (obtained by // calling GlobalKernelRegistry()), inserts 'k' into registry_ptr. diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc index a531b10cde..4275ad0f73 100644 --- a/tensorflow/core/framework/op_kernel_test.cc +++ b/tensorflow/core/framework/op_kernel_test.cc @@ -745,43 +745,6 @@ TEST_F(GetAttrTest, TypeList) { EXPECT_EQ(DT_BOOL, get_attr_kernel->type_vector[1]); } -REGISTER_OP("HostMemoryTest") - .Input("a: float") - .Input("b: T") - .Input("c: N * string") - .Output("o: N * T") - .Attr("T: type") - .Attr("N: int"); -REGISTER_KERNEL_BUILDER(Name("HostMemoryTest").Device(DEVICE_CPU), DummyKernel); -REGISTER_KERNEL_BUILDER(Name("HostMemoryTest") - .Device(DEVICE_GPU) - .HostMemory("a") - .HostMemory("c") - .HostMemory("o"), - DummyKernel); - -TEST(MemoryTypesForNode, Simple) { - NodeDef node_def; - TF_ASSERT_OK(NodeDefBuilder("test", "HostMemoryTest") - .Input(FakeInput()) - .Input(FakeInput(DT_BOOL)) - .Input(FakeInput(3)) - .Finalize(&node_def)); - MemoryTypeVector input, output; - - TF_EXPECT_OK(MemoryTypesForNode(*OpRegistry::Global(), DEVICE_CPU, node_def, - &input, &output)); - EXPECT_EQ(MemoryTypeVector(5, DEVICE_MEMORY), input); - EXPECT_EQ(MemoryTypeVector(3, DEVICE_MEMORY), output); - - TF_EXPECT_OK(MemoryTypesForNode(*OpRegistry::Global(), DEVICE_GPU, node_def, - &input, &output)); - EXPECT_EQ(MemoryTypeVector({HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY, - HOST_MEMORY, HOST_MEMORY}), - input); - EXPECT_EQ(MemoryTypeVector(3, HOST_MEMORY), output); -} - class BaseKernel : public ::tensorflow::OpKernel { public: explicit BaseKernel(OpKernelConstruction* context) : OpKernel(context) {} diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index 16efeb777a..e73a7aa3e5 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -19,8 +19,8 @@ limitations under the License. #include <unordered_map> #include <vector> +#include "tensorflow/core/framework/memory_types.h" #include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/costmodel.h" #include "tensorflow/core/graph/graph_def_builder.h" @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { @@ -643,7 +644,7 @@ Status BuildMemoryDeviceInfo(const Graph& g, GraphInfo* info) { input_memory_types.resize(node->num_inputs()); output_memory_types.clear(); output_memory_types.resize(node->num_outputs()); - status = MemoryTypesForNode(*g.op_registry(), DeviceType(parsed.type), + status = MemoryTypesForNode(g.op_registry(), DeviceType(parsed.type), node->def(), &input_memory_types, &output_memory_types); if (!status.ok()) return status; |