aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar Mingsheng Hong <hongm@google.com>2018-10-05 10:31:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 10:35:34 -0700
commitd493a7f2fdbbc29a292741135f4c1598352e876b (patch)
treead894183d0c747c76b4dd53d7a96180515ea6b14 /tensorflow/core
parent8b7c789e7401fe56b4f648a04f675a3cb69119e5 (diff)
When running a native/builtin op via eager C API, automatically fill in default
attr values that are not overridden e.g. transpose_a in the matmul op). This is required for backward compatibility (a binary built via an older version of TF should still run on a newer version of TF, where some ops may have added attrs). For non-eager graph building, the default attr values of graph ops are added by tensorflow::AddDefaultsToNodeDef(). We ran into this issue when running the same S4TF test cases via eager APIs -- some tests failed due to "missing attrs", but are fixed by this patch. PiperOrigin-RevId: 215927271
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/common_runtime/eager/attr_builder.cc16
-rw-r--r--tensorflow/core/common_runtime/eager/attr_builder.h6
2 files changed, 22 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/eager/attr_builder.cc b/tensorflow/core/common_runtime/eager/attr_builder.cc
index cf1cd4134e..5c8369de87 100644
--- a/tensorflow/core/common_runtime/eager/attr_builder.cc
+++ b/tensorflow/core/common_runtime/eager/attr_builder.cc
@@ -136,6 +136,22 @@ void AttrBuilder::FillAttrValueMap(AttrValueMap* m,
m->insert(*it);
}
}
+ // For any attr-value pairs that exist in the op def (from op registry) but
+ // not `m`, fill them into `m`, so that we can run a TFE_Op without having to
+ // specify all the default attr values (e.g. for matmul, the `transpose_a`
+ // attr defaults to false).
+ const OpDef* op_def = nullptr;
+ Status s = OpDefForOp(op_name_.c_str(), &op_def);
+ // This is expected, if this op is a custom function, and is therefore not
+ // present in the op registry.
+ if (!s.ok()) return;
+
+ DCHECK(op_def);
+ for (const auto& attr_def : op_def->attr()) {
+ if (attr_def.has_default_value() && !m->count(attr_def.name())) {
+ SetInAttrValueMap(m, attr_def.name(), attr_def.default_value());
+ }
+ }
}
const NodeDef& AttrBuilder::BuildNodeDef() {
diff --git a/tensorflow/core/common_runtime/eager/attr_builder.h b/tensorflow/core/common_runtime/eager/attr_builder.h
index cbe6a1cb50..c114ea4ba0 100644
--- a/tensorflow/core/common_runtime/eager/attr_builder.h
+++ b/tensorflow/core/common_runtime/eager/attr_builder.h
@@ -110,6 +110,12 @@ class AttrBuilder {
using AttrVec = tensorflow::gtl::InlinedVector<std::pair<StringPiece, T>, 2>;
void MayBeInitializeNodeDef();
+ // Fill `m` with the attr-value pairs set via AttrBuilder::Set() so far, as
+ // well as any default attr-value pairs from the associated op_def, if there
+ // is one.
+ //
+ // If `include_those_in_node_def` is true, also include any attr-value pairs
+ // from `node_def_`.
void FillAttrValueMap(AttrValueMap* m, bool include_those_in_node_def) const;
template <class T>