aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-07-06 13:59:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-06 14:03:01 -0700
commit7d5c74a9c89333f4f5b0524856907ac1b9f43a7b (patch)
tree92961ba96c8bfacf8548b3097f87e86ac355e931 /tensorflow/core/framework
parent2caec3af18a5c81f4c02f049cb5f39ca71700795 (diff)
Move duplicate detection logic from Graph to FunctionLibraryDefinition
Turns out this is more useful, since there are many function libraries that don't belong to a graph. This will be used in a future change. Note that this maintains the current behavior of Graph. In addition, updates FunctionDefsEqual() to handle unset attr entries (I ran into this when using this in said future change). PiperOrigin-RevId: 161126628
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r--tensorflow/core/framework/function.cc61
-rw-r--r--tensorflow/core/framework/function.h6
-rw-r--r--tensorflow/core/framework/function_test.cc47
3 files changed, 85 insertions, 29 deletions
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index 9d43aab5a5..c6a20bf3ce 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -724,6 +724,23 @@ string DebugStringWhole(const GraphDef& gdef) {
return ret;
}
+namespace {
+
+// Returns the name -> attr mapping of fdef's attrs that have a value set. In
+// Python, it's possible to access unset attrs, which returns a default value
+// and adds an unset attr to the map.
+std::map<StringPiece, AttrValue> GetSetAttrs(const FunctionDef& fdef) {
+ std::map<StringPiece, AttrValue> set_attrs;
+ for (auto iter : fdef.attr()) {
+ if (iter.second.value_case() != AttrValue::VALUE_NOT_SET) {
+ set_attrs[iter.first] = iter.second;
+ }
+ }
+ return set_attrs;
+}
+
+} // end namespace
+
bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) {
// NOTE(skyewm): Using MessageDifferencer would be better here, but that is
// currently not included in tensorflow/core/platform/default/protobuf.h, so
@@ -736,10 +753,12 @@ bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) {
f2.signature().SerializeToString(&sig2);
if (sig1 != sig2) return false;
- if (f1.attr().size() != f2.attr().size()) return false;
- for (auto iter1 : f1.attr()) {
- auto iter2 = f2.attr().find(iter1.first);
- if (iter2 == f2.attr().end()) return false;
+ std::map<StringPiece, AttrValue> f1_attrs = GetSetAttrs(f1);
+ std::map<StringPiece, AttrValue> f2_attrs = GetSetAttrs(f2);
+ if (f1_attrs.size() != f2_attrs.size()) return false;
+ for (auto iter1 : f1_attrs) {
+ auto iter2 = f2_attrs.find(iter1.first);
+ if (iter2 == f2_attrs.end()) return false;
if (!AreAttrValuesEqual(iter1.second, iter2->second)) return false;
}
@@ -883,11 +902,17 @@ const FunctionDef* FunctionLibraryDefinition::Find(const string& name) const {
}
Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) {
- auto& ptr = function_defs_[fdef.signature().name()];
- if (ptr != nullptr) {
- return errors::InvalidArgument("Function with name: ",
- fdef.signature().name(),
- " already exists in function library.");
+ std::unique_ptr<FunctionDefAndOpRegistration>* entry =
+ &function_defs_[fdef.signature().name()];
+ if (*entry != nullptr) {
+ if (!FunctionDefsEqual((*entry)->fdef, fdef)) {
+ return errors::InvalidArgument(
+ "Cannot add function '", fdef.signature().name(),
+ "' because a different function with the same name already "
+ "exists.");
+ }
+ // Ignore duplicate FunctionDefs
+ return Status::OK();
}
const OpDef* op_def;
if (default_registry_->LookUpOpDef(fdef.signature().name(), &op_def).ok()) {
@@ -895,19 +920,27 @@ Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) {
"Cannot add function '", fdef.signature().name(),
"' because an op with the same name already exists.");
}
- ptr.reset(new FunctionDefAndOpRegistration(fdef));
+ entry->reset(new FunctionDefAndOpRegistration(fdef));
return Status::OK();
}
Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) {
- if (func_grad_.count(grad.function_name()) > 0) {
- return errors::InvalidArgument("Gradient for function '",
- grad.function_name(), "' already exists.");
+ string* entry = &func_grad_[grad.function_name()];
+ if (!entry->empty()) {
+ if (*entry != grad.gradient_func()) {
+ return errors::InvalidArgument(
+ "Cannot assign gradient function '", grad.gradient_func(), "' to '",
+ grad.function_name(), "' because it already has gradient function ",
+ "'", *entry, "'");
+ }
+ // Ignore duplicate GradientDefs
+ return Status::OK();
}
- func_grad_[grad.function_name()] = grad.gradient_func();
+ *entry = grad.gradient_func();
return Status::OK();
}
+// TODO(skyewm): don't modify FunctionLibraryDefinition in case of error
Status FunctionLibraryDefinition::AddLibrary(
const FunctionLibraryDefinition& other) {
for (auto iter : other.function_defs_) {
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index b8d5b8797a..2342e08b38 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -287,20 +287,24 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
const FunctionDef* Find(const string& func) const;
// Adds function definition 'fdef' to this function library.
- // Returns status 'ok' on success, or error otherwise.
+ // Returns status 'ok' on success, or error otherwise. This is a no-op if
+ // 'fdef' already exists in this function library.
// If 'fdef' is successfully added to the library, it will be accessible
// from 'LookUp' and included in the proto returned by 'ToProto'.
Status AddFunctionDef(const FunctionDef& fdef);
// Adds gradient definition 'grad' to this function library.
+ // This is a no-op if 'grad' already exists in this function library.
// If 'grad' is successfully added, it will be accessible via 'FindGradient'
// and included in the proto returned by 'ToProto'.
Status AddGradientDef(const GradientDef& grad);
// Adds the functions and gradients in 'other' to this function library.
+ // Duplicate functions and gradients are ignored.
Status AddLibrary(const FunctionLibraryDefinition& other);
// Adds the functions and gradients in 'lib_def' to this function library.
+ // Duplicate functions and gradients are ignored.
Status AddLibrary(const FunctionDefLibrary& lib_def);
// If the gradient function for 'func' is specified explicitly in
diff --git a/tensorflow/core/framework/function_test.cc b/tensorflow/core/framework/function_test.cc
index 2ecdc36c11..1173384a1e 100644
--- a/tensorflow/core/framework/function_test.cc
+++ b/tensorflow/core/framework/function_test.cc
@@ -971,6 +971,10 @@ TEST(FunctionLibraryDefinitionTest, AddFunctionDef) {
EXPECT_EQ(s.error_message(),
"Cannot add function 'Add' because an op with the same name "
"already exists.");
+
+ // Already-added functions don't produce error
+ TF_EXPECT_OK(lib_def.AddFunctionDef(test::function::XTimesTwo()));
+ TF_EXPECT_OK(lib_def.AddFunctionDef(test::function::WXPlusB()));
}
TEST(FunctionLibraryDefinitionTest, AddGradientDef) {
@@ -984,12 +988,16 @@ TEST(FunctionLibraryDefinitionTest, AddGradientDef) {
grad.set_gradient_func(test::function::XTimesFour().signature().name());
TF_EXPECT_OK(lib_def.AddGradientDef(grad));
+ // Already-added gradients don't produce error
+ TF_EXPECT_OK(lib_def.AddGradientDef(grad));
+
// Test that adding a duplicate gradient fails
grad.set_gradient_func(test::function::XTimes16().signature().name());
Status s = lib_def.AddGradientDef(grad);
EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT);
EXPECT_EQ(s.error_message(),
- "Gradient for function 'XTimesTwo' already exists.");
+ "Cannot assign gradient function 'XTimes16' to 'XTimesTwo' because "
+ "it already has gradient function 'XTimesFour'");
}
TEST(FunctionLibraryDefinitionTest, AddLibrary) {
@@ -998,35 +1006,46 @@ TEST(FunctionLibraryDefinitionTest, AddLibrary) {
*proto.add_function() = test::function::XTimesTwo();
FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto);
- // Error if you try to add the same function twice
- Status s = lib_def.AddLibrary(lib_def);
- EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT);
- EXPECT_EQ(s.error_message(),
- "Function with name: XTimesTwo already exists in function "
- "library.");
-
// Add gradient
GradientDef grad;
grad.set_function_name(test::function::XTimesTwo().signature().name());
grad.set_gradient_func(test::function::XTimesFour().signature().name());
TF_EXPECT_OK(lib_def.AddGradientDef(grad));
- // Error if you try to add the same library function twice
+ // Error if you try to add conflicting function
proto.Clear();
- *proto.add_gradient() = grad;
+ FunctionDef fdef = test::function::XTimesFour();
+ fdef.mutable_signature()->set_name(
+ test::function::XTimesTwo().signature().name());
+ *proto.add_function() = fdef;
FunctionLibraryDefinition lib_def2(OpRegistry::Global(), proto);
- s = lib_def.AddLibrary(lib_def2);
+ Status s = lib_def.AddLibrary(lib_def2);
EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT);
EXPECT_EQ(s.error_message(),
- "Gradient for function 'XTimesTwo' already exists.");
+ "Cannot add function 'XTimesTwo' because a different function with "
+ "the same name already exists.");
+
+ // Error if you try to add conflicting gradient
+ proto.Clear();
+ grad.set_gradient_func(test::function::XTimes16().signature().name());
+ *proto.add_gradient() = grad;
+ FunctionLibraryDefinition lib_def3(OpRegistry::Global(), proto);
+ s = lib_def.AddLibrary(lib_def3);
+ EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT);
+ EXPECT_EQ(s.error_message(),
+ "Cannot assign gradient function 'XTimes16' to 'XTimesTwo' because "
+ "it already has gradient function 'XTimesFour'");
// No conflicting functions or gradients OK
proto.Clear();
*proto.add_function() = test::function::XTimesFour();
grad.set_function_name(test::function::XTimes16().signature().name());
*proto.add_gradient() = grad;
- FunctionLibraryDefinition lib_def3(OpRegistry::Global(), proto);
- TF_EXPECT_OK(lib_def.AddLibrary(lib_def3));
+ FunctionLibraryDefinition lib_def4(OpRegistry::Global(), proto);
+ TF_EXPECT_OK(lib_def.AddLibrary(lib_def4));
+
+ // OK to add the same functions and gradients twice
+ TF_EXPECT_OK(lib_def.AddLibrary(lib_def));
}
TEST(FunctionLibraryDefinitionTest, ToProto) {