diff options
author | Manjunath Kudlur <keveman@gmail.com> | 2015-11-06 16:27:58 -0800 |
---|---|---|
committer | Manjunath Kudlur <keveman@gmail.com> | 2015-11-06 16:27:58 -0800 |
commit | f41959ccb2d9d4c722fe8fc3351401d53bcf4900 (patch) | |
tree | ef0ca22cb2a5ac4bdec9d080d8e0788a53ed496d /tensorflow/core/framework/op_kernel.cc |
TensorFlow: Initial commit of TensorFlow library.
TensorFlow is an open source software library for numerical computation
using data flow graphs.
Base CL: 107276108
Diffstat (limited to 'tensorflow/core/framework/op_kernel.cc')
-rw-r--r-- | tensorflow/core/framework/op_kernel.cc | 749 |
1 files changed, 749 insertions, 0 deletions
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc new file mode 100644 index 0000000000..eb83d393f0 --- /dev/null +++ b/tensorflow/core/framework/op_kernel.cc @@ -0,0 +1,749 @@ +#include "tensorflow/core/framework/op_kernel.h" + +#include <unordered_map> + +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +namespace { + +Status MatchSignatureHelper(const DataTypeSlice expected_inputs, + const DataTypeSlice expected_outputs, + const DataTypeSlice inputs, + const DataTypeSlice outputs) { + bool signature_mismatch = false; + + if (inputs.size() != expected_inputs.size()) signature_mismatch = true; + for (size_t i = 0; !signature_mismatch && i < inputs.size(); ++i) { + if (!TypesCompatible(expected_inputs[i], inputs[i])) { + signature_mismatch = true; + } + } + + if (outputs.size() != expected_outputs.size()) signature_mismatch = true; + for (size_t i = 0; !signature_mismatch && i < outputs.size(); ++i) { + if (!TypesCompatible(expected_outputs[i], outputs[i])) { + signature_mismatch = true; + } + } + + if (signature_mismatch) { + return errors::InvalidArgument("Signature mismatch, have: ", + DataTypeSliceString(inputs), "->", + DataTypeSliceString(outputs), " expected: ", + DataTypeSliceString(expected_inputs), "->", + DataTypeSliceString(expected_outputs)); + } + return Status::OK(); +} + +// Check HostMemory backward compatibility. +bool CheckHostMemoryCompatibility(const DeviceType device_type, + const OpKernel* kernel) { + if (device_type == DEVICE_GPU) { + for (int i = 0; i < kernel->num_inputs(); ++i) { + if (kernel->input_type(i) == DT_INT32 && + kernel->input_memory_types()[i] != HOST_MEMORY) { + return false; + } + } + for (int i = 0; i < kernel->num_outputs(); ++i) { + if (kernel->output_type(i) == DT_INT32 && + kernel->output_memory_types()[i] != HOST_MEMORY) { + return false; + } + } + } + return true; +} + +} // namespace + +// OpKernel ------------------------------------------------------------------ + +OpKernel::OpKernel(OpKernelConstruction* context) + : def_(context->def()), + input_types_(context->input_types().begin(), + context->input_types().end()), + output_types_(context->output_types().begin(), + context->output_types().end()), + input_name_map_(context->num_inputs()), + output_name_map_(context->num_outputs()) { + OP_REQUIRES_OK(context, + NameRangesForNode(def_, context->op_def(), &input_name_map_, + &output_name_map_)); + + // By default, the input and output memory types are always in device memory, + // but can be overridden by individual implementations of OpKernels in their + // constructor. + input_memory_types_ = MemoryTypeVector(input_types_.size(), DEVICE_MEMORY); + output_memory_types_ = MemoryTypeVector(output_types_.size(), DEVICE_MEMORY); + // TODO(yuanbyu): For now we assume the memory types of function + // inputs/outputs to be DEVICE_MEMORY. + auto lib = context->function_library(); + if (lib == nullptr || !lib->IsDefined(def_.op())) { + OP_REQUIRES_OK(context, MemoryTypesForNode( + context->device_type(), def_, context->op_def(), + input_name_map_, output_name_map_, + &input_memory_types_, &output_memory_types_)); + // Log all the uses of int32 on GPU. + // TODO(yunabyu): Remove once everyone transitions to HostMemory. + if (VLOG_IS_ON(2)) { + if (!CheckHostMemoryCompatibility(context->device_type(), this)) { + VLOG(2) << "Using int32 on GPU at node: " << SummarizeNodeDef(def()); + } + } + } +} + +Status OpKernel::InputRange(const string& input_name, int* start, + int* stop) const { + const auto result = input_name_map_.find(input_name); + if (result == input_name_map_.end()) { + return errors::InvalidArgument("Unknown input name: ", input_name); + } else { + *start = result->second.first; + *stop = result->second.second; + return Status::OK(); + } +} + +Status OpKernel::OutputRange(const string& output_name, int* start, + int* stop) const { + const auto result = output_name_map_.find(output_name); + if (result == output_name_map_.end()) { + return errors::InvalidArgument("Unknown output name: ", output_name); + } else { + *start = result->second.first; + *stop = result->second.second; + return Status::OK(); + } +} + +void AsyncOpKernel::Compute(OpKernelContext* context) { + Notification n; + ComputeAsync(context, [&n]() { n.Notify(); }); + n.WaitForNotification(); +} + +// PersistentTensor ---------------------------------------------------------- + +Tensor* PersistentTensor::AccessTensor(OpKernelConstruction* context) { + // the caller has to have a valid context + CHECK(context); + return &tensor_; +} + +Tensor* PersistentTensor::AccessTensor(OpKernelContext* context) { + context->NotifyUseOfPersistentTensor(tensor_); + return &tensor_; +} + +// OpKernelConstruction ------------------------------------------------------ + +Status OpKernelConstruction::MatchSignature( + const DataTypeSlice expected_inputs, const DataTypeSlice expected_outputs) { + return MatchSignatureHelper(expected_inputs, expected_outputs, input_types_, + output_types_); +} + +Status OpKernelConstruction::allocate_temp(DataType type, + const TensorShape& shape, + Tensor* out_temp) { + Tensor new_temp(allocator_, type, shape); + + if (!new_temp.IsInitialized() && shape.num_elements() > 0) { + return errors::ResourceExhausted( + "OOM when allocating temporary tensor with shape", shape.DebugString()); + } + *out_temp = new_temp; + return Status::OK(); +} + +Status OpKernelConstruction::allocate_persistent( + DataType type, const TensorShape& shape, PersistentTensor* out_persistent, + Tensor** out_tensor) { + // for now just do the same thing as allocate_temp + // TODO(misard) add specific memory tracking for persistent tensors + Tensor persistent; + Status s = allocate_temp(type, shape, &persistent); + if (!s.ok()) { + return s; + } + *out_persistent = PersistentTensor(persistent); + Tensor* allocated = out_persistent->AccessTensor(this); + if (out_tensor) { + *out_tensor = allocated; + } + return s; +} + +// OpKernelContext ----------------------------------------------------------- + +OpKernelContext::OpKernelContext(const Params& params) + : params_(params), + outputs_(params.op_kernel->output_types().size()), + output_allocation_types_(params.op_kernel->output_types().size()) { + Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes()); + eigen_gpu_device_ = params_.device->MakeGpuDevice(params_.op_device_context, + eigen_gpu_allocator); +} + +OpKernelContext::~OpKernelContext() { + for (TensorValue& value : outputs_) { + if (!value.is_ref()) { + delete value.tensor; + } + } + for (Tensor* t : temp_tensors_) delete t; + delete eigen_gpu_device_; +} + +Status OpKernelContext::input(const string& name, const Tensor** tensor) const { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued input name '", + name, + "' when single-valued input was " + "expected"); + } + if ((*params_.inputs)[start].is_ref()) { + return errors::InvalidArgument("OpKernel used ref input name '", name, + "' when immutable input was expected"); + } + *tensor = (*params_.inputs)[start].tensor; + return Status::OK(); +} + +Status OpKernelContext::input_ref_mutex(const string& name, mutex** out_mutex) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued input name '", + name, + "' when single-valued input was expected"); + } + *out_mutex = input_ref_mutex(start); + return Status::OK(); +} + +Status OpKernelContext::mutable_input(const string& name, Tensor* tensor, + bool lock_held) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued input name '", + name, + "' when single-valued input was expected"); + } + if (!(*params_.inputs)[start].is_ref()) { + return errors::InvalidArgument("OpKernel used immutable input name '", name, + "' when ref input was expected"); + } + // return a copy of the Ref acquired while holding the mutex + if (lock_held) { + *tensor = *(*params_.inputs)[start].tensor; + } else { + mutex_lock l(*input_ref_mutex(start)); + *tensor = *(*params_.inputs)[start].tensor; + } + return Status::OK(); +} + +Status OpKernelContext::replace_ref_input(const string& name, + const Tensor& tensor, + bool lock_held) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued input name '", + name, + "' when single-valued input was expected"); + } + if (!(*params_.inputs)[start].is_ref()) { + return errors::InvalidArgument("OpKernel used immutable input name '", name, + "' when ref input was expected"); + } + replace_ref_input(start, tensor, lock_held); + return Status::OK(); +} + +Status OpKernelContext::input_list(const string& name, + OpInputList* list) const { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop)); + *list = OpInputList(this, start, stop); + return Status::OK(); +} + +Status OpKernelContext::mutable_input_list(const string& name, + OpMutableInputList* list) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop)); + *list = OpMutableInputList(this, start, stop); + return Status::OK(); +} + +Status OpKernelContext::output_list(const string& name, OpOutputList* list) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop)); + *list = OpOutputList(this, start, stop); + return Status::OK(); +} + +Status OpKernelContext::allocate_output(const string& name, + const TensorShape& shape, + Tensor** tensor) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued output name '", + name, + "' when single-valued output was " + "expected"); + } + return allocate_output(start, shape, tensor); +} + +Status OpKernelContext::allocate_output(const string& name, + const TensorShape& shape, + Tensor** tensor, + AllocatorAttributes attr) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued output name '", + name, + "' when single-valued output was " + "expected"); + } + return allocate_output(start, shape, tensor, attr); +} + +Status OpKernelContext::set_output(const string& name, const Tensor& tensor) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued output name '", + name, + "' when single-valued output was " + "expected"); + } + set_output(start, tensor); + return Status::OK(); +} + +Status OpKernelContext::set_output_ref(const string& name, mutex* mu, + Tensor* tensor_for_ref) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued output name '", + name, + "' when single-valued output was " + "expected"); + } + set_output_ref(start, mu, tensor_for_ref); + return Status::OK(); +} + +Status OpKernelContext::mutable_output(const string& name, Tensor** tensor) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued output name '", + name, + "' when single-valued output was " + "expected"); + } + *tensor = mutable_output(start); + return Status::OK(); +} + +Status OpKernelContext::release_output(const string& name, TensorValue* value) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued output name '", + name, + "' when single-valued output was " + "expected"); + } + *value = release_output(start); + return Status::OK(); +} + +bool OpKernelContext::ValidateInputsAreSameShape(OpKernel* op) { + const auto& inputs = *params_.inputs; + for (size_t i = 1; i < inputs.size(); ++i) { + if (!inputs[0]->IsSameSize(*(inputs[i].tensor))) { + SetStatus(errors::InvalidArgument( + "Inputs to operation ", op->name(), " of type ", op->type_string(), + " must have the same size and shape. Input 0: ", + inputs[0]->shape().DebugString(), " != input ", i, ": ", + inputs[i]->shape().DebugString())); + return false; + } + } + return true; +} + +Status OpKernelContext::MatchSignature(const DataTypeSlice expected_inputs, + const DataTypeSlice expected_outputs) { + DataTypeVector inputs; + for (const TensorValue& t : *params_.inputs) { + inputs.push_back(t.is_ref() ? MakeRefType(t->dtype()) : t->dtype()); + } + DataTypeVector outputs = params_.op_kernel->output_types(); + return MatchSignatureHelper(expected_inputs, expected_outputs, inputs, + outputs); +} + +// OpKernel registration ------------------------------------------------------ + +struct KernelRegistration { + KernelRegistration(const KernelDef& d, + kernel_factory::OpKernelRegistrar::Factory f) + : def(d), factory(f) {} + const KernelDef def; + const kernel_factory::OpKernelRegistrar::Factory factory; +}; + +// This maps from 'op_type' + DeviceType to the set of KernelDefs and +// factory functions for instantiating the OpKernel that matches the +// KernelDef. +typedef std::unordered_multimap<string, KernelRegistration> KernelRegistry; + +static KernelRegistry* GlobalKernelRegistry() { + static KernelRegistry* global_kernel_registry = new KernelRegistry; + return global_kernel_registry; +} + +static string Key(const string& op_type, DeviceType device_type, + const string& label) { + return strings::StrCat(op_type, ":", DeviceTypeString(device_type), ":", + label); +} + +namespace kernel_factory { + +OpKernelRegistrar::OpKernelRegistrar(const KernelDef* kernel_def, + Factory factory) { + const string key = + Key(kernel_def->op(), DeviceType(kernel_def->device_type()), + kernel_def->label()); + GlobalKernelRegistry()->insert( + std::make_pair(key, KernelRegistration(*kernel_def, factory))); + delete kernel_def; +} + +} // namespace kernel_factory + +namespace { + +// Helper for AttrsMatch(). +bool InTypeList(DataType dt, const AttrValue& type_list) { + for (int in_list : type_list.list().type()) { + if (dt == in_list) return true; + } + return false; +} + +// Returns whether the attrs in the NodeDef satisfy the constraints in +// the kernel_def. Returns an error if attrs in kernel_def are not +// found, or have a mismatching type. +Status AttrsMatch(const NodeDef& node_def, const KernelDef& kernel_def, + bool* match) { + *match = false; + AttrSlice attrs(node_def); + for (const auto& constraint : kernel_def.constraint()) { + if (constraint.allowed_values().list().type_size() == 0) { + return errors::Unimplemented( + "KernelDef '", kernel_def.ShortDebugString(), + " has constraint on attr '", constraint.name(), + "' with unsupported type: ", + SummarizeAttrValue(constraint.allowed_values())); + } + + const AttrValue* found = attrs.Find(constraint.name()); + if (found) { + if (found->type() != DT_INVALID) { + if (!InTypeList(found->type(), constraint.allowed_values())) { + return Status::OK(); + } + } else { + if (!AttrValueHasType(*found, "list(type)").ok()) { + return errors::InvalidArgument( + "KernelDef '", kernel_def.ShortDebugString(), + "' has constraint on attr '", constraint.name(), + "' that has value '", SummarizeAttrValue(*found), + "' that does not have type 'type' or 'list(type)' in NodeDef '", + SummarizeNodeDef(node_def), "'"); + } + + for (int t : found->list().type()) { + if (!InTypeList(static_cast<DataType>(t), + constraint.allowed_values())) { + return Status::OK(); + } + } + } + } else { + return errors::InvalidArgument( + "OpKernel '", kernel_def.op(), "' has constraint on attr '", + constraint.name(), "' not in NodeDef '", SummarizeNodeDef(node_def), + "', KernelDef: '", kernel_def.ShortDebugString(), "'"); + } + } + *match = true; + return Status::OK(); +} + +Status FindKernelRegistration(DeviceType device_type, const NodeDef& node_def, + const KernelRegistration** reg) { + *reg = nullptr; + string label; // Label defaults to empty if not found in NodeDef. + GetNodeAttr(node_def, "_kernel", &label); + const string key = Key(node_def.op(), device_type, label); + auto regs = GlobalKernelRegistry()->equal_range(key); + for (auto iter = regs.first; iter != regs.second; ++iter) { + // If there is a kernel registered for the op and device_type, + // check that the attrs match. + bool match; + TF_RETURN_IF_ERROR(AttrsMatch(node_def, iter->second.def, &match)); + if (match) { + if (*reg != nullptr) { + return errors::InvalidArgument( + "Multiple OpKernel registrations match NodeDef '", + SummarizeNodeDef(node_def), "': '", (*reg)->def.ShortDebugString(), + "' and '", iter->second.def.ShortDebugString(), "'"); + } + *reg = &iter->second; + } + } + return Status::OK(); +} + +} // namespace + +Status SupportedDeviceTypesForNode( + const std::vector<DeviceType>& prioritized_types, const NodeDef& def, + DeviceTypeVector* device_types) { + // TODO(zhifengc): Changes the callers (SimplePlacer and + // DynamicPlacer) to consider the possibility that 'def' is call to + // a user-defined function and only calls this + // SupportedDeviceTypesForNode for primitive ops. + Status s; + const OpDef* op_def = OpRegistry::Global()->LookUp(def.op(), &s); + if (op_def) { + for (const DeviceType& device_type : prioritized_types) { + const KernelRegistration* reg = nullptr; + TF_RETURN_IF_ERROR(FindKernelRegistration(device_type, def, ®)); + if (reg != nullptr) device_types->push_back(device_type); + } + } else { + // Assumes that all device types support this node. + for (const DeviceType& device_type : prioritized_types) { + device_types->push_back(device_type); + } + } + return Status::OK(); +} + +std::unique_ptr<OpKernel> CreateOpKernel(DeviceType device_type, + DeviceBase* device, + Allocator* allocator, + const NodeDef& node_def, + Status* status) { + OpKernel* kernel = nullptr; + *status = CreateOpKernel(device_type, device, allocator, nullptr, node_def, + &kernel); + return std::unique_ptr<OpKernel>(kernel); +} + +Status CreateOpKernel(DeviceType device_type, DeviceBase* device, + Allocator* allocator, FunctionLibraryRuntime* flib, + const NodeDef& node_def, OpKernel** kernel) { + VLOG(1) << "Instantiating kernel for node: " << SummarizeNodeDef(node_def); + + // Look up the Op registered for this op name. + Status s; + const OpDef* op_def = OpRegistry::Global()->LookUp(node_def.op(), &s); + if (op_def == nullptr) return s; + + // Validate node_def against OpDef. + s = ValidateNodeDef(node_def, *op_def); + if (!s.ok()) return s; + + // Look up kernel registration. + const KernelRegistration* registration; + s = FindKernelRegistration(device_type, node_def, ®istration); + if (!s.ok()) { + errors::AppendToMessage(&s, " when instantiating ", node_def.op()); + return s; + } + if (registration == nullptr) { + s.Update(errors::NotFound("No registered '", node_def.op(), + "' OpKernel for ", DeviceTypeString(device_type), + " devices compatible with node ", + SummarizeNodeDef(node_def))); + return s; + } + + // Get signature from the OpDef & NodeDef + DataTypeVector inputs; + DataTypeVector outputs; + s.Update(InOutTypesForNode(node_def, *op_def, &inputs, &outputs)); + if (!s.ok()) { + errors::AppendToMessage(&s, " for node: ", SummarizeNodeDef(node_def)); + return s; + } + + // Everything needed for OpKernel construction. + OpKernelConstruction context(device_type, device, allocator, &node_def, + op_def, flib, inputs, outputs, &s); + *kernel = (*registration->factory)(&context); + if (!s.ok()) { + delete *kernel; + *kernel = nullptr; + } + return s; +} + +namespace { // Helper for MemoryTypesForNode. +// Fills memory_types for either input or output, setting everything +// to DEVICE_MEMORY except those args in host_memory_args. Removes +// elements of host_memory_args that were used. +void MemoryTypesHelper(const NameRangeMap& name_map, + std::vector<string>* host_memory_args, + MemoryTypeVector* memory_types) { + // Set total to the largest endpoint of anything in the name_map. + int total = 0; + for (const auto& item : name_map) { + total = std::max(total, item.second.second); + } + + // Now that we know the size, fill with the default 'DEVICE_MEMORY'. + memory_types->clear(); + memory_types->resize(total, DEVICE_MEMORY); + + // Update args that have been marked as in "HOST_MEMORY". + size_t keep = 0; + for (size_t i = 0; i < host_memory_args->size(); ++i) { + auto iter = name_map.find((*host_memory_args)[i]); + if (iter != name_map.end()) { + for (int j = iter->second.first; j < iter->second.second; ++j) { + (*memory_types)[j] = HOST_MEMORY; + } + } else { + // (*host_memory_args)[i] not found, save it for the next pass. + if (i > keep) (*host_memory_args)[keep] = (*host_memory_args)[i]; + ++keep; + } + } + host_memory_args->resize(keep); +} +} // namespace + +Status MemoryTypesForNode(DeviceType device_type, const NodeDef& ndef, + const OpDef& op_def, + const NameRangeMap& input_name_map, + const NameRangeMap& output_name_map, + MemoryTypeVector* input_memory_types, + MemoryTypeVector* output_memory_types) { + Status status; + const KernelRegistration* registration; + TF_RETURN_IF_ERROR(FindKernelRegistration(device_type, ndef, ®istration)); + + if (registration != nullptr) { + const auto& from_proto = registration->def.host_memory_arg(); + std::vector<string> host_memory_args(from_proto.begin(), from_proto.end()); + MemoryTypesHelper(input_name_map, &host_memory_args, input_memory_types); + MemoryTypesHelper(output_name_map, &host_memory_args, output_memory_types); + if (!host_memory_args.empty()) { + return errors::InvalidArgument( + "HostMemory args '", str_util::Join(host_memory_args, "', '"), + "' not found in OpDef: ", SummarizeOpDef(op_def)); + } + } + return status; +} + +Status MemoryTypesForNode(const OpRegistryInterface* op_registry, + DeviceType device_type, const NodeDef& ndef, + MemoryTypeVector* input_memory_types, + MemoryTypeVector* output_memory_types) { + // Look up the Op registered for this op name. + Status status; + const OpDef* op_def = op_registry->LookUp(ndef.op(), &status); + if (op_def == nullptr) return status; + + NameRangeMap inputs, outputs; + status = NameRangesForNode(ndef, *op_def, &inputs, &outputs); + if (!status.ok()) return status; + + return MemoryTypesForNode(device_type, ndef, *op_def, inputs, outputs, + input_memory_types, output_memory_types); +} + +namespace { + +bool FindArgInOp(const string& arg_name, + const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) { + for (const auto& arg : args) { + if (arg_name == arg.name()) { + return true; + } + } + return false; +} + +} // namespace + +Status ValidateKernelRegistrations(const OpRegistryInterface* op_registry) { + Status unused_status; + for (const auto& key_registration : *GlobalKernelRegistry()) { + const KernelDef& kernel_def(key_registration.second.def); + const OpDef* op_def = op_registry->LookUp(kernel_def.op(), &unused_status); + if (op_def == nullptr) { + // TODO(josh11b): Make this a hard error. + LOG(ERROR) << "OpKernel ('" << kernel_def.ShortDebugString() + << "') for unknown op: " << kernel_def.op(); + continue; + } + for (const auto& host_memory_arg : kernel_def.host_memory_arg()) { + if (!FindArgInOp(host_memory_arg, op_def->input_arg()) && + !FindArgInOp(host_memory_arg, op_def->output_arg())) { + return errors::InvalidArgument("HostMemory arg '", host_memory_arg, + "' not found in OpDef: ", + SummarizeOpDef(*op_def)); + } + } + } + return Status::OK(); +} + +template <> +const Eigen::ThreadPoolDevice& OpKernelContext::eigen_device() const { + return eigen_cpu_device(); +} + +template <> +const Eigen::GpuDevice& OpKernelContext::eigen_device() const { + return eigen_gpu_device(); +} + +} // namespace tensorflow |