diff options
31 files changed, 962 insertions, 385 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 37d9062612..dc0a2fc157 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -543,6 +543,7 @@ cc_library( "common_runtime/kernel_benchmark_testlib.h", "framework/fake_input.h", "framework/function_testlib.h", + "framework/shape_inference_testutil.h", "framework/tensor_testutil.h", "graph/testlib.h", # TODO(josh11b): Drop this once users are depending on @@ -559,6 +560,7 @@ cc_library( ":lib", ":proto_text", ":protos_all_cc", + ":shape_inference_testutil", ":tensor_testutil", ":test", "//tensorflow/core/kernels:constant_op", @@ -748,6 +750,8 @@ filegroup( srcs = [ "//tensorflow/core:framework/fake_input.cc", "//tensorflow/core:framework/fake_input.h", + "//tensorflow/core:framework/shape_inference_testutil.cc", + "//tensorflow/core:framework/shape_inference_testutil.h", "//tensorflow/core:framework/tensor_testutil.cc", "//tensorflow/core:framework/tensor_testutil.h", "//tensorflow/core:platform/test.h", @@ -1212,6 +1216,19 @@ cc_library( ], ) +cc_library( + name = "shape_inference_testutil", + testonly = 1, + srcs = ["framework/shape_inference_testutil.cc"], + hdrs = ["framework/shape_inference_testutil.h"], + copts = tf_copts(), + deps = [ + ":framework", + ":lib", + ":test", + ], +) + # Main program for tests cc_library( name = "test_main", diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index c243b7fc18..1ccf66ed34 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -301,9 +301,7 @@ FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl( lib_def_(lib_def), optimizer_(optimizer_options) { get_func_sig_ = [this](const string& op, const OpDef** sig) { - Status s; - *sig = lib_def_->LookUp(op, &s); - return s; + return lib_def_->LookUpOpDef(op, sig); }; create_kernel_ = [this](const NodeDef& ndef, OpKernel** kernel) { return CreateKernel(ndef, kernel); @@ -689,9 +687,9 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, } bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) { - Status s; - auto sig = lib_def_->LookUp(func, &s); - return s.ok() && sig->is_stateful(); + const OpDef* op_def; + const Status s = lib_def_->LookUpOpDef(func, &op_def); + return s.ok() && op_def->is_stateful(); } FunctionLibraryRuntime* NewFunctionLibraryRuntime( diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index 13c8de8384..d5fc71195f 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -38,9 +38,7 @@ namespace tensorflow { typedef FunctionDefHelper FDH; Status GetOpSig(const string& op, const OpDef** sig) { - Status s; - *sig = OpRegistry::Global()->LookUp(op, &s); - return s; + return OpRegistry::Global()->LookUpOpDef(op, sig); } void FunctionTestSchedClosure(std::function<void()> fn) { diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index be6dfd600c..0d437346a4 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -708,14 +708,19 @@ Status FunctionCallFrame::SetRetval(int index, const Tensor& val) { FunctionLibraryDefinition::FunctionLibraryDefinition( const FunctionLibraryDefinition& other) - : function_defs_(other.function_defs_), func_grad_(other.func_grad_) {} + : func_grad_(other.func_grad_) { + for (const auto& it : other.function_defs_) { + TF_CHECK_OK(AddFunctionDef(it.second->fdef)); + } +} FunctionLibraryDefinition::FunctionLibraryDefinition( const FunctionDefLibrary& def_lib) : function_defs_(def_lib.function_size()) { for (const auto& fdef : def_lib.function()) { // The latter function definition wins. - function_defs_[fdef.signature().name()] = fdef; + auto& ptr = function_defs_[fdef.signature().name()]; + ptr.reset(new FunctionDefAndOpRegistration(fdef)); } for (const auto& grad : def_lib.gradient()) { func_grad_[grad.function_name()] = grad.gradient_func(); @@ -729,16 +734,18 @@ const FunctionDef* FunctionLibraryDefinition::Find(const string& name) const { if (iter == function_defs_.end()) { return nullptr; } else { - return &iter->second; + return &iter->second->fdef; } } Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) { - if (!function_defs_.insert({fdef.signature().name(), fdef}).second) { + auto& ptr = function_defs_[fdef.signature().name()]; + if (ptr != nullptr) { return errors::InvalidArgument("Function with name: ", fdef.signature().name(), " already exists in function library."); } + ptr.reset(new FunctionDefAndOpRegistration(fdef)); return Status::OK(); } @@ -746,19 +753,20 @@ string FunctionLibraryDefinition::FindGradient(const string& func) const { return gtl::FindWithDefault(func_grad_, func, ""); } -const OpDef* FunctionLibraryDefinition::LookUp(const string& op, - Status* status) const { - auto fdef = Find(op); - if (fdef != nullptr) { - return &(fdef->signature()); +Status FunctionLibraryDefinition::LookUp( + const string& op, const OpRegistrationData** op_reg_data) const { + auto iter = function_defs_.find(op); + if (iter != function_defs_.end()) { + *op_reg_data = &iter->second->op_registration_data; + return Status::OK(); } - return OpRegistry::Global()->LookUp(op, status); + return OpRegistry::Global()->LookUp(op, op_reg_data); } FunctionDefLibrary FunctionLibraryDefinition::ToProto() const { FunctionDefLibrary lib; for (const auto& f : function_defs_) { - *lib.add_function() = f.second; + *lib.add_function() = f.second->fdef; } for (const auto& g : func_grad_) { GradientDef* gd = lib.add_gradient(); @@ -845,7 +853,10 @@ FunctionDef FunctionDefHelper::Define(const string& name, for (const auto& a : arg_def) b.Input(a); for (const auto& r : ret_def) b.Output(r); for (const auto& a : attr_def) b.Attr(a); - TF_CHECK_OK(b.Finalize(fdef.mutable_signature())); + + OpRegistrationData op_reg_data; + TF_CHECK_OK(b.Finalize(&op_reg_data)); + fdef.mutable_signature()->Swap(&op_reg_data.op_def); for (const auto& n : node_def) { *(fdef.add_node()) = n.ToProto(); } diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 0b05aa4a8d..708861fb9d 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -277,14 +277,25 @@ class FunctionLibraryDefinition : public OpRegistryInterface { // // If "op" is defined in the library, returns its signature. // Otherwise, assume "op" is a primitive op and returns its op - // signature. - const OpDef* LookUp(const string& op, Status* status) const override; + // signature and shape inference function. + Status LookUp(const string& op_type_name, + const OpRegistrationData** op_reg_data) const override; // Returns a proto representation of the state of this function library. FunctionDefLibrary ToProto() const; private: - std::unordered_map<string, FunctionDef> function_defs_; + // TODO(cwhipkey): support shape functions in FunctionDefLibrary. + struct FunctionDefAndOpRegistration { + FunctionDefAndOpRegistration(const FunctionDef& fdef_in) + : fdef(fdef_in), op_registration_data(fdef.signature()) {} + + FunctionDef fdef; + OpRegistrationData op_registration_data; + }; + + std::unordered_map<string, std::unique_ptr<FunctionDefAndOpRegistration>> + function_defs_; std::unordered_map<string, string> func_grad_; }; diff --git a/tensorflow/core/framework/function_test.cc b/tensorflow/core/framework/function_test.cc index 5849051062..d7e967ec3d 100644 --- a/tensorflow/core/framework/function_test.cc +++ b/tensorflow/core/framework/function_test.cc @@ -33,9 +33,7 @@ namespace tensorflow { typedef FunctionDefHelper FDH; Status GetOpSig(const string& op, const OpDef** sig) { - Status s; - *sig = OpRegistry::Global()->LookUp(op, &s); - return s; + return OpRegistry::Global()->LookUpOpDef(op, sig); } REGISTER_OP("One") @@ -643,12 +641,12 @@ TEST(FunctionLibraryDefinitionTest, LookUp) { *proto.add_function() = test::function::XTimesTwo(); FunctionLibraryDefinition lib_def(proto); - Status s; - EXPECT_EQ(lib_def.LookUp("XTimes16", &s), nullptr); + const OpDef* op_def; + EXPECT_TRUE(!lib_def.LookUpOpDef("XTimes16", &op_def).ok()); - auto found = lib_def.LookUp("XTimesTwo", &s); - ASSERT_NE(found, nullptr); - EXPECT_EQ(found->DebugString(), + TF_EXPECT_OK(lib_def.LookUpOpDef("XTimesTwo", &op_def)); + ASSERT_NE(op_def, nullptr); + EXPECT_EQ(op_def->DebugString(), test::function::XTimesTwo().signature().DebugString()); } @@ -662,14 +660,15 @@ TEST(FunctionLibraryDefinitionTest, AddFunctionDef) { TF_EXPECT_OK(lib_def.AddFunctionDef(test::function::WXPlusB())); // Test lookup of first function. - Status s; - auto first = lib_def.LookUp("XTimesTwo", &s); + const OpDef* first; + TF_EXPECT_OK(lib_def.LookUpOpDef("XTimesTwo", &first)); ASSERT_NE(first, nullptr); EXPECT_EQ(first->DebugString(), test::function::XTimesTwo().signature().DebugString()); // Test lookup of second function. - auto second = lib_def.LookUp("WXPlusB", &s); + const OpDef* second; + TF_EXPECT_OK(lib_def.LookUpOpDef("WXPlusB", &second)); ASSERT_NE(second, nullptr); EXPECT_EQ(second->DebugString(), test::function::WXPlusB().signature().DebugString()); @@ -689,18 +688,14 @@ TEST(FunctionLibraryDefinitionTest, ToProto) { FunctionLibraryDefinition lib_def2(proto2); // Test that the first function exists in both libraries. - Status s; - auto f1 = lib_def1.LookUp("XTimesTwo", &s); - TF_EXPECT_OK(s); - auto f2 = lib_def1.LookUp("XTimesTwo", &s); - TF_EXPECT_OK(s); + const OpDef *f1, *f2, *f3, *f4; + TF_EXPECT_OK(lib_def1.LookUpOpDef("XTimesTwo", &f1)); + TF_EXPECT_OK(lib_def2.LookUpOpDef("XTimesTwo", &f2)); EXPECT_EQ(f1->DebugString(), f2->DebugString()); // Test that the second function exists in both libraries. - auto f3 = lib_def1.LookUp("WXPlusB", &s); - TF_EXPECT_OK(s); - auto f4 = lib_def1.LookUp("WXPlusB", &s); - TF_EXPECT_OK(s); + TF_EXPECT_OK(lib_def1.LookUpOpDef("WXPlusB", &f3)); + TF_EXPECT_OK(lib_def2.LookUpOpDef("WXPlusB", &f4)); EXPECT_EQ(f3->DebugString(), f4->DebugString()); } diff --git a/tensorflow/core/framework/graph_def_util.cc b/tensorflow/core/framework/graph_def_util.cc index e1e0729e8d..35b765cc51 100644 --- a/tensorflow/core/framework/graph_def_util.cc +++ b/tensorflow/core/framework/graph_def_util.cc @@ -56,17 +56,14 @@ Status AddDefaultAttrsToGraphDef(GraphDef* graph_def, node_offset, " with total nodes in graph: ", graph_def->node_size()); } - Status s; for (int i = node_offset; i < graph_def->node_size(); ++i) { NodeDef* node_def = graph_def->mutable_node(i); - const OpDef* op_def = op_registry.LookUp(node_def->op(), &s); - if (!s.ok()) { - return s; - } + const OpDef* op_def; + TF_RETURN_IF_ERROR(op_registry.LookUpOpDef(node_def->op(), &op_def)); AddDefaultsToNodeDef(*op_def, node_def); } - return s; + return Status::OK(); } Status RemoveNewDefaultAttrsFromGraphDef( @@ -77,12 +74,13 @@ Status RemoveNewDefaultAttrsFromGraphDef( std::vector<string> to_remove; for (int n = 0; n < graph_def->node_size(); ++n) { NodeDef* node_def = graph_def->mutable_node(n); - const OpDef* producer_op_def = - producer_op_registry.LookUp(node_def->op(), &s); - if (!s.ok()) return s; - const OpDef* consumer_op_def = - consumer_op_registry.LookUp(node_def->op(), &s); - if (!s.ok()) return s; + const OpDef* producer_op_def; + const OpDef* consumer_op_def; + + TF_RETURN_IF_ERROR( + producer_op_registry.LookUpOpDef(node_def->op(), &producer_op_def)); + TF_RETURN_IF_ERROR( + consumer_op_registry.LookUpOpDef(node_def->op(), &consumer_op_def)); for (const auto& attr : node_def->attr()) { // If the attr is not in consumer_op_def and doesn't start with '_'... @@ -172,13 +170,12 @@ Status StrippedOpListForGraph(const GraphDef& graph_def, OpsUsedByGraph(graph_def, &used_ops); // Build the stripped op list in sorted order, ignoring functions. - Status status; stripped_op_list->clear_op(); for (const string& op_name : used_ops) { - const OpDef* op = op_registry.LookUp(op_name, &status); - if (!op) return status; + const OpDef* op_def; + TF_RETURN_IF_ERROR(op_registry.LookUpOpDef(op_name, &op_def)); OpDef* stripped_op = stripped_op_list->add_op(); - stripped_op->CopyFrom(*op); + stripped_op->CopyFrom(*op_def); RemoveDescriptionsFromOpDef(stripped_op); } return Status::OK(); diff --git a/tensorflow/core/framework/graph_def_util_test.cc b/tensorflow/core/framework/graph_def_util_test.cc index b4df868193..98f6e9b89b 100644 --- a/tensorflow/core/framework/graph_def_util_test.cc +++ b/tensorflow/core/framework/graph_def_util_test.cc @@ -27,12 +27,19 @@ limitations under the License. namespace tensorflow { namespace { +Status FinalizeOpDef(OpDefBuilder b, OpDef* op_def) { + OpRegistrationData op_reg_data; + const Status s = b.Finalize(&op_reg_data); + *op_def = op_reg_data.op_def; + return s; +} + // Producer and consumer have default for an attr -> graph unchanged. TEST(RemoveNewDefaultAttrsFromGraphDefTest, NoChangeWithDefault) { OpList op_list; - TF_ASSERT_OK(OpDefBuilder("NoChangeWithDefault") - .Attr("a: int = 12") - .Finalize(op_list.add_op())); + TF_ASSERT_OK( + FinalizeOpDef(OpDefBuilder("NoChangeWithDefault").Attr("a: int = 12"), + op_list.add_op())); OpListOpRegistry registry(&op_list); GraphDef graph_def; @@ -51,9 +58,8 @@ TEST(RemoveNewDefaultAttrsFromGraphDefTest, NoChangeWithDefault) { // Producer and consumer both have an attr -> graph unchanged. TEST(RemoveNewDefaultAttrsFromGraphDefTest, NoChangeNoDefault) { OpList op_list; - TF_ASSERT_OK(OpDefBuilder("NoChangeNoDefault") - .Attr("a: int") - .Finalize(op_list.add_op())); + TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("NoChangeNoDefault").Attr("a: int"), + op_list.add_op())); OpListOpRegistry registry(&op_list); GraphDef graph_def; @@ -75,13 +81,13 @@ TEST(RemoveNewDefaultAttrsFromGraphDefTest, NoChangeNoDefault) { // attr removed from graph (and so able to be consumed). TEST(RemoveNewDefaultAttrsFromGraphDefTest, UsesDefault) { OpList consumer_op_list; - TF_ASSERT_OK(OpDefBuilder("UsesDefault").Finalize(consumer_op_list.add_op())); + TF_ASSERT_OK( + FinalizeOpDef(OpDefBuilder("UsesDefault"), consumer_op_list.add_op())); OpListOpRegistry consumer_registry(&consumer_op_list); OpList producer_op_list; - TF_ASSERT_OK(OpDefBuilder("UsesDefault") - .Attr("a: int = 17") - .Finalize(producer_op_list.add_op())); + TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("UsesDefault").Attr("a: int = 17"), + producer_op_list.add_op())); OpListOpRegistry producer_registry(&producer_op_list); GraphDef produced_graph_def; @@ -107,14 +113,14 @@ TEST(RemoveNewDefaultAttrsFromGraphDefTest, UsesDefault) { // graph unchanged (but not able to be consumed by consumer). TEST(RemoveNewDefaultAttrsFromGraphDefTest, ChangedFromDefault) { OpList consumer_op_list; - TF_ASSERT_OK( - OpDefBuilder("ChangedFromDefault").Finalize(consumer_op_list.add_op())); + TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("ChangedFromDefault"), + consumer_op_list.add_op())); OpListOpRegistry consumer_registry(&consumer_op_list); OpList producer_op_list; - TF_ASSERT_OK(OpDefBuilder("ChangedFromDefault") - .Attr("a: int = 17") - .Finalize(producer_op_list.add_op())); + TF_ASSERT_OK( + FinalizeOpDef(OpDefBuilder("ChangedFromDefault").Attr("a: int = 17"), + producer_op_list.add_op())); OpListOpRegistry producer_registry(&producer_op_list); GraphDef produced_graph_def; @@ -136,11 +142,13 @@ TEST(RemoveNewDefaultAttrsFromGraphDefTest, ChangedFromDefault) { // Attrs starting with underscores should not be removed. TEST(RemoveNewDefaultAttrsFromGraphDefTest, UnderscoreAttrs) { OpList consumer_op_list; - TF_ASSERT_OK(OpDefBuilder("Underscore").Finalize(consumer_op_list.add_op())); + TF_ASSERT_OK( + FinalizeOpDef(OpDefBuilder("Underscore"), consumer_op_list.add_op())); OpListOpRegistry consumer_registry(&consumer_op_list); OpList producer_op_list; - TF_ASSERT_OK(OpDefBuilder("Underscore").Finalize(producer_op_list.add_op())); + TF_ASSERT_OK( + FinalizeOpDef(OpDefBuilder("Underscore"), producer_op_list.add_op())); // Add the _underscore attr manually since OpDefBuilder would complain OpDef::AttrDef* attr = producer_op_list.mutable_op(0)->add_attr(); attr->set_name("_underscore"); diff --git a/tensorflow/core/framework/memory_types.cc b/tensorflow/core/framework/memory_types.cc index 776bd5e2fc..d538228494 100644 --- a/tensorflow/core/framework/memory_types.cc +++ b/tensorflow/core/framework/memory_types.cc @@ -83,13 +83,12 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry, MemoryTypeVector* inp_mtypes, MemoryTypeVector* out_mtypes) { // Look up the Op registered for this op name. - Status status; - const OpDef* op_def = op_registry->LookUp(ndef.op(), &status); - if (op_def == nullptr) return status; + const OpDef* op_def; + TF_RETURN_IF_ERROR(op_registry->LookUpOpDef(ndef.op(), &op_def)); // Look up the Kernel registered for this node def. const KernelDef* kdef = nullptr; - status = + Status status = FindKernelDef(device_type, ndef, &kdef, nullptr /* kernel_class_name */); if (!status.ok() || HasTypeList(*op_def)) { diff --git a/tensorflow/core/framework/node_def_builder.cc b/tensorflow/core/framework/node_def_builder.cc index 47b9fa5b9e..9385d1266a 100644 --- a/tensorflow/core/framework/node_def_builder.cc +++ b/tensorflow/core/framework/node_def_builder.cc @@ -38,13 +38,12 @@ void NodeDefBuilder::NodeOut::Reset(StringPiece n, int i, DataType dt) { NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name, const OpRegistryInterface* op_registry) { node_def_.set_name(name.ToString()); - Status status; - op_def_ = op_registry->LookUp(op_name.ToString(), &status); - if (op_def_ == nullptr) { + const Status status = op_registry->LookUpOpDef(op_name.ToString(), &op_def_); + if (status.ok()) { + Initialize(); + } else { errors_.push_back(status.error_message()); inputs_specified_ = 0; - } else { - Initialize(); } } diff --git a/tensorflow/core/framework/node_def_builder_test.cc b/tensorflow/core/framework/node_def_builder_test.cc index d2b29397cd..861ad9d7d0 100644 --- a/tensorflow/core/framework/node_def_builder_test.cc +++ b/tensorflow/core/framework/node_def_builder_test.cc @@ -32,7 +32,9 @@ class NodeDefBuilderTest : public ::testing::Test { protected: // Specify an OpDef via an OpDefBuilder. void Op(const OpDefBuilder& op_def_builder) { - TF_EXPECT_OK(op_def_builder.Finalize(&op_def_)); + OpRegistrationData op_reg_data; + TF_EXPECT_OK(op_def_builder.Finalize(&op_reg_data)); + op_def_ = op_reg_data.op_def; } // Resets builder_ with a new NodeDefBuilder using the Op from the last call diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc index bc02f39d30..b60709c5a4 100644 --- a/tensorflow/core/framework/node_def_util_test.cc +++ b/tensorflow/core/framework/node_def_util_test.cc @@ -28,9 +28,9 @@ namespace tensorflow { namespace { OpDef ToOpDef(const OpDefBuilder& builder) { - OpDef op_def; - TF_EXPECT_OK(builder.Finalize(&op_def)); - return op_def; + OpRegistrationData op_reg_data; + TF_EXPECT_OK(builder.Finalize(&op_reg_data)); + return op_reg_data.op_def; } NodeDef ToNodeDef(const string& text) { diff --git a/tensorflow/core/framework/op.cc b/tensorflow/core/framework/op.cc index bbe0513814..2ea735a790 100644 --- a/tensorflow/core/framework/op.cc +++ b/tensorflow/core/framework/op.cc @@ -32,38 +32,51 @@ namespace tensorflow { OpRegistryInterface::~OpRegistryInterface() {} +Status OpRegistryInterface::LookUpOpDef(const string& op_type_name, + const OpDef** op_def) const { + *op_def = nullptr; + const OpRegistrationData* op_reg_data = nullptr; + TF_RETURN_IF_ERROR(LookUp(op_type_name, &op_reg_data)); + *op_def = &op_reg_data->op_def; + return Status::OK(); +} + OpRegistry::OpRegistry() : initialized_(false) {} -void OpRegistry::Register(const OpDef& op_def) { +OpRegistry::~OpRegistry() { + for (const auto& e : registry_) delete e.second; +} + +void OpRegistry::Register(std::unique_ptr<OpRegistrationData> op_reg_data) { + OpRegistrationData* raw_ptr = op_reg_data.get(); + mutex_lock lock(mu_); if (initialized_) { - TF_QCHECK_OK(RegisterAlreadyLocked(op_def)) << "Attempting to register: " - << SummarizeOpDef(op_def); + TF_QCHECK_OK(RegisterAlreadyLocked(std::move(op_reg_data))); } else { - deferred_.push_back(op_def); + deferred_.push_back(std::move(op_reg_data)); } if (watcher_) { - watcher_(op_def); + watcher_(raw_ptr->op_def); } } -const OpDef* OpRegistry::LookUp(const string& op_type_name, - Status* status) const { - const OpDef* op_def = nullptr; +Status OpRegistry::LookUp(const string& op_type_name, + const OpRegistrationData** op_reg_data) const { + *op_reg_data = nullptr; + const OpRegistrationData* res = nullptr; + bool first_call = false; { // Scope for lock. mutex_lock lock(mu_); first_call = CallDeferred(); - op_def = gtl::FindWithDefault(registry_, op_type_name, nullptr); + res = gtl::FindWithDefault(registry_, op_type_name, nullptr); // Note: Can't hold mu_ while calling Export() below. } if (first_call) { TF_QCHECK_OK(ValidateKernelRegistrations(*this)); } - if (op_def == nullptr) { - status->Update( - errors::NotFound("Op type not registered '", op_type_name, "'")); - VLOG(1) << status->ToString(); + if (res == nullptr) { static bool first_unregistered = true; if (first_unregistered) { OpList op_list; @@ -74,15 +87,20 @@ const OpDef* OpRegistry::LookUp(const string& op_type_name, } first_unregistered = false; } + Status status = + errors::NotFound("Op type not registered '", op_type_name, "'"); + VLOG(1) << status.ToString(); + return status; } - return op_def; + *op_reg_data = res; + return Status::OK(); } void OpRegistry::GetRegisteredOps(std::vector<OpDef>* op_defs) { mutex_lock lock(mu_); CallDeferred(); - for (auto p : registry_) { - op_defs->push_back(*p.second); + for (const auto& p : registry_) { + op_defs->push_back(p.second->op_def); } } @@ -100,8 +118,8 @@ void OpRegistry::Export(bool include_internal, OpList* ops) const { mutex_lock lock(mu_); CallDeferred(); - std::vector<std::pair<string, const OpDef*>> sorted(registry_.begin(), - registry_.end()); + std::vector<std::pair<string, const OpRegistrationData*>> sorted( + registry_.begin(), registry_.end()); std::sort(sorted.begin(), sorted.end()); auto out = ops->mutable_op(); @@ -110,7 +128,7 @@ void OpRegistry::Export(bool include_internal, OpList* ops) const { for (const auto& item : sorted) { if (include_internal || !StringPiece(item.first).starts_with("_")) { - *out->Add() = *item.second; + *out->Add() = item.second->op_def; } } } @@ -128,23 +146,23 @@ string OpRegistry::DebugString(bool include_internal) const { bool OpRegistry::CallDeferred() const { if (initialized_) return false; initialized_ = true; - for (const auto& op_def : deferred_) { - TF_QCHECK_OK(RegisterAlreadyLocked(op_def)) << "Attempting to register: " - << SummarizeOpDef(op_def); + for (int i = 0; i < deferred_.size(); ++i) { + TF_QCHECK_OK(RegisterAlreadyLocked(std::move(deferred_[i]))); } deferred_.clear(); return true; } -Status OpRegistry::RegisterAlreadyLocked(const OpDef& def) const { - TF_RETURN_IF_ERROR(ValidateOpDef(def)); +Status OpRegistry::RegisterAlreadyLocked( + std::unique_ptr<OpRegistrationData> op_reg_data) const { + TF_RETURN_IF_ERROR(ValidateOpDef(op_reg_data->op_def)); - std::unique_ptr<OpDef> copy(new OpDef(def)); - if (gtl::InsertIfNotPresent(®istry_, def.name(), copy.get())) { - copy.release(); // Ownership transferred to op_registry + if (gtl::InsertIfNotPresent(®istry_, op_reg_data->op_def.name(), + op_reg_data.get())) { + op_reg_data.release(); // Ownership transferred to op_registry return Status::OK(); } else { - return errors::AlreadyExists("Op with name ", def.name()); + return errors::AlreadyExists("Op with name ", op_reg_data->op_def.name()); } } @@ -158,19 +176,25 @@ OpRegistry* OpRegistry::Global() { OpListOpRegistry::OpListOpRegistry(const OpList* op_list) { for (const OpDef& op_def : op_list->op()) { - index_[op_def.name()] = &op_def; + auto* op_reg_data = new OpRegistrationData(); + op_reg_data->op_def = op_def; + index_[op_def.name()] = op_reg_data; } } -const OpDef* OpListOpRegistry::LookUp(const string& op_type_name, - Status* status) const { +OpListOpRegistry::~OpListOpRegistry() { + for (const auto& e : index_) delete e.second; +} + +Status OpListOpRegistry::LookUp(const string& op_type_name, + const OpRegistrationData** op_reg_data) const { auto iter = index_.find(op_type_name); if (iter == index_.end()) { - status->Update( - errors::NotFound("Op type not registered '", op_type_name, "'")); - return nullptr; + *op_reg_data = nullptr; + return errors::NotFound("Op type not registered '", op_type_name, "'"); } - return iter->second; + *op_reg_data = iter->second; + return Status::OK(); } // Other registration --------------------------------------------------------- @@ -178,9 +202,9 @@ const OpDef* OpListOpRegistry::LookUp(const string& op_type_name, namespace register_op { OpDefBuilderReceiver::OpDefBuilderReceiver( const OpDefBuilderWrapper<true>& wrapper) { - OpDef op_def; - wrapper.builder().Finalize(&op_def); - OpRegistry::Global()->Register(op_def); + std::unique_ptr<OpRegistrationData> data(new OpRegistrationData); + wrapper.builder().Finalize(data.get()); + OpRegistry::Global()->Register(std::move(data)); } } // namespace register_op diff --git a/tensorflow/core/framework/op.h b/tensorflow/core/framework/op.h index 2a99597721..cd484dacd4 100644 --- a/tensorflow/core/framework/op.h +++ b/tensorflow/core/framework/op.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/framework/selective_registration.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -42,11 +43,14 @@ class OpRegistryInterface { public: virtual ~OpRegistryInterface(); - // Returns nullptr and sets *status if no OpDef is registered under that - // name, otherwise returns the registered OpDef. + // Returns an error status and sets *op_reg_data to nullptr if no OpDef is + // registered under that name, otherwise returns the registered OpDef. // Caller must not delete the returned pointer. - virtual const OpDef* LookUp(const string& op_type_name, - Status* status) const = 0; + virtual Status LookUp(const string& op_type_name, + const OpRegistrationData** op_reg_data) const = 0; + + // Shorthand for calling LookUp to get the OpDef. + Status LookUpOpDef(const string& op_type_name, const OpDef** op_def) const; }; // The standard implementation of OpRegistryInterface, along with a @@ -62,17 +66,13 @@ class OpRegistryInterface { class OpRegistry : public OpRegistryInterface { public: OpRegistry(); - ~OpRegistry() override {} + ~OpRegistry() override; - // Calls func() and registers the returned OpDef. Since Register() - // is normally called during program initialization (before main()), - // we defer calling func() until the first call to LookUp() or - // Export() (if one of those has already been called, func() is - // called immediately). - void Register(const OpDef& op_def); + // Calls watcher and registers the passed OpDef. + void Register(std::unique_ptr<OpRegistrationData> op_data); - const OpDef* LookUp(const string& op_type_name, - Status* status) const override; + Status LookUp(const string& op_type_name, + const OpRegistrationData** op_reg_data) const override; // Fills *ops with all registered OpDefs (except those with names // starting with '_' if include_internal == false). @@ -111,15 +111,19 @@ class OpRegistry : public OpRegistryInterface { // time it is called. bool CallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_); - // Add 'def' to the registry. On failure, or if there is already an - // OpDef with that name registered, returns a non-okay status. - Status RegisterAlreadyLocked(const OpDef& def) const - EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Add 'def' to the registry with additional data 'data'. On failure, or if + // there is already an OpDef with that name registered, returns a non-okay + // status. + Status RegisterAlreadyLocked(std::unique_ptr<OpRegistrationData> op_data) + const EXCLUSIVE_LOCKS_REQUIRED(mu_); mutable mutex mu_; // Functions in deferred_ may only be called with mu_ held. - mutable std::vector<OpDef> deferred_ GUARDED_BY(mu_); - mutable std::unordered_map<string, OpDef*> registry_ GUARDED_BY(mu_); + mutable std::vector<std::unique_ptr<OpRegistrationData>> deferred_ + GUARDED_BY(mu_); + // Values are owned. + mutable std::unordered_map<string, const OpRegistrationData*> registry_ + GUARDED_BY(mu_); mutable bool initialized_ GUARDED_BY(mu_); // Registry watcher. @@ -127,16 +131,21 @@ class OpRegistry : public OpRegistryInterface { }; // An adapter to allow an OpList to be used as an OpRegistryInterface. +// +// Note that shape inference functions are not passed in to OpListOpRegistry, so +// it will return an unusable shape inference function for every op it supports; +// therefore, it should only be used in contexts where this is okay. class OpListOpRegistry : public OpRegistryInterface { public: // Does not take ownership of op_list, *op_list must outlive *this. OpListOpRegistry(const OpList* op_list); - ~OpListOpRegistry() override {} - const OpDef* LookUp(const string& op_type_name, - Status* status) const override; + ~OpListOpRegistry() override; + Status LookUp(const string& op_type_name, + const OpRegistrationData** op_reg_data) const override; private: - std::unordered_map<string, const OpDef*> index_; + // Values are owned. + std::unordered_map<string, const OpRegistrationData*> index_; }; // Treats 'registry_ptr' as a pointer to OpRegistry, and calls @@ -217,6 +226,10 @@ class OpDefBuilderWrapper<true> { builder_.Doc(text); return *this; } + OpDefBuilderWrapper<true>& SetShapeFn(const OpShapeInferenceFn& fn) { + builder_.SetShapeFn(fn); + return *this; + } const ::tensorflow::OpDefBuilder& builder() const { return builder_; } private: @@ -237,6 +250,9 @@ class OpDefBuilderWrapper<false> { OpDefBuilderWrapper<false>& SetAllowsUninitializedInput() { return *this; } OpDefBuilderWrapper<false>& Deprecated(int, StringPiece) { return *this; } OpDefBuilderWrapper<false>& Doc(StringPiece text) { return *this; } + OpDefBuilderWrapper<false>& SetShapeFn(const OpShapeInferenceFn& fn) { + return *this; + } }; struct OpDefBuilderReceiver { diff --git a/tensorflow/core/framework/op_compatibility_test.cc b/tensorflow/core/framework/op_compatibility_test.cc index 9cb1ead451..fd6a988563 100644 --- a/tensorflow/core/framework/op_compatibility_test.cc +++ b/tensorflow/core/framework/op_compatibility_test.cc @@ -40,11 +40,9 @@ class TestKernel : public OpKernel { class OpCompatibilityTest : public OpsTestBase { protected: const OpDef* RegisteredOpDef() { - Status status; - const OpDef* new_op_def = - OpRegistry::Global()->LookUp(node_def()->op(), &status); - TF_CHECK_OK(status); - return new_op_def; + const OpDef* op_def; + TF_CHECK_OK(OpRegistry::Global()->LookUpOpDef(node_def()->op(), &op_def)); + return op_def; } void ExpectSuccess(const OpDef& old_op_def) { @@ -170,11 +168,11 @@ REGISTER_OP("AddAttr").Output("ndef: string").Attr("a: int = 42"); REGISTER_KERNEL_BUILDER(Name("AddAttr").Device(DEVICE_CPU), TestKernel); TEST_F(OpCompatibilityTest, AddAttr) { - OpDef old_op_def; + OpRegistrationData old_op; TF_ASSERT_OK( - OpDefBuilder("AddAttr").Output("ndef: string").Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("add_attr", &old_op_def).Finalize(node_def())); - ExpectSuccess(old_op_def); + OpDefBuilder("AddAttr").Output("ndef: string").Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("add_attr", &old_op.op_def).Finalize(node_def())); + ExpectSuccess(old_op.op_def); EXPECT_EQ("add_attr = AddAttr[a=42]()", Result()); } @@ -183,15 +181,15 @@ REGISTER_OP("LessStrict").Output("ndef: string").Attr("a: {'A', 'B', 'C'}"); REGISTER_KERNEL_BUILDER(Name("LessStrict").Device(DEVICE_CPU), TestKernel); TEST_F(OpCompatibilityTest, LessStrict) { - OpDef old_op_def; + OpRegistrationData old_op; TF_ASSERT_OK(OpDefBuilder("LessStrict") .Output("ndef: string") .Attr("a: {'A', 'B'}") - .Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("less_strict", &old_op_def) + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("less_strict", &old_op.op_def) .Attr("a", "B") .Finalize(node_def())); - ExpectSuccess(old_op_def); + ExpectSuccess(old_op.op_def); EXPECT_EQ("less_strict = LessStrict[a=\"B\"]()", Result()); } @@ -201,15 +199,15 @@ REGISTER_KERNEL_BUILDER(Name("RemoveRestriction").Device(DEVICE_CPU), TestKernel); TEST_F(OpCompatibilityTest, RemoveRestriction) { - OpDef old_op_def; + OpRegistrationData old_op; TF_ASSERT_OK(OpDefBuilder("RemoveRestriction") .Output("ndef: string") .Attr("a: {int32, bool}") - .Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("remove_restriction", &old_op_def) + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("remove_restriction", &old_op.op_def) .Attr("a", DT_INT32) .Finalize(node_def())); - ExpectSuccess(old_op_def); + ExpectSuccess(old_op.op_def); EXPECT_EQ("remove_restriction = RemoveRestriction[a=DT_INT32]()", Result()); } @@ -218,17 +216,17 @@ REGISTER_OP("AttrOrder").Output("ndef: string").Attr("a: int").Attr("b: bool"); REGISTER_KERNEL_BUILDER(Name("AttrOrder").Device(DEVICE_CPU), TestKernel); TEST_F(OpCompatibilityTest, AttrOrder) { - OpDef old_op_def; + OpRegistrationData old_op; TF_ASSERT_OK(OpDefBuilder("AttrOrder") .Output("ndef: string") .Attr("b: bool") .Attr("a: int") - .Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("attr_order", &old_op_def) + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("attr_order", &old_op.op_def) .Attr("b", true) .Attr("a", 7) .Finalize(node_def())); - ExpectSuccess(old_op_def); + ExpectSuccess(old_op.op_def); EXPECT_EQ("attr_order = AttrOrder[a=7, b=true]()", Result()); } @@ -237,15 +235,15 @@ REGISTER_OP("AddDefault").Output("ndef: string").Attr("a: int = 1234"); REGISTER_KERNEL_BUILDER(Name("AddDefault").Device(DEVICE_CPU), TestKernel); TEST_F(OpCompatibilityTest, AddDefault) { - OpDef old_op_def; + OpRegistrationData old_op; TF_ASSERT_OK(OpDefBuilder("AddDefault") .Output("ndef: string") .Attr("a: int") - .Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("add_default", &old_op_def) + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("add_default", &old_op.op_def) .Attr("a", 765) .Finalize(node_def())); - ExpectSuccess(old_op_def); + ExpectSuccess(old_op.op_def); EXPECT_EQ("add_default = AddDefault[a=765]()", Result()); } @@ -255,14 +253,14 @@ REGISTER_OP("RemoveDefault").Output("ndef: string").Attr("a: int"); REGISTER_KERNEL_BUILDER(Name("RemoveDefault").Device(DEVICE_CPU), TestKernel); TEST_F(OpCompatibilityTest, RemoveDefault) { - OpDef old_op_def; + OpRegistrationData old_op; TF_ASSERT_OK(OpDefBuilder("RemoveDefault") .Output("ndef: string") .Attr("a: int = 91") - .Finalize(&old_op_def)); + .Finalize(&old_op)); TF_ASSERT_OK( - NodeDefBuilder("remove_default", &old_op_def).Finalize(node_def())); - ExpectSuccess(old_op_def); + NodeDefBuilder("remove_default", &old_op.op_def).Finalize(node_def())); + ExpectSuccess(old_op.op_def); EXPECT_EQ("remove_default = RemoveDefault[a=91]()", Result()); } @@ -275,15 +273,15 @@ REGISTER_OP("TypePolymorphic") REGISTER_KERNEL_BUILDER(Name("TypePolymorphic").Device(DEVICE_CPU), TestKernel); TEST_F(OpCompatibilityTest, TypePolymorphic) { - OpDef old_op_def; + OpRegistrationData old_op; TF_ASSERT_OK(OpDefBuilder("TypePolymorphic") .Input("a: int32") .Output("ndef: string") - .Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("type_polymorphic", &old_op_def) + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("type_polymorphic", &old_op.op_def) .Input(FakeInput()) .Finalize(node_def())); - ExpectSuccess(old_op_def); + ExpectSuccess(old_op.op_def); EXPECT_EQ("type_polymorphic = TypePolymorphic[T=DT_INT32](a)", Result()); } @@ -296,15 +294,15 @@ REGISTER_OP("MakeList") REGISTER_KERNEL_BUILDER(Name("MakeList").Device(DEVICE_CPU), TestKernel); TEST_F(OpCompatibilityTest, MakeList) { - OpDef old_op_def; + OpRegistrationData old_op; TF_ASSERT_OK(OpDefBuilder("MakeList") .Input("a: int32") .Output("ndef: string") - .Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("make_list", &old_op_def) + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("make_list", &old_op.op_def) .Input(FakeInput()) .Finalize(node_def())); - ExpectSuccess(old_op_def); + ExpectSuccess(old_op.op_def); EXPECT_EQ("make_list = MakeList[N=1](a)", Result()); } @@ -319,15 +317,15 @@ REGISTER_OP("MakePolyList") REGISTER_KERNEL_BUILDER(Name("MakePolyList").Device(DEVICE_CPU), TestKernel); TEST_F(OpCompatibilityTest, MakePolyList) { - OpDef old_op_def; + OpRegistrationData old_op; TF_ASSERT_OK(OpDefBuilder("MakePolyList") .Input("a: int32") .Output("ndef: string") - .Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("make_poly_list", &old_op_def) + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("make_poly_list", &old_op.op_def) .Input(FakeInput()) .Finalize(node_def())); - ExpectSuccess(old_op_def); + ExpectSuccess(old_op.op_def); EXPECT_EQ("make_poly_list = MakePolyList[N=1, T=DT_INT32](a)", Result()); } @@ -340,15 +338,15 @@ REGISTER_OP("MakeAnyList") REGISTER_KERNEL_BUILDER(Name("MakeAnyList").Device(DEVICE_CPU), TestKernel); TEST_F(OpCompatibilityTest, MakeAnyList) { - OpDef old_op_def; + OpRegistrationData old_op; TF_ASSERT_OK(OpDefBuilder("MakeAnyList") .Input("a: int32") .Output("ndef: string") - .Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("make_any_list", &old_op_def) + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("make_any_list", &old_op.op_def) .Input(FakeInput()) .Finalize(node_def())); - ExpectSuccess(old_op_def); + ExpectSuccess(old_op.op_def); EXPECT_EQ("make_any_list = MakeAnyList[T=[DT_INT32]](a)", Result()); } @@ -362,16 +360,16 @@ REGISTER_OP("PolyIntoList") REGISTER_KERNEL_BUILDER(Name("PolyIntoList").Device(DEVICE_CPU), TestKernel); TEST_F(OpCompatibilityTest, PolyIntoList) { - OpDef old_op_def; + OpRegistrationData old_op; TF_ASSERT_OK(OpDefBuilder("PolyIntoList") .Input("a: T") .Output("ndef: string") .Attr("T: type") - .Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("poly_into_list", &old_op_def) + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("poly_into_list", &old_op.op_def) .Input(FakeInput(DT_INT32)) .Finalize(node_def())); - ExpectSuccess(old_op_def); + ExpectSuccess(old_op.op_def); EXPECT_EQ("poly_into_list = PolyIntoList[N=1, T=DT_INT32](a)", Result()); } @@ -387,17 +385,17 @@ REGISTER_KERNEL_BUILDER(Name("MakeMultipleSameList").Device(DEVICE_CPU), TestKernel); TEST_F(OpCompatibilityTest, MakeMultipleSameList) { - OpDef old_op_def; + OpRegistrationData old_op; TF_ASSERT_OK(OpDefBuilder("MakeMultipleSameList") .Input("a: int32") .Input("b: int32") .Output("ndef: string") - .Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("make_list", &old_op_def) + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("make_list", &old_op.op_def) .Input(FakeInput()) .Input(FakeInput()) .Finalize(node_def())); - ExpectSuccess(old_op_def); + ExpectSuccess(old_op.op_def); EXPECT_EQ("make_list = MakeMultipleSameList[N=2](a, b)", Result()); } @@ -411,17 +409,17 @@ REGISTER_KERNEL_BUILDER(Name("MakeMultipleAnyList").Device(DEVICE_CPU), TestKernel); TEST_F(OpCompatibilityTest, MakeMultipleAnyList) { - OpDef old_op_def; + OpRegistrationData old_op; TF_ASSERT_OK(OpDefBuilder("MakeMultipleAnyList") .Input("a: int32") .Input("b: float") .Output("ndef: string") - .Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("make_list", &old_op_def) + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("make_list", &old_op.op_def) .Input(FakeInput()) .Input(FakeInput()) .Finalize(node_def())); - ExpectSuccess(old_op_def); + ExpectSuccess(old_op.op_def); EXPECT_EQ("make_list = MakeMultipleAnyList[T=[DT_INT32, DT_FLOAT]](a, b)", Result()); } @@ -431,15 +429,15 @@ REGISTER_OP("ChangeName").Input("y: int32").Output("ndef: string"); REGISTER_KERNEL_BUILDER(Name("ChangeName").Device(DEVICE_CPU), TestKernel); TEST_F(OpCompatibilityTest, ChangeName) { - OpDef old_op_def; + OpRegistrationData old_op; TF_ASSERT_OK(OpDefBuilder("ChangeName") .Input("x: int32") .Output("ndef: string") - .Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("change_name", &old_op_def) + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("change_name", &old_op.op_def) .Input(FakeInput()) .Finalize(node_def())); - ExpectSuccess(old_op_def); + ExpectSuccess(old_op.op_def); EXPECT_EQ("change_name = ChangeName[](a)", Result()); } @@ -452,11 +450,12 @@ REGISTER_OP("AddNInts") REGISTER_KERNEL_BUILDER(Name("AddNInts").Device(DEVICE_CPU), TestKernel); TEST_F(OpCompatibilityTest, AddNInts) { - OpDef old_op_def; + OpRegistrationData old_op; TF_ASSERT_OK( - OpDefBuilder("AddNInts").Output("ndef: string").Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("add_n_ints", &old_op_def).Finalize(node_def())); - ExpectSuccess(old_op_def); + OpDefBuilder("AddNInts").Output("ndef: string").Finalize(&old_op)); + TF_ASSERT_OK( + NodeDefBuilder("add_n_ints", &old_op.op_def).Finalize(node_def())); + ExpectSuccess(old_op.op_def); EXPECT_EQ("add_n_ints = AddNInts[N=0]()", Result()); } @@ -470,11 +469,12 @@ REGISTER_OP("AddNSame") REGISTER_KERNEL_BUILDER(Name("AddNSame").Device(DEVICE_CPU), TestKernel); TEST_F(OpCompatibilityTest, AddNSame) { - OpDef old_op_def; + OpRegistrationData old_op; + TF_ASSERT_OK( + OpDefBuilder("AddNSame").Output("ndef: string").Finalize(&old_op)); TF_ASSERT_OK( - OpDefBuilder("AddNSame").Output("ndef: string").Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("add_n_same", &old_op_def).Finalize(node_def())); - ExpectSuccess(old_op_def); + NodeDefBuilder("add_n_same", &old_op.op_def).Finalize(node_def())); + ExpectSuccess(old_op.op_def); EXPECT_EQ("add_n_same = AddNSame[N=0, T=DT_BOOL]()", Result()); } @@ -490,16 +490,16 @@ REGISTER_KERNEL_BUILDER(Name("AddNSameAsExisting").Device(DEVICE_CPU), TestKernel); TEST_F(OpCompatibilityTest, AddNSameAsExisting) { - OpDef old_op_def; + OpRegistrationData old_op; TF_ASSERT_OK(OpDefBuilder("AddNSameAsExisting") .Input("a: T") .Output("ndef: string") .Attr("T: type") - .Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("add_n_same_as_existing", &old_op_def) + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("add_n_same_as_existing", &old_op.op_def) .Input(FakeInput(DT_STRING)) .Finalize(node_def())); - ExpectSuccess(old_op_def); + ExpectSuccess(old_op.op_def); EXPECT_EQ("add_n_same_as_existing = AddNSameAsExisting[N=0, T=DT_STRING](a)", Result()); } @@ -513,12 +513,12 @@ REGISTER_OP("AddAnyList") REGISTER_KERNEL_BUILDER(Name("AddAnyList").Device(DEVICE_CPU), TestKernel); TEST_F(OpCompatibilityTest, AddAnyList) { - OpDef old_op_def; + OpRegistrationData old_op; TF_ASSERT_OK( - OpDefBuilder("AddAnyList").Output("ndef: string").Finalize(&old_op_def)); + OpDefBuilder("AddAnyList").Output("ndef: string").Finalize(&old_op)); TF_ASSERT_OK( - NodeDefBuilder("add_any_list", &old_op_def).Finalize(node_def())); - ExpectSuccess(old_op_def); + NodeDefBuilder("add_any_list", &old_op.op_def).Finalize(node_def())); + ExpectSuccess(old_op.op_def); EXPECT_EQ("add_any_list = AddAnyList[T=[]]()", Result()); } @@ -530,16 +530,16 @@ REGISTER_OP("ShorterAnyList") REGISTER_KERNEL_BUILDER(Name("ShorterAnyList").Device(DEVICE_CPU), TestKernel); TEST_F(OpCompatibilityTest, ShorterAnyList) { - OpDef old_op_def; + OpRegistrationData old_op; TF_ASSERT_OK(OpDefBuilder("ShorterAnyList") .Input("a: T") .Output("ndef: string") .Attr("T: list(type) >= 2") - .Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("shorter_any_list", &old_op_def) + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("shorter_any_list", &old_op.op_def) .Input(FakeInput(2, DT_BOOL)) .Finalize(node_def())); - ExpectSuccess(old_op_def); + ExpectSuccess(old_op.op_def); EXPECT_EQ("shorter_any_list = ShorterAnyList[T=[DT_BOOL, DT_BOOL]](a, a:1)", Result()); } @@ -551,16 +551,16 @@ REGISTER_OP("ShorterSameList") REGISTER_KERNEL_BUILDER(Name("ShorterSameList").Device(DEVICE_CPU), TestKernel); TEST_F(OpCompatibilityTest, ShorterSameList) { - OpDef old_op_def; + OpRegistrationData old_op; TF_ASSERT_OK(OpDefBuilder("ShorterSameList") .Input("a: N * int32") .Output("ndef: string") .Attr("N: int >= 2") - .Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("shorter_same_list", &old_op_def) + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("shorter_same_list", &old_op.op_def) .Input(FakeInput(2)) .Finalize(node_def())); - ExpectSuccess(old_op_def); + ExpectSuccess(old_op.op_def); EXPECT_EQ("shorter_same_list = ShorterSameList[N=2](a, a:1)", Result()); } @@ -571,15 +571,15 @@ REGISTER_KERNEL_BUILDER(Name("AttrRemoveRestriction").Device(DEVICE_CPU), TestKernel); TEST_F(OpCompatibilityTest, AttrRemoveRestriction) { - OpDef old_op_def; + OpRegistrationData old_op; TF_ASSERT_OK(OpDefBuilder("AttrRemoveRestriction") .Attr("t: {int32,int64}") .Output("ndef: string") - .Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("remove_restriction", &old_op_def) + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("remove_restriction", &old_op.op_def) .Attr("t", DT_INT32) .Finalize(node_def())); - ExpectSuccess(old_op_def); + ExpectSuccess(old_op.op_def); EXPECT_EQ("remove_restriction = AttrRemoveRestriction[t=DT_INT32]()", Result()); } @@ -593,15 +593,15 @@ REGISTER_KERNEL_BUILDER(Name("AttrLessRestrictive").Device(DEVICE_CPU), TestKernel); TEST_F(OpCompatibilityTest, AttrLessRestrictive) { - OpDef old_op_def; + OpRegistrationData old_op; TF_ASSERT_OK(OpDefBuilder("AttrLessRestrictive") .Attr("t: {int32, int64}") .Output("ndef: string") - .Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("less_restrictive", &old_op_def) + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("less_restrictive", &old_op.op_def) .Attr("t", DT_INT32) .Finalize(node_def())); - ExpectSuccess(old_op_def); + ExpectSuccess(old_op.op_def); EXPECT_EQ("less_restrictive = AttrLessRestrictive[t=DT_INT32]()", Result()); } @@ -611,15 +611,15 @@ REGISTER_OP("AttrRemoveMin").Attr("n: int").Output("ndef: string"); REGISTER_KERNEL_BUILDER(Name("AttrRemoveMin").Device(DEVICE_CPU), TestKernel); TEST_F(OpCompatibilityTest, AttrRemoveMin) { - OpDef old_op_def; + OpRegistrationData old_op; TF_ASSERT_OK(OpDefBuilder("AttrRemoveMin") .Attr("n: int >= 3") .Output("ndef: string") - .Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("remove_min", &old_op_def) + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("remove_min", &old_op.op_def) .Attr("n", 4) .Finalize(node_def())); - ExpectSuccess(old_op_def); + ExpectSuccess(old_op.op_def); EXPECT_EQ("remove_min = AttrRemoveMin[n=4]()", Result()); } @@ -629,15 +629,15 @@ REGISTER_OP("AttrLowerMin").Attr("n: int >= 1").Output("ndef: string"); REGISTER_KERNEL_BUILDER(Name("AttrLowerMin").Device(DEVICE_CPU), TestKernel); TEST_F(OpCompatibilityTest, AttrLowerMin) { - OpDef old_op_def; + OpRegistrationData old_op; TF_ASSERT_OK(OpDefBuilder("AttrLowerMin") .Attr("n: int >= 3") .Output("ndef: string") - .Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("lower_min", &old_op_def) + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("lower_min", &old_op.op_def) .Attr("n", 4) .Finalize(node_def())); - ExpectSuccess(old_op_def); + ExpectSuccess(old_op.op_def); EXPECT_EQ("lower_min = AttrLowerMin[n=4]()", Result()); } @@ -647,11 +647,12 @@ TEST_F(OpCompatibilityTest, AttrLowerMin) { REGISTER_OP("RemoveAttr"); TEST_F(OpCompatibilityTest, RemoveAttrFails) { - OpDef old_op_def; - TF_ASSERT_OK(OpDefBuilder("RemoveAttr").Attr("a: int").Finalize(&old_op_def)); - TF_ASSERT_OK( - NodeDefBuilder("fails", &old_op_def).Attr("a", 3).Finalize(node_def())); - ExpectInvalid(old_op_def, "NodeDef mentions attr 'a' not in", + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("RemoveAttr").Attr("a: int").Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def) + .Attr("a", 3) + .Finalize(node_def())); + ExpectInvalid(old_op.op_def, "NodeDef mentions attr 'a' not in", "Attr 'a' removed"); } @@ -659,10 +660,10 @@ TEST_F(OpCompatibilityTest, RemoveAttrFails) { REGISTER_OP("AddAttrNoDefault").Attr("a: int"); TEST_F(OpCompatibilityTest, AddAttrNoDefaultFails) { - OpDef old_op_def; - TF_ASSERT_OK(OpDefBuilder("AddAttrNoDefault").Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("fails", &old_op_def).Finalize(node_def())); - ExpectInvalid(old_op_def, "NodeDef missing attr 'a'", + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("AddAttrNoDefault").Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def).Finalize(node_def())); + ExpectInvalid(old_op.op_def, "NodeDef missing attr 'a'", "Attr 'a' added without default"); } @@ -670,10 +671,10 @@ TEST_F(OpCompatibilityTest, AddAttrNoDefaultFails) { REGISTER_OP("AddSingleInput").Input("a: int32"); TEST_F(OpCompatibilityTest, AddSingleInputFails) { - OpDef old_op_def; - TF_ASSERT_OK(OpDefBuilder("AddSingleInput").Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("fails", &old_op_def).Finalize(node_def())); - ExpectInvalid(old_op_def, + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("AddSingleInput").Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def).Finalize(node_def())); + ExpectInvalid(old_op.op_def, "expected inputs 'int32' do not match 0 inputs specified", "Input signature mismatch '' vs. 'int32'"); } @@ -690,28 +691,28 @@ REGISTER_OP("AddListBigDefault") .Attr("T: list(type) = [DT_INT32]"); TEST_F(OpCompatibilityTest, AddNIntsBigDefaultFails) { - OpDef old_op_def; - TF_ASSERT_OK(OpDefBuilder("AddNIntsBigDefault").Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("fails", &old_op_def).Finalize(node_def())); - ExpectInvalid(old_op_def, + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("AddNIntsBigDefault").Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def).Finalize(node_def())); + ExpectInvalid(old_op.op_def, "expected inputs 'int32' do not match 0 inputs specified", "Input signature mismatch '' vs. 'int32'"); } TEST_F(OpCompatibilityTest, AddNSameBigDefaultFails) { - OpDef old_op_def; - TF_ASSERT_OK(OpDefBuilder("AddNSameBigDefault").Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("fails", &old_op_def).Finalize(node_def())); - ExpectInvalid(old_op_def, + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("AddNSameBigDefault").Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def).Finalize(node_def())); + ExpectInvalid(old_op.op_def, "expected inputs 'int32' do not match 0 inputs specified", "Input signature mismatch '' vs. 'int32'"); } TEST_F(OpCompatibilityTest, AddListBigDefaultFails) { - OpDef old_op_def; - TF_ASSERT_OK(OpDefBuilder("AddListBigDefault").Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("fails", &old_op_def).Finalize(node_def())); - ExpectInvalid(old_op_def, + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("AddListBigDefault").Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def).Finalize(node_def())); + ExpectInvalid(old_op.op_def, "expected inputs 'int32' do not match 0 inputs specified", "Input signature mismatch '' vs. 'int32'"); } @@ -721,13 +722,12 @@ TEST_F(OpCompatibilityTest, AddListBigDefaultFails) { REGISTER_OP("ChangeType").Input("a: float"); TEST_F(OpCompatibilityTest, ChangeTypeFails) { - OpDef old_op_def; - TF_ASSERT_OK( - OpDefBuilder("ChangeType").Input("a: int32").Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("fails", &old_op_def) + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("ChangeType").Input("a: int32").Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def) .Input(FakeInput()) .Finalize(node_def())); - ExpectTypeMismatch(old_op_def, + ExpectTypeMismatch(old_op.op_def, "Input signature mismatch 'int32' vs. 'float'"); } @@ -736,17 +736,18 @@ TEST_F(OpCompatibilityTest, ChangeTypeFails) { REGISTER_OP("ChangeOrder").Input("a: float").Input("b: int32"); TEST_F(OpCompatibilityTest, ChangeOrderFails) { - OpDef old_op_def; + OpRegistrationData old_op; TF_ASSERT_OK(OpDefBuilder("ChangeOrder") .Input("b: int32") .Input("a: float") - .Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("fails", &old_op_def) + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def) .Input(FakeInput()) .Input(FakeInput()) .Finalize(node_def())); ExpectTypeMismatch( - old_op_def, "Input signature mismatch 'int32, float' vs. 'float, int32'"); + old_op.op_def, + "Input signature mismatch 'int32, float' vs. 'float, int32'"); } // Can't remove inputs/outputs. @@ -754,13 +755,12 @@ TEST_F(OpCompatibilityTest, ChangeOrderFails) { REGISTER_OP("RemoveInput"); TEST_F(OpCompatibilityTest, RemoveInputFails) { - OpDef old_op_def; - TF_ASSERT_OK( - OpDefBuilder("RemoveInput").Input("a: float").Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("fails", &old_op_def) + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("RemoveInput").Input("a: float").Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def) .Input(FakeInput()) .Finalize(node_def())); - ExpectInvalid(old_op_def, + ExpectInvalid(old_op.op_def, "expected inputs '' do not match 1 inputs specified", "Input signature mismatch 'float' vs. ''"); } @@ -770,13 +770,13 @@ TEST_F(OpCompatibilityTest, RemoveInputFails) { REGISTER_OP("ChangeAttrType").Attr("a: int"); TEST_F(OpCompatibilityTest, ChangeAttrTypeFails) { - OpDef old_op_def; + OpRegistrationData old_op; TF_ASSERT_OK( - OpDefBuilder("ChangeAttrType").Attr("a: bool").Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("fails", &old_op_def) + OpDefBuilder("ChangeAttrType").Attr("a: bool").Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def) .Attr("a", true) .Finalize(node_def())); - ExpectInvalid(old_op_def, "value with type 'bool' when 'int' expected", + ExpectInvalid(old_op.op_def, "value with type 'bool' when 'int' expected", "Attr 'a' changed type 'bool' -> 'int'"); } @@ -785,12 +785,14 @@ TEST_F(OpCompatibilityTest, ChangeAttrTypeFails) { REGISTER_OP("AttrFromList").Attr("a: int"); TEST_F(OpCompatibilityTest, AttrFromListFails) { - OpDef old_op_def; - TF_ASSERT_OK( - OpDefBuilder("AttrFromList").Attr("a: list(int)").Finalize(&old_op_def)); + OpRegistrationData old_op; TF_ASSERT_OK( - NodeDefBuilder("fails", &old_op_def).Attr("a", {5}).Finalize(node_def())); - ExpectInvalid(old_op_def, "value with type 'list(int)' when 'int' expected", + OpDefBuilder("AttrFromList").Attr("a: list(int)").Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def) + .Attr("a", {5}) + .Finalize(node_def())); + ExpectInvalid(old_op.op_def, + "value with type 'list(int)' when 'int' expected", "Attr 'a' changed type 'list(int)' -> 'int'"); } @@ -799,11 +801,13 @@ TEST_F(OpCompatibilityTest, AttrFromListFails) { REGISTER_OP("AttrToList").Attr("a: list(int)"); TEST_F(OpCompatibilityTest, AttrToListFails) { - OpDef old_op_def; - TF_ASSERT_OK(OpDefBuilder("AttrToList").Attr("a: int").Finalize(&old_op_def)); - TF_ASSERT_OK( - NodeDefBuilder("fails", &old_op_def).Attr("a", 5).Finalize(node_def())); - ExpectInvalid(old_op_def, "value with type 'int' when 'list(int)' expected", + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("AttrToList").Attr("a: int").Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def) + .Attr("a", 5) + .Finalize(node_def())); + ExpectInvalid(old_op.op_def, + "value with type 'int' when 'list(int)' expected", "Attr 'a' changed type 'int' -> 'list(int)'"); } @@ -812,15 +816,16 @@ TEST_F(OpCompatibilityTest, AttrToListFails) { REGISTER_OP("PolymorphicToAnyList").Input("a: T").Attr("T: list(type)"); TEST_F(OpCompatibilityTest, PolymorphicToAnyListFails) { - OpDef old_op_def; + OpRegistrationData old_op; TF_ASSERT_OK(OpDefBuilder("PolymorphicToAnyList") .Input("a: T") .Attr("T: type") - .Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("fails", &old_op_def) + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def) .Input(FakeInput(DT_INT32)) .Finalize(node_def())); - ExpectInvalid(old_op_def, "value with type 'type' when 'list(type)' expected", + ExpectInvalid(old_op.op_def, + "value with type 'type' when 'list(type)' expected", "Attr 'T' changed type 'type' -> 'list(type)'"); } @@ -832,16 +837,17 @@ REGISTER_OP("SameToAnyList") .Attr("N: int = 1"); TEST_F(OpCompatibilityTest, SameToAnyListFails) { - OpDef old_op_def; + OpRegistrationData old_op; TF_ASSERT_OK(OpDefBuilder("SameToAnyList") .Input("a: N * T") .Attr("T: type") .Attr("N: int") - .Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("fails", &old_op_def) + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def) .Input(FakeInput(1, DT_INT32)) .Finalize(node_def())); - ExpectInvalid(old_op_def, "value with type 'type' when 'list(type)' expected", + ExpectInvalid(old_op.op_def, + "value with type 'type' when 'list(type)' expected", "Attr 'T' changed type 'type' -> 'list(type)'"); } @@ -850,13 +856,13 @@ TEST_F(OpCompatibilityTest, SameToAnyListFails) { REGISTER_OP("AttrAddRestriction").Attr("t: {int32, int64}"); TEST_F(OpCompatibilityTest, AttrAddRestrictionFails) { - OpDef old_op_def; + OpRegistrationData old_op; TF_ASSERT_OK( - OpDefBuilder("AttrAddRestriction").Attr("t: type").Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("add_restriction", &old_op_def) + OpDefBuilder("AttrAddRestriction").Attr("t: type").Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("add_restriction", &old_op.op_def) .Attr("t", DT_BOOL) .Finalize(node_def())); - ExpectInvalid(old_op_def, + ExpectInvalid(old_op.op_def, "Value for attr 't' of bool is not in the list of allowed " "values: int32, int64", "Attr 't' has a stricter set of allowed values; from " @@ -868,14 +874,14 @@ TEST_F(OpCompatibilityTest, AttrAddRestrictionFails) { REGISTER_OP("AttrMoreRestrictive").Attr("t: {int32, int64}"); TEST_F(OpCompatibilityTest, AttrMoreRestrictiveFails) { - OpDef old_op_def; + OpRegistrationData old_op; TF_ASSERT_OK(OpDefBuilder("AttrMoreRestrictive") .Attr("t: {int32, int64, bool}") - .Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("more_restrictive", &old_op_def) + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("more_restrictive", &old_op.op_def) .Attr("t", DT_BOOL) .Finalize(node_def())); - ExpectInvalid(old_op_def, + ExpectInvalid(old_op.op_def, "Value for attr 't' of bool is not in the list of allowed " "values: int32, int64", "Attr 't' has a stricter set of allowed values; from " @@ -887,11 +893,12 @@ TEST_F(OpCompatibilityTest, AttrMoreRestrictiveFails) { REGISTER_OP("AttrAddMin").Attr("n: int >= 3"); TEST_F(OpCompatibilityTest, AttrAddMinFails) { - OpDef old_op_def; - TF_ASSERT_OK(OpDefBuilder("AttrAddMin").Attr("n: int").Finalize(&old_op_def)); - TF_ASSERT_OK( - NodeDefBuilder("add_min", &old_op_def).Attr("n", 2).Finalize(node_def())); - ExpectInvalid(old_op_def, + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("AttrAddMin").Attr("n: int").Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("add_min", &old_op.op_def) + .Attr("n", 2) + .Finalize(node_def())); + ExpectInvalid(old_op.op_def, "Value for attr 'n' of 2 must be at least minimum 3", "Attr 'n' has a higher minimum; from no minimum to 3"); } @@ -901,13 +908,13 @@ TEST_F(OpCompatibilityTest, AttrAddMinFails) { REGISTER_OP("AttrRaiseMin").Attr("n: int >= 3"); TEST_F(OpCompatibilityTest, AttrRaiseMinFails) { - OpDef old_op_def; + OpRegistrationData old_op; TF_ASSERT_OK( - OpDefBuilder("AttrRaiseMin").Attr("n: int >= 1").Finalize(&old_op_def)); - TF_ASSERT_OK(NodeDefBuilder("raise_min", &old_op_def) + OpDefBuilder("AttrRaiseMin").Attr("n: int >= 1").Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("raise_min", &old_op.op_def) .Attr("n", 2) .Finalize(node_def())); - ExpectInvalid(old_op_def, + ExpectInvalid(old_op.op_def, "Value for attr 'n' of 2 must be at least minimum 3", "Attr 'n' has a higher minimum; from 1 to 3"); } diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc index 2ec8b22458..7e021f7b64 100644 --- a/tensorflow/core/framework/op_def_builder.cc +++ b/tensorflow/core/framework/op_def_builder.cc @@ -491,7 +491,7 @@ void FinalizeDoc(const string& text, OpDef* op_def, } // namespace OpDefBuilder::OpDefBuilder(StringPiece op_name) { - op_def_.set_name(op_name.ToString()); // NOLINT + op_def()->set_name(op_name.ToString()); // NOLINT } OpDefBuilder& OpDefBuilder::Attr(StringPiece spec) { @@ -513,7 +513,7 @@ OpDefBuilder& OpDefBuilder::Output(StringPiece spec) { OpDefBuilder& OpDefBuilder::Doc(StringPiece text) { if (!doc_.empty()) { errors_.push_back( - strings::StrCat("Extra call to Doc() for Op ", op_def_.name())); + strings::StrCat("Extra call to Doc() for Op ", op_def()->name())); } else { doc_.assign(text.data(), text.size()); } @@ -522,41 +522,52 @@ OpDefBuilder& OpDefBuilder::Doc(StringPiece text) { #endif OpDefBuilder& OpDefBuilder::SetIsCommutative() { - op_def_.set_is_commutative(true); + op_def()->set_is_commutative(true); return *this; } OpDefBuilder& OpDefBuilder::SetIsAggregate() { - op_def_.set_is_aggregate(true); + op_def()->set_is_aggregate(true); return *this; } OpDefBuilder& OpDefBuilder::SetIsStateful() { - op_def_.set_is_stateful(true); + op_def()->set_is_stateful(true); return *this; } OpDefBuilder& OpDefBuilder::SetAllowsUninitializedInput() { - op_def_.set_allows_uninitialized_input(true); + op_def()->set_allows_uninitialized_input(true); return *this; } OpDefBuilder& OpDefBuilder::Deprecated(int version, StringPiece explanation) { - if (op_def_.has_deprecation()) { + if (op_def()->has_deprecation()) { errors_.push_back( - strings::StrCat("Deprecated called twice for Op ", op_def_.name())); + strings::StrCat("Deprecated called twice for Op ", op_def()->name())); } else { - OpDeprecation* deprecation = op_def_.mutable_deprecation(); + OpDeprecation* deprecation = op_def()->mutable_deprecation(); deprecation->set_version(version); deprecation->set_explanation(explanation.ToString()); } return *this; } -Status OpDefBuilder::Finalize(OpDef* op_def) const { +OpDefBuilder& OpDefBuilder::SetShapeFn(const OpShapeInferenceFn& fn) { + if (op_reg_data_.shape_inference_fn != nullptr) { + errors_.push_back( + strings::StrCat("SetShapeFn called twice for Op ", op_def()->name())); + } else { + op_reg_data_.shape_inference_fn = fn; + } + return *this; +} + +Status OpDefBuilder::Finalize(OpRegistrationData* op_reg_data) const { std::vector<string> errors = errors_; - *op_def = op_def_; + *op_reg_data = op_reg_data_; + OpDef* op_def = &op_reg_data->op_def; for (StringPiece attr : attrs_) { FinalizeAttr(attr, op_def, &errors); } diff --git a/tensorflow/core/framework/op_def_builder.h b/tensorflow/core/framework/op_def_builder.h index 4ca5927d8a..14950eff28 100644 --- a/tensorflow/core/framework/op_def_builder.h +++ b/tensorflow/core/framework/op_def_builder.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Class and associated machinery for specifying an Op's OpDef for Op -// registration. +// Class and associated machinery for specifying an Op's OpDef and shape +// inference function for Op registration. #ifndef TENSORFLOW_FRAMEWORK_OP_DEF_BUILDER_H_ #define TENSORFLOW_FRAMEWORK_OP_DEF_BUILDER_H_ @@ -24,9 +24,25 @@ limitations under the License. #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/macros.h" namespace tensorflow { +namespace shape_inference { +class InferenceContext; +} +typedef std::function<Status(shape_inference::InferenceContext* c)> + OpShapeInferenceFn; + +struct OpRegistrationData { + public: + OpRegistrationData() {} + OpRegistrationData(const OpDef& def) : op_def(def) {} + + OpDef op_def; + OpShapeInferenceFn shape_inference_fn; +}; + // Builder class passed to the REGISTER_OP() macro. class OpDefBuilder { public: @@ -111,14 +127,21 @@ class OpDefBuilder { OpDefBuilder& Doc(StringPiece text) { return *this; } #endif - // Sets *op_def to the requested OpDef, or returns an error. + OpDefBuilder& SetShapeFn(const OpShapeInferenceFn& fn); + + // Sets op_reg_data->op_def to the requested OpDef and + // op_reg_data->shape_inference_fn to the requested shape inference function, + // or returns an error. // Must be called after all of the above methods. + // // Note that OpDefBuilder only reports parsing errors. You should also // call ValidateOpDef() to detect other problems. - Status Finalize(OpDef* op_def) const; + Status Finalize(OpRegistrationData* op_reg_data) const; private: - OpDef op_def_; + OpDef* op_def() { return &op_reg_data_.op_def; } + + OpRegistrationData op_reg_data_; std::vector<string> attrs_; std::vector<string> inputs_; std::vector<string> outputs_; diff --git a/tensorflow/core/framework/op_def_builder_test.cc b/tensorflow/core/framework/op_def_builder_test.cc index ff986746ec..15f9c9671f 100644 --- a/tensorflow/core/framework/op_def_builder_test.cc +++ b/tensorflow/core/framework/op_def_builder_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -38,10 +39,12 @@ class OpDefBuilderTest : public ::testing::Test { protected: OpDefBuilder b() { return OpDefBuilder("Test"); } - void ExpectSuccess(const OpDefBuilder& builder, StringPiece proto) { - OpDef op_def; - Status status = builder.Finalize(&op_def); + void ExpectSuccess(const OpDefBuilder& builder, StringPiece proto, + OpShapeInferenceFn* shape_fn_out = nullptr) { + OpRegistrationData op_reg_data; + Status status = builder.Finalize(&op_reg_data); TF_EXPECT_OK(status); + OpDef& op_def = op_reg_data.op_def; if (status.ok()) { OpDef expected; protobuf::TextFormat::ParseFromString( @@ -50,13 +53,18 @@ class OpDefBuilderTest : public ::testing::Test { CanonicalizeAttrTypeListOrder(&op_def); CanonicalizeAttrTypeListOrder(&expected); EXPECT_EQ(op_def.ShortDebugString(), expected.ShortDebugString()); + + if (shape_fn_out) { + *shape_fn_out = op_reg_data.shape_inference_fn; + } } } void ExpectOrdered(const OpDefBuilder& builder, StringPiece proto) { - OpDef op_def; - Status status = builder.Finalize(&op_def); + OpRegistrationData op_reg_data; + Status status = builder.Finalize(&op_reg_data); TF_EXPECT_OK(status); + OpDef& op_def = op_reg_data.op_def; if (status.ok()) { OpDef expected; protobuf::TextFormat::ParseFromString( @@ -66,8 +74,8 @@ class OpDefBuilderTest : public ::testing::Test { } void ExpectFailure(const OpDefBuilder& builder, string error) { - OpDef op_def; - Status status = builder.Finalize(&op_def); + OpRegistrationData op_reg_data; + Status status = builder.Finalize(&op_reg_data); EXPECT_FALSE(status.ok()); if (!status.ok()) { EXPECT_EQ(status.error_message(), error); @@ -573,5 +581,26 @@ attr { )proto"); } +TEST_F(OpDefBuilderTest, SetShapeFn) { + auto fn = OpShapeInferenceFn([](shape_inference::InferenceContext* c) { + return errors::Unknown("ShapeFn was called"); + }); + OpShapeInferenceFn fn_out; + ExpectSuccess( + b().SetShapeFn(fn).Attr("dtype: type"), + "attr { name: \"dtype\" type: \"type\" allowed_values { list { } } }", + &fn_out); + ASSERT_TRUE(fn_out != nullptr); + EXPECT_EQ("ShapeFn was called", fn_out(nullptr).error_message()); +} + +TEST_F(OpDefBuilderTest, SetShapeFnCalledTwiceFailure) { + auto fn = OpShapeInferenceFn([](shape_inference::InferenceContext* c) { + return errors::Unknown("ShapeFn was called"); + }); + ExpectFailure(b().SetShapeFn(fn).SetShapeFn(fn), + "SetShapeFn called twice for Op Test"); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/framework/op_def_util_test.cc b/tensorflow/core/framework/op_def_util_test.cc index 71924b4c98..e24b645683 100644 --- a/tensorflow/core/framework/op_def_util_test.cc +++ b/tensorflow/core/framework/op_def_util_test.cc @@ -37,13 +37,13 @@ class ValidateOpDefTest : public ::testing::Test { Status TestProto(const string& text) { return ValidateOpDef(FromText(text)); } Status TestBuilder(const OpDefBuilder& builder) { - OpDef op_def; - Status status = builder.Finalize(&op_def); + OpRegistrationData op_reg_data; + Status status = builder.Finalize(&op_reg_data); TF_EXPECT_OK(status); if (!status.ok()) { return status; } else { - return ValidateOpDef(op_def); + return ValidateOpDef(op_reg_data.op_def); } } diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index 30de279e7f..4df0ffeda7 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -757,9 +757,9 @@ Status SupportedDeviceTypesForNode( // DynamicPlacer) to consider the possibility that 'def' is call to // a user-defined function and only calls this // SupportedDeviceTypesForNode for primitive ops. - Status s; - const OpDef* op_def = OpRegistry::Global()->LookUp(def.op(), &s); - if (op_def) { + const OpRegistrationData* op_reg_data; + const Status s = OpRegistry::Global()->LookUp(def.op(), &op_reg_data); + if (s.ok()) { for (const DeviceType& device_type : prioritized_types) { const KernelRegistration* reg = nullptr; TF_RETURN_IF_ERROR(FindKernelRegistration(device_type, def, ®)); @@ -790,9 +790,9 @@ Status CreateOpKernel(DeviceType device_type, DeviceBase* device, VLOG(1) << "Instantiating kernel for node: " << SummarizeNodeDef(node_def); // Look up the Op registered for this op name. - Status s; - const OpDef* op_def = OpRegistry::Global()->LookUp(node_def.op(), &s); - if (op_def == nullptr) return s; + const OpDef* op_def = nullptr; + Status s = OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def); + if (!s.ok()) return s; // Validate node_def against OpDef. s = ValidateNodeDef(node_def, *op_def); @@ -858,22 +858,23 @@ bool FindArgInOp(StringPiece arg_name, } // namespace Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry) { - Status unused_status; for (const auto& key_registration : *GlobalKernelRegistryTyped()) { const KernelDef& kernel_def(key_registration.second.def); - const OpDef* op_def = op_registry.LookUp(kernel_def.op(), &unused_status); - if (op_def == nullptr) { + const OpRegistrationData* op_reg_data; + const Status status = op_registry.LookUp(kernel_def.op(), &op_reg_data); + if (!status.ok()) { // TODO(josh11b): Make this a hard error. LOG(ERROR) << "OpKernel ('" << ProtoShortDebugString(kernel_def) << "') for unknown op: " << kernel_def.op(); continue; } + const OpDef& op_def = op_reg_data->op_def; for (const auto& host_memory_arg : kernel_def.host_memory_arg()) { - if (!FindArgInOp(host_memory_arg, op_def->input_arg()) && - !FindArgInOp(host_memory_arg, op_def->output_arg())) { + if (!FindArgInOp(host_memory_arg, op_def.input_arg()) && + !FindArgInOp(host_memory_arg, op_def.output_arg())) { return errors::InvalidArgument("HostMemory arg '", host_memory_arg, "' not found in OpDef: ", - SummarizeOpDef(*op_def)); + SummarizeOpDef(op_def)); } } } diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index 17b9d33386..5269314583 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -125,8 +125,8 @@ Status InferenceContext::WithValue(const Dimension* dim, int64 value, return Status::OK(); } *out = nullptr; - return errors::InvalidArgument("Dimension must be size ", value, - " but is size ", existing); + return errors::InvalidArgument("Dimension must be ", value, " but is ", + existing); } Status InferenceContext::Merge(const Dimension* d0, const Dimension* d1, @@ -142,7 +142,7 @@ Status InferenceContext::Merge(const Dimension* d0, const Dimension* d1, return Status::OK(); } else { *out = nullptr; - return errors::InvalidArgument("Dimensions must be equal size, but are ", + return errors::InvalidArgument("Dimensions must be equal, but are ", Value(d0), " and ", Value(d1)); } } @@ -181,7 +181,8 @@ Status InferenceContext::Merge(const Shape* s0, const Shape* s1, return_s1 = false; } else if (v0 != v1) { *out = nullptr; - return errors::InvalidArgument("Dimensions must be equal size, but are ", + return errors::InvalidArgument("Dimension ", i, + " in both shapes must be equal, but are ", Value(d0), " and ", Value(d1)); } } diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index c0681a91fc..55e8b90326 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -91,8 +91,9 @@ class InferenceContext { // not available at the time of shape inference. const Tensor* input_tensor(int idx) const { return input_tensors_[idx]; } - void set_output(int idx, const Shape* shape); + void set_output(int idx, const Shape* shape) { outputs_[idx] = shape; } int num_outputs() const { return outputs_.size(); } + const Shape* output(int idx) { return outputs_[idx]; } // idx can be negative for an offset from end of dimensions. const Dimension* Dim(const Shape* s, int32 idx) { return s->dims_[idx]; } diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc index b8d21e5e38..27788ac03d 100644 --- a/tensorflow/core/framework/shape_inference_test.cc +++ b/tensorflow/core/framework/shape_inference_test.cc @@ -113,11 +113,11 @@ TEST(ShapeInferenceTest, WithValue) { // WithRank on dimension with known size. out1 = d0; - EXPECT_EQ("Invalid argument: Dimension must be size 0 but is size 1", + EXPECT_EQ("Invalid argument: Dimension must be 0 but is 1", c.WithValue(d0, 0, &out1).ToString()); EXPECT_TRUE(out1 == nullptr); out1 = d0; - EXPECT_EQ("Invalid argument: Dimension must be size 2 but is size 1", + EXPECT_EQ("Invalid argument: Dimension must be 2 but is 1", c.WithValue(d0, 2, &out1).ToString()); EXPECT_TRUE(out1 == nullptr); EXPECT_TRUE(c.WithValue(d0, 1, &out1).ok()); @@ -159,10 +159,10 @@ TEST(ShapeInferenceTest, MergeDim) { EXPECT_TRUE(d2_b == out); // Merging inequal values is an error. - EXPECT_EQ("Invalid argument: Dimensions must be equal size, but are 2 and 1", + EXPECT_EQ("Invalid argument: Dimensions must be equal, but are 2 and 1", c.Merge(d2, d1, &out).ToString()); EXPECT_TRUE(out == nullptr); - EXPECT_EQ("Invalid argument: Dimensions must be equal size, but are 1 and 2", + EXPECT_EQ("Invalid argument: Dimensions must be equal, but are 1 and 2", c.Merge(d1, d2, &out).ToString()); EXPECT_TRUE(out == nullptr); } @@ -210,11 +210,13 @@ TEST(ShapeInferenceTest, MergeShape) { // Incompatible merges give errors and set out to nullptr. out = s_unknown; - EXPECT_EQ("Invalid argument: Dimensions must be equal size, but are 2 and 3", + EXPECT_EQ(("Invalid argument: Dimension 1 in both shapes must be equal, but " + "are 2 and 3"), c.Merge(s_u_2, s_1_3, &out).ToString()); EXPECT_TRUE(out == nullptr); out = s_unknown; - EXPECT_EQ("Invalid argument: Dimensions must be equal size, but are 3 and 2", + EXPECT_EQ(("Invalid argument: Dimension 1 in both shapes must be equal, but " + "are 3 and 2"), c.Merge(s_1_3, s_u_2, &out).ToString()); EXPECT_TRUE(out == nullptr); out = s_unknown; diff --git a/tensorflow/core/framework/shape_inference_testutil.cc b/tensorflow/core/framework/shape_inference_testutil.cc new file mode 100644 index 0000000000..9b56014edb --- /dev/null +++ b/tensorflow/core/framework/shape_inference_testutil.cc @@ -0,0 +1,153 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/shape_inference_testutil.h" + +#include <unordered_map> + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace tensorflow { + +using shape_inference::Dimension; +using shape_inference::Shape; +using errors::Unknown; + +Status InferShapes(const string& op_name, const string& ins, + const string& expected_outs) { + const OpRegistrationData* op_reg_data; + TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUp(op_name, &op_reg_data)); + const int num_outputs = op_reg_data->op_def.output_arg_size(); + + std::vector<string> ins_v = str_util::Split(ins, ';'); + shape_inference::InferenceContext c(ins_v, num_outputs); + TF_RETURN_IF_ERROR(op_reg_data->shape_inference_fn(&c)); + + std::unordered_map<const Dimension*, std::pair<int, int>> + dim_to_input_and_dim_idx; + std::unordered_map<const Shape*, int> shape_to_input_idx; + for (int i = 0; i < c.num_inputs(); ++i) { + auto in = c.input(i); + shape_to_input_idx[in] = i; + for (int j = 0; j < c.Rank(in); ++j) { + dim_to_input_and_dim_idx[c.Dim(in, j)] = std::make_pair(i, j); + } + } + if (expected_outs == "e") { + return Unknown("Shape inference should have returned error"); + } + + // Verify the output shape. + std::vector<string> expected_outs_v = str_util::Split(expected_outs, ';'); + if (num_outputs != expected_outs_v.size()) { + return Unknown("Wrong number of expected outputs (", expected_outs_v.size(), + " vs ", num_outputs, ")"); + } + for (int i = 0; i < num_outputs; ++i) { + string err_prefix = strings::StrCat("Output ", i); + StringPiece expected(expected_outs_v[i]); + const shape_inference::Shape* out = c.output(i); + const int in_index = gtl::FindWithDefault(shape_to_input_idx, out, -1); + if (expected.starts_with("in")) { + if (in_index == -1) { + return Unknown(err_prefix, " did not match any input shape"); + } + auto v = str_util::Split(expected, '|'); + if (std::find(v.begin(), v.end(), strings::StrCat("in", in_index)) == + v.end()) { + return Unknown(err_prefix, " matched input ", in_index, + " and should have matched one of (", expected, ")"); + } + continue; + } + if (in_index != -1) { + return Unknown(err_prefix, " matched input ", in_index, + " and should have not matched an input shape"); + } + if (expected == "?") { + if (c.RankKnown(out)) { + return Unknown(err_prefix, " expected to be unknown but was ", + c.DebugString(out)); + } + continue; + } + + // Verify the dimensions. + CHECK(expected.starts_with("[") && expected.ends_with("]")); + expected.remove_prefix(1); + expected.remove_suffix(1); + + // Split expected as a dimension. + auto expected_dims = str_util::Split(expected, ','); + if (!c.RankKnown(out)) { + return Unknown(err_prefix, " expected rank ", expected_dims.size(), + " but was ?"); + } + if (c.Rank(out) != expected_dims.size()) { + return Unknown(err_prefix, " expected rank ", expected_dims.size(), + " but was ", c.Rank(out)); + } + for (int j = 0; j < expected_dims.size(); ++j) { + err_prefix = strings::StrCat("Output dim ", i, ",", j); + StringPiece expected_dim(expected_dims[j]); + const Dimension* out_dim = c.Dim(out, j); + std::pair<int, int> in_dim_idx = gtl::FindWithDefault( + dim_to_input_and_dim_idx, out_dim, std::make_pair(-1, -1)); + if (expected_dim == "?") { + if (in_dim_idx.first != -1) { + return Unknown(err_prefix, + " expected to be unknown but matched input d", + in_dim_idx.first, "_", in_dim_idx.second); + } else if (c.ValueKnown(out_dim)) { + return Unknown(err_prefix, " expected to be unknown but was ", + c.Value(out_dim)); + } + } else if (expected_dim.starts_with("d")) { + // Compare the dimension values. + auto v = str_util::Split(expected_dim, '|'); + if (in_dim_idx.first == -1) { + return Unknown(err_prefix, " did not match any input dim"); + } + if (std::find(v.begin(), v.end(), + strings::StrCat("d", in_dim_idx.first, "_", + in_dim_idx.second)) == v.end()) { + return Unknown(err_prefix, " matched input d", in_dim_idx.first, "_", + in_dim_idx.second, " and should have matched one of ", + expected_dim); + } + } else { + // Parse it as a value. + int64 value = -1; + if (!strings::safe_strto64(expected_dim, &value)) { + return Unknown(err_prefix, " expected dim failed to parse as int64"); + } + if (in_dim_idx.first != -1) { + return Unknown(err_prefix, " expected to be ", value, + " but matched input d", in_dim_idx.first, "_", + in_dim_idx.second); + } else if (value != c.Value(out_dim)) { + return Unknown(err_prefix, " expected to be ", value, " but was ", + c.DebugString(out_dim)); + } + } + } + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/shape_inference_testutil.h b/tensorflow/core/framework/shape_inference_testutil.h new file mode 100644 index 0000000000..f2581247d9 --- /dev/null +++ b/tensorflow/core/framework/shape_inference_testutil.h @@ -0,0 +1,56 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_ + +#include <vector> +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +// Contains utilities for writing tests for shape inference functions. + +namespace tensorflow { + +// Run shape inference for <op_name>, given inputs specified by <ins> +// and returns an error if the inferred shape does not match expected_outs. +// +// <ins> is a semicolon separated list of shapes. Each shape is formatted +// according to the formatting per +// shape_inference::InferenceContext::InferenceContext. +// +// <expected_outs> is a semicolon separated list of shapes. Each shape is +// formatted as one of: +// * ? - an unknown shape, but not matching an input shape +// * in0|in2|... - output shape must be the same as one of these input shapes. +// * [1,?,d0_0|d0_1] - output shape is of known rank, with comma-separated +// dimension values. +// Each dimension value is one of: +// * a constant, which means that constant not equal to a specific input +// * ?, which means an unknown dim size not equal to a specific input +// * d0_0|d1_2, indicating that the dim size must be equal to one of +// the given input dimensions; the first number is the input # and +// the second is which dimension in that input it corresponds to. +// <expected_outs> can be "e"; this is used to indicate that shape inference +// should have failed. +Status InferShapes(const string& op_name, const string& ins, + const string& expected_outs); + +#define INFER_OK(op, i, o) EXPECT_EQ("", InferShapes(op, i, o).error_message()) +#define INFER_ERROR(s, op, i) \ + EXPECT_EQ(s, InferShapes(op, i, "x").error_message()) + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_ diff --git a/tensorflow/core/framework/shape_inference_testutil_test.cc b/tensorflow/core/framework/shape_inference_testutil_test.cc new file mode 100644 index 0000000000..d16ec04264 --- /dev/null +++ b/tensorflow/core/framework/shape_inference_testutil_test.cc @@ -0,0 +1,154 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/shape_inference_testutil.h" + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +using shape_inference::InferenceContext; + +namespace { + +static OpShapeInferenceFn* global_fn_ptr = nullptr; +REGISTER_OP("OpOneOut") + .Input("inputs: N * T") + .Output("o1: T") + .Attr("N: int >= 1") + .Attr("T: numbertype") + .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) { + return (*global_fn_ptr)(c); + })); +REGISTER_OP("OpTwoOut") + .Input("inputs: N * T") + .Output("o1: T") + .Output("o2: T") + .Attr("N: int >= 1") + .Attr("T: numbertype") + .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) { + return (*global_fn_ptr)(c); + })); + +string RunInferShapes(const string& op_name, const string& ins, + const string& expected_outs, OpShapeInferenceFn fn) { + global_fn_ptr = &fn; + return InferShapes(op_name, ins, expected_outs).error_message(); +} + +} // namespace + +TEST(ShapeInferenceTestutilTest, Failures) { + auto fn_copy_input_0 = [](InferenceContext* c) { + c->set_output(0, c->input(0)); + return Status::OK(); + }; + auto fn_copy_input_2 = [](InferenceContext* c) { + c->set_output(0, c->input(2)); + return Status::OK(); + }; + auto fn_output_unknown_shapes = [](InferenceContext* c) { + for (int i = 0; i < c->num_outputs(); ++i) { + c->set_output(i, c->CreateUnknownShape()); + } + return Status::OK(); + }; + auto fn_output_1_2 = [](InferenceContext* c) { + c->set_output(0, c->CreateShape({c->CreateDim(1), c->CreateDim(2)})); + return Status::OK(); + }; + auto fn_output_u_2 = [](InferenceContext* c) { + c->set_output(0, c->CreateShape({c->CreateUnknownDim(), c->CreateDim(2)})); + return Status::OK(); + }; + const string& op = "OpOneOut"; + + EXPECT_EQ("Shape inference should have returned error", + RunInferShapes(op, "[1];[2];[1]", "e", fn_copy_input_0)); + EXPECT_EQ("Wrong number of expected outputs (2 vs 1)", + RunInferShapes(op, "[1];[2];[1]", "[1];[2]", fn_copy_input_0)); + EXPECT_EQ("Op type not registered 'NoSuchOp'", + RunInferShapes("NoSuchOp", "", "", fn_copy_input_0)); + + // Wrong shape error messages. + EXPECT_EQ( + "Output 0 matched input 0 and should have not matched an input shape", + RunInferShapes(op, "[1];[2];[1]", "?", fn_copy_input_0)); + EXPECT_EQ("Output 0 matched input 0 and should have matched one of (in2)", + RunInferShapes(op, "[1];[2];[1]", "in2", fn_copy_input_0)); + EXPECT_EQ("Output 0 matched input 0 and should have matched one of (in1|in2)", + RunInferShapes(op, "[1];[2];[1]", "in1|in2", fn_copy_input_0)); + EXPECT_EQ( + "Output 0 matched input 2 and should have not matched an input shape", + RunInferShapes(op, "[1];[2];[1]", "[1]", fn_copy_input_2)); + EXPECT_EQ("Output 0 did not match any input shape", + RunInferShapes(op, "[1];[2];[1]", "in0|in1", fn_output_1_2)); + EXPECT_EQ("Output 0 expected to be unknown but was [1,2]", + RunInferShapes(op, "[1];[2];[1]", "?", fn_output_1_2)); + EXPECT_EQ("Output 0 expected rank 3 but was 2", + RunInferShapes(op, "[1];[2];[1]", "[1,2,3]", fn_output_1_2)); + EXPECT_EQ( + "Output 0 expected rank 2 but was ?", + RunInferShapes(op, "[1];[2];[1]", "[1,2]", fn_output_unknown_shapes)); + + // Wrong shape error messages on the second output. + EXPECT_EQ("Output 1 expected rank 3 but was ?", + RunInferShapes("OpTwoOut", "[1];[2];[1]", "?;[1,2,3]", + fn_output_unknown_shapes)); + + // Wrong dimension error messages. + EXPECT_EQ("Output dim 0,1 expected to be 3 but was 2", + RunInferShapes(op, "[1];[2];[1]", "[1,3]", fn_output_1_2)); + EXPECT_EQ("Output dim 0,0 expected to be 2 but was 1", + RunInferShapes(op, "[1];[2];[1]", "[2,2]", fn_output_1_2)); + EXPECT_EQ("Output dim 0,0 expected to be unknown but was 1", + RunInferShapes(op, "[1];[2];[1]", "[?,2]", fn_output_1_2)); + EXPECT_EQ("Output dim 0,1 expected to be 1 but was 2", + RunInferShapes(op, "[1];[2];[1]", "[?,1]", fn_output_u_2)); + EXPECT_EQ("Output dim 0,0 expected to be 1 but was ?", + RunInferShapes(op, "[0,1,?];[2];[1]", "[1,2]", fn_output_u_2)); + auto fn = [](InferenceContext* c) { + c->set_output( + 0, c->CreateShape({c->Dim(c->input(0), 1), c->CreateDim(2), + c->CreateUnknownDim(), c->Dim(c->input(2), 0)})); + return Status::OK(); + }; + const string ins = "[0,1,?];[2];[1]"; + EXPECT_EQ("Output dim 0,0 expected to be unknown but matched input d0_1", + RunInferShapes(op, ins, "[?,2,?,d2_0]", fn)); + EXPECT_EQ("Output dim 0,0 expected to be 0 but matched input d0_1", + RunInferShapes(op, ins, "[0,2,?,d2_0]", fn)); + EXPECT_EQ( + "Output dim 0,0 matched input d0_1 and should have matched one of d0_0", + RunInferShapes(op, ins, "[d0_0,2,?,d2_0]", fn)); + EXPECT_EQ("Output dim 0,0 expected dim failed to parse as int64", + RunInferShapes(op, ins, "[x,2,?,d2_0]", fn)); + EXPECT_EQ( + "Output dim 0,0 matched input d0_1 and should have matched one of " + "d0_0|d0_2", + RunInferShapes(op, ins, "[d0_0|d0_2,2,?,d2_0]", fn)); + EXPECT_EQ("Output dim 0,1 expected to be unknown but was 2", + RunInferShapes(op, ins, "[d0_1,?,?,d0_0|d2_0]", fn)); + EXPECT_EQ("Output dim 0,2 expected to be 8 but was ?", + RunInferShapes(op, ins, "[d0_1,2,8,d0_0|d2_0]", fn)); + EXPECT_EQ("Output dim 0,2 did not match any input dim", + RunInferShapes(op, ins, "[d0_1,2,d0_1|d2_0,d0_0|d2_0]", fn)); + EXPECT_EQ("", // OK, no error. + RunInferShapes(op, ins, "[d0_1,2,?,d0_0|d2_0]", fn)); +} + +} // namespace tensorflow diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 363fe5623c..3607db7e81 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -205,8 +205,9 @@ Graph::~Graph() { } Node* Graph::AddNode(const NodeDef& node_def, Status* status) { - const OpDef* op_def = ops_->LookUp(node_def.op(), status); - if (op_def == nullptr) return nullptr; + const OpDef* op_def; + status->Update(ops_->LookUpOpDef(node_def.op(), &op_def)); + if (!status->ok()) return nullptr; DataTypeVector inputs; DataTypeVector outputs; diff --git a/tensorflow/core/graph/validate.cc b/tensorflow/core/graph/validate.cc index 21d9e25284..b5fbb5fa27 100644 --- a/tensorflow/core/graph/validate.cc +++ b/tensorflow/core/graph/validate.cc @@ -30,8 +30,8 @@ Status ValidateGraphDef(const GraphDef& graph_def, const int version = graph_def.versions().producer(); for (const NodeDef& node_def : graph_def.node()) { // Look up the OpDef for the node_def's op name. - const OpDef* op_def = op_registry.LookUp(node_def.op(), &s); - TF_RETURN_IF_ERROR(s); + const OpDef* op_def; + TF_RETURN_IF_ERROR(op_registry.LookUpOpDef(node_def.op(), &op_def)); TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, *op_def)); TF_RETURN_IF_ERROR(CheckOpDeprecation(*op_def, version)); } diff --git a/tensorflow/core/graph/validate_test.cc b/tensorflow/core/graph/validate_test.cc index a7357df15d..cb6d107cad 100644 --- a/tensorflow/core/graph/validate_test.cc +++ b/tensorflow/core/graph/validate_test.cc @@ -95,8 +95,10 @@ TEST(ValidateGraphDefTest, GraphWithUnspecifiedRequiredAttr) { } TEST(ValidateGraphDefAgainstOpListTest, GraphWithOpOnlyInOpList) { + OpRegistrationData op_reg_data; + TF_ASSERT_OK(OpDefBuilder("UniqueSnowflake").Finalize(&op_reg_data)); OpList op_list; - TF_ASSERT_OK(OpDefBuilder("UniqueSnowflake").Finalize(op_list.add_op())); + *op_list.add_op() = op_reg_data.op_def; const string graph_def_str = "node { name: 'A' op: 'UniqueSnowflake' }"; GraphDef graph_def; auto parser = protobuf::TextFormat::Parser(); @@ -105,8 +107,10 @@ TEST(ValidateGraphDefAgainstOpListTest, GraphWithOpOnlyInOpList) { } TEST(ValidateGraphDefAgainstOpListTest, GraphWithGlobalOpNotInOpList) { + OpRegistrationData op_reg_data; + TF_ASSERT_OK(OpDefBuilder("NotAnywhere").Finalize(&op_reg_data)); OpList op_list; - TF_ASSERT_OK(OpDefBuilder("NotAnywhere").Finalize(op_list.add_op())); + *op_list.add_op() = op_reg_data.op_def; const string graph_def_str = "node { name: 'A' op: 'FloatInput' }"; GraphDef graph_def; auto parser = protobuf::TextFormat::Parser(); diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 46398ca5a9..7743173996 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -15,9 +15,14 @@ limitations under the License. #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { +typedef shape_inference::Dimension Dimension; +typedef shape_inference::InferenceContext InferenceContext; +typedef shape_inference::Shape Shape; + REGISTER_OP("AddN") .Input("inputs: N * T") .Output("sum: T") @@ -25,6 +30,16 @@ REGISTER_OP("AddN") .Attr("T: numbertype") .SetIsCommutative() .SetIsAggregate() + .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) { + const Shape* cur = c->input(c->num_inputs() - 1); + for (int i = c->num_inputs() - 2; i >= 0; --i) { + TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur), + "From merging shape ", i, + " with other shapes."); + } + c->set_output(0, cur); + return Status::OK(); + })) .Doc(R"doc( Add all input tensors element wise. diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc new file mode 100644 index 0000000000..89d6ed14d1 --- /dev/null +++ b/tensorflow/core/ops/math_ops_test.cc @@ -0,0 +1,44 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +yo + +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference_testutil.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +TEST(MathOpsTest, AddN_ShapeFn) { + INFER_OK("AddN", "?;?", "in0|in1"); + INFER_OK("AddN", "[1,?]", "in0"); + INFER_OK("AddN", "[1,2];[?,2]", "in0"); + INFER_OK("AddN", "[1,2];[1,2]", "in0|in1"); + INFER_OK("AddN", "[?,2];[1,2]", "in1"); + INFER_OK("AddN", "[1,?];[?,2];[1,2]", "in2"); + INFER_OK("AddN", "[1,2];[?,2];[1,?]", "in0"); + INFER_OK("AddN", "?;?;[1,2]", "in2"); + INFER_OK("AddN", "[1,?];[?,2]", "[d0_0,d1_1]"); + INFER_OK("AddN", "[?,2,?];[?,?,3]", "[d0_0|d1_0,d0_1,d1_2]"); + INFER_OK("AddN", "[?,2];[1,?]", "[d1_0,d0_1]"); + + INFER_ERROR(("Dimension 1 in both shapes must be equal, but are 2 and " + "4\n\tFrom merging shape 0 with other shapes."), + "AddN", "[1,2];?;[1,4]"); + INFER_ERROR(("Shapes must be equal rank, but are 2 and 3\n\tFrom merging " + "shape 1 with other shapes."), + "AddN", "?;[1,2];?;[1,2,3]"); +} + +} // end namespace tensorflow |