aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/function.cc
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-07-06 13:59:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-06 14:03:01 -0700
commit7d5c74a9c89333f4f5b0524856907ac1b9f43a7b (patch)
tree92961ba96c8bfacf8548b3097f87e86ac355e931 /tensorflow/core/framework/function.cc
parent2caec3af18a5c81f4c02f049cb5f39ca71700795 (diff)
Move duplicate detection logic from Graph to FunctionLibraryDefinition
Turns out this is more useful, since there are many function libraries that don't belong to a graph. This will be used in a future change. Note that this maintains the current behavior of Graph. In addition, updates FunctionDefsEqual() to handle unset attr entries (I ran into this when using this in said future change). PiperOrigin-RevId: 161126628
Diffstat (limited to 'tensorflow/core/framework/function.cc')
-rw-r--r--tensorflow/core/framework/function.cc61
1 files changed, 47 insertions, 14 deletions
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index 9d43aab5a5..c6a20bf3ce 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -724,6 +724,23 @@ string DebugStringWhole(const GraphDef& gdef) {
return ret;
}
+namespace {
+
+// Returns the name -> attr mapping of fdef's attrs that have a value set. In
+// Python, it's possible to access unset attrs, which returns a default value
+// and adds an unset attr to the map.
+std::map<StringPiece, AttrValue> GetSetAttrs(const FunctionDef& fdef) {
+ std::map<StringPiece, AttrValue> set_attrs;
+ for (auto iter : fdef.attr()) {
+ if (iter.second.value_case() != AttrValue::VALUE_NOT_SET) {
+ set_attrs[iter.first] = iter.second;
+ }
+ }
+ return set_attrs;
+}
+
+} // end namespace
+
bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) {
// NOTE(skyewm): Using MessageDifferencer would be better here, but that is
// currently not included in tensorflow/core/platform/default/protobuf.h, so
@@ -736,10 +753,12 @@ bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) {
f2.signature().SerializeToString(&sig2);
if (sig1 != sig2) return false;
- if (f1.attr().size() != f2.attr().size()) return false;
- for (auto iter1 : f1.attr()) {
- auto iter2 = f2.attr().find(iter1.first);
- if (iter2 == f2.attr().end()) return false;
+ std::map<StringPiece, AttrValue> f1_attrs = GetSetAttrs(f1);
+ std::map<StringPiece, AttrValue> f2_attrs = GetSetAttrs(f2);
+ if (f1_attrs.size() != f2_attrs.size()) return false;
+ for (auto iter1 : f1_attrs) {
+ auto iter2 = f2_attrs.find(iter1.first);
+ if (iter2 == f2_attrs.end()) return false;
if (!AreAttrValuesEqual(iter1.second, iter2->second)) return false;
}
@@ -883,11 +902,17 @@ const FunctionDef* FunctionLibraryDefinition::Find(const string& name) const {
}
Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) {
- auto& ptr = function_defs_[fdef.signature().name()];
- if (ptr != nullptr) {
- return errors::InvalidArgument("Function with name: ",
- fdef.signature().name(),
- " already exists in function library.");
+ std::unique_ptr<FunctionDefAndOpRegistration>* entry =
+ &function_defs_[fdef.signature().name()];
+ if (*entry != nullptr) {
+ if (!FunctionDefsEqual((*entry)->fdef, fdef)) {
+ return errors::InvalidArgument(
+ "Cannot add function '", fdef.signature().name(),
+ "' because a different function with the same name already "
+ "exists.");
+ }
+ // Ignore duplicate FunctionDefs
+ return Status::OK();
}
const OpDef* op_def;
if (default_registry_->LookUpOpDef(fdef.signature().name(), &op_def).ok()) {
@@ -895,19 +920,27 @@ Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) {
"Cannot add function '", fdef.signature().name(),
"' because an op with the same name already exists.");
}
- ptr.reset(new FunctionDefAndOpRegistration(fdef));
+ entry->reset(new FunctionDefAndOpRegistration(fdef));
return Status::OK();
}
Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) {
- if (func_grad_.count(grad.function_name()) > 0) {
- return errors::InvalidArgument("Gradient for function '",
- grad.function_name(), "' already exists.");
+ string* entry = &func_grad_[grad.function_name()];
+ if (!entry->empty()) {
+ if (*entry != grad.gradient_func()) {
+ return errors::InvalidArgument(
+ "Cannot assign gradient function '", grad.gradient_func(), "' to '",
+ grad.function_name(), "' because it already has gradient function ",
+ "'", *entry, "'");
+ }
+ // Ignore duplicate GradientDefs
+ return Status::OK();
}
- func_grad_[grad.function_name()] = grad.gradient_func();
+ *entry = grad.gradient_func();
return Status::OK();
}
+// TODO(skyewm): don't modify FunctionLibraryDefinition in case of error
Status FunctionLibraryDefinition::AddLibrary(
const FunctionLibraryDefinition& other) {
for (auto iter : other.function_defs_) {