diff options
author | 2017-07-06 13:59:28 -0700 | |
---|---|---|
committer | 2017-07-06 14:03:01 -0700 | |
commit | 7d5c74a9c89333f4f5b0524856907ac1b9f43a7b (patch) | |
tree | 92961ba96c8bfacf8548b3097f87e86ac355e931 /tensorflow/core/framework | |
parent | 2caec3af18a5c81f4c02f049cb5f39ca71700795 (diff) |
Move duplicate detection logic from Graph to FunctionLibraryDefinition
Turns out this is more useful, since there are many function libraries
that don't belong to a graph. This will be used in a future
change. Note that this maintains the current behavior of Graph.
In addition, updates FunctionDefsEqual() to handle unset attr entries
(I ran into this when using this in said future change).
PiperOrigin-RevId: 161126628
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r-- | tensorflow/core/framework/function.cc | 61 | ||||
-rw-r--r-- | tensorflow/core/framework/function.h | 6 | ||||
-rw-r--r-- | tensorflow/core/framework/function_test.cc | 47 |
3 files changed, 85 insertions, 29 deletions
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 9d43aab5a5..c6a20bf3ce 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -724,6 +724,23 @@ string DebugStringWhole(const GraphDef& gdef) { return ret; } +namespace { + +// Returns the name -> attr mapping of fdef's attrs that have a value set. In +// Python, it's possible to access unset attrs, which returns a default value +// and adds an unset attr to the map. +std::map<StringPiece, AttrValue> GetSetAttrs(const FunctionDef& fdef) { + std::map<StringPiece, AttrValue> set_attrs; + for (auto iter : fdef.attr()) { + if (iter.second.value_case() != AttrValue::VALUE_NOT_SET) { + set_attrs[iter.first] = iter.second; + } + } + return set_attrs; +} + +} // end namespace + bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) { // NOTE(skyewm): Using MessageDifferencer would be better here, but that is // currently not included in tensorflow/core/platform/default/protobuf.h, so @@ -736,10 +753,12 @@ bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) { f2.signature().SerializeToString(&sig2); if (sig1 != sig2) return false; - if (f1.attr().size() != f2.attr().size()) return false; - for (auto iter1 : f1.attr()) { - auto iter2 = f2.attr().find(iter1.first); - if (iter2 == f2.attr().end()) return false; + std::map<StringPiece, AttrValue> f1_attrs = GetSetAttrs(f1); + std::map<StringPiece, AttrValue> f2_attrs = GetSetAttrs(f2); + if (f1_attrs.size() != f2_attrs.size()) return false; + for (auto iter1 : f1_attrs) { + auto iter2 = f2_attrs.find(iter1.first); + if (iter2 == f2_attrs.end()) return false; if (!AreAttrValuesEqual(iter1.second, iter2->second)) return false; } @@ -883,11 +902,17 @@ const FunctionDef* FunctionLibraryDefinition::Find(const string& name) const { } Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) { - auto& ptr = function_defs_[fdef.signature().name()]; - if (ptr != nullptr) { - return errors::InvalidArgument("Function with name: ", - fdef.signature().name(), - " already exists in function library."); + std::unique_ptr<FunctionDefAndOpRegistration>* entry = + &function_defs_[fdef.signature().name()]; + if (*entry != nullptr) { + if (!FunctionDefsEqual((*entry)->fdef, fdef)) { + return errors::InvalidArgument( + "Cannot add function '", fdef.signature().name(), + "' because a different function with the same name already " + "exists."); + } + // Ignore duplicate FunctionDefs + return Status::OK(); } const OpDef* op_def; if (default_registry_->LookUpOpDef(fdef.signature().name(), &op_def).ok()) { @@ -895,19 +920,27 @@ Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) { "Cannot add function '", fdef.signature().name(), "' because an op with the same name already exists."); } - ptr.reset(new FunctionDefAndOpRegistration(fdef)); + entry->reset(new FunctionDefAndOpRegistration(fdef)); return Status::OK(); } Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) { - if (func_grad_.count(grad.function_name()) > 0) { - return errors::InvalidArgument("Gradient for function '", - grad.function_name(), "' already exists."); + string* entry = &func_grad_[grad.function_name()]; + if (!entry->empty()) { + if (*entry != grad.gradient_func()) { + return errors::InvalidArgument( + "Cannot assign gradient function '", grad.gradient_func(), "' to '", + grad.function_name(), "' because it already has gradient function ", + "'", *entry, "'"); + } + // Ignore duplicate GradientDefs + return Status::OK(); } - func_grad_[grad.function_name()] = grad.gradient_func(); + *entry = grad.gradient_func(); return Status::OK(); } +// TODO(skyewm): don't modify FunctionLibraryDefinition in case of error Status FunctionLibraryDefinition::AddLibrary( const FunctionLibraryDefinition& other) { for (auto iter : other.function_defs_) { diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index b8d5b8797a..2342e08b38 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -287,20 +287,24 @@ class FunctionLibraryDefinition : public OpRegistryInterface { const FunctionDef* Find(const string& func) const; // Adds function definition 'fdef' to this function library. - // Returns status 'ok' on success, or error otherwise. + // Returns status 'ok' on success, or error otherwise. This is a no-op if + // 'fdef' already exists in this function library. // If 'fdef' is successfully added to the library, it will be accessible // from 'LookUp' and included in the proto returned by 'ToProto'. Status AddFunctionDef(const FunctionDef& fdef); // Adds gradient definition 'grad' to this function library. + // This is a no-op if 'grad' already exists in this function library. // If 'grad' is successfully added, it will be accessible via 'FindGradient' // and included in the proto returned by 'ToProto'. Status AddGradientDef(const GradientDef& grad); // Adds the functions and gradients in 'other' to this function library. + // Duplicate functions and gradients are ignored. Status AddLibrary(const FunctionLibraryDefinition& other); // Adds the functions and gradients in 'lib_def' to this function library. + // Duplicate functions and gradients are ignored. Status AddLibrary(const FunctionDefLibrary& lib_def); // If the gradient function for 'func' is specified explicitly in diff --git a/tensorflow/core/framework/function_test.cc b/tensorflow/core/framework/function_test.cc index 2ecdc36c11..1173384a1e 100644 --- a/tensorflow/core/framework/function_test.cc +++ b/tensorflow/core/framework/function_test.cc @@ -971,6 +971,10 @@ TEST(FunctionLibraryDefinitionTest, AddFunctionDef) { EXPECT_EQ(s.error_message(), "Cannot add function 'Add' because an op with the same name " "already exists."); + + // Already-added functions don't produce error + TF_EXPECT_OK(lib_def.AddFunctionDef(test::function::XTimesTwo())); + TF_EXPECT_OK(lib_def.AddFunctionDef(test::function::WXPlusB())); } TEST(FunctionLibraryDefinitionTest, AddGradientDef) { @@ -984,12 +988,16 @@ TEST(FunctionLibraryDefinitionTest, AddGradientDef) { grad.set_gradient_func(test::function::XTimesFour().signature().name()); TF_EXPECT_OK(lib_def.AddGradientDef(grad)); + // Already-added gradients don't produce error + TF_EXPECT_OK(lib_def.AddGradientDef(grad)); + // Test that adding a duplicate gradient fails grad.set_gradient_func(test::function::XTimes16().signature().name()); Status s = lib_def.AddGradientDef(grad); EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT); EXPECT_EQ(s.error_message(), - "Gradient for function 'XTimesTwo' already exists."); + "Cannot assign gradient function 'XTimes16' to 'XTimesTwo' because " + "it already has gradient function 'XTimesFour'"); } TEST(FunctionLibraryDefinitionTest, AddLibrary) { @@ -998,35 +1006,46 @@ TEST(FunctionLibraryDefinitionTest, AddLibrary) { *proto.add_function() = test::function::XTimesTwo(); FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto); - // Error if you try to add the same function twice - Status s = lib_def.AddLibrary(lib_def); - EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT); - EXPECT_EQ(s.error_message(), - "Function with name: XTimesTwo already exists in function " - "library."); - // Add gradient GradientDef grad; grad.set_function_name(test::function::XTimesTwo().signature().name()); grad.set_gradient_func(test::function::XTimesFour().signature().name()); TF_EXPECT_OK(lib_def.AddGradientDef(grad)); - // Error if you try to add the same library function twice + // Error if you try to add conflicting function proto.Clear(); - *proto.add_gradient() = grad; + FunctionDef fdef = test::function::XTimesFour(); + fdef.mutable_signature()->set_name( + test::function::XTimesTwo().signature().name()); + *proto.add_function() = fdef; FunctionLibraryDefinition lib_def2(OpRegistry::Global(), proto); - s = lib_def.AddLibrary(lib_def2); + Status s = lib_def.AddLibrary(lib_def2); EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT); EXPECT_EQ(s.error_message(), - "Gradient for function 'XTimesTwo' already exists."); + "Cannot add function 'XTimesTwo' because a different function with " + "the same name already exists."); + + // Error if you try to add conflicting gradient + proto.Clear(); + grad.set_gradient_func(test::function::XTimes16().signature().name()); + *proto.add_gradient() = grad; + FunctionLibraryDefinition lib_def3(OpRegistry::Global(), proto); + s = lib_def.AddLibrary(lib_def3); + EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT); + EXPECT_EQ(s.error_message(), + "Cannot assign gradient function 'XTimes16' to 'XTimesTwo' because " + "it already has gradient function 'XTimesFour'"); // No conflicting functions or gradients OK proto.Clear(); *proto.add_function() = test::function::XTimesFour(); grad.set_function_name(test::function::XTimes16().signature().name()); *proto.add_gradient() = grad; - FunctionLibraryDefinition lib_def3(OpRegistry::Global(), proto); - TF_EXPECT_OK(lib_def.AddLibrary(lib_def3)); + FunctionLibraryDefinition lib_def4(OpRegistry::Global(), proto); + TF_EXPECT_OK(lib_def.AddLibrary(lib_def4)); + + // OK to add the same functions and gradients twice + TF_EXPECT_OK(lib_def.AddLibrary(lib_def)); } TEST(FunctionLibraryDefinitionTest, ToProto) { |