diff options
author | Geoffrey Irving <geoffreyi@google.com> | 2017-05-16 16:08:20 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-05-16 16:12:05 -0700 |
commit | 749e5cc18381f7a5ec174673f76e20aead8529c6 (patch) | |
tree | 4b92d36c9e1d8e59e34fd8d08e7f11fbda1315d9 /tensorflow/core/framework | |
parent | ed5d05d8b53425ef98aad129a60143a5011a4288 (diff) |
Reduce direct references to NodeDef in favor of Node and AttrSlice
This is one step towards replacing in-memory use of NodeDef with a customized
NodeInfo class. There are still quite a few Node::def() references, but far fewer than before. Those remaining require more work, either because they are part of kernel registration (which is a bunch of functions), copy and modify the NodeDef, etc. Follow-on CLs will remove more.
RELNOTES: n/a
PiperOrigin-RevId: 156244933
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r-- | tensorflow/core/framework/function.cc | 102 | ||||
-rw-r--r-- | tensorflow/core/framework/function.h | 49 | ||||
-rw-r--r-- | tensorflow/core/framework/function_test.cc | 110 | ||||
-rw-r--r-- | tensorflow/core/framework/node_def_util.cc | 92 | ||||
-rw-r--r-- | tensorflow/core/framework/node_def_util.h | 39 | ||||
-rw-r--r-- | tensorflow/core/framework/op_kernel.cc | 23 | ||||
-rw-r--r-- | tensorflow/core/framework/shape_inference.h | 9 |
7 files changed, 254 insertions, 170 deletions
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index f46bb6e2ed..186095201d 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/function.pb_text.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -44,12 +45,11 @@ namespace { // Otherwise (arg_def is a simple type T), *is_type_list is set to // false, and *dtypes is set to a single element vector, whose only // element is T. -Status ArgNumType(const InstantiateAttrValueMap& attrs, - const OpDef::ArgDef& arg_def, bool* is_type_list, - DataTypeVector* dtypes) { +Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def, + bool* is_type_list, DataTypeVector* dtypes) { dtypes->clear(); if (!arg_def.type_list_attr().empty()) { - const AttrValue* v = gtl::FindOrNull(attrs, arg_def.type_list_attr()); + const AttrValue* v = attrs.Find(arg_def.type_list_attr()); if (v == nullptr) { return errors::NotFound("type attr not found: ", arg_def.type_list_attr()); @@ -64,7 +64,7 @@ Status ArgNumType(const InstantiateAttrValueMap& attrs, *is_type_list = false; int num = 1; if (!arg_def.number_attr().empty()) { - const AttrValue* v = gtl::FindOrNull(attrs, arg_def.number_attr()); + const AttrValue* v = attrs.Find(arg_def.number_attr()); if (v == nullptr) { return errors::NotFound("type attr not found: ", arg_def.type_attr()); } @@ -77,7 +77,7 @@ Status ArgNumType(const InstantiateAttrValueMap& attrs, } else if (arg_def.type_attr().empty()) { dtype = DT_INVALID; } else { - const AttrValue* v = gtl::FindOrNull(attrs, arg_def.type_attr()); + const AttrValue* v = attrs.Find(arg_def.type_attr()); if (v == nullptr) { return errors::NotFound("type attr not found: ", arg_def.type_attr()); } @@ -92,18 +92,17 @@ void AddAttr(const string& name, const T& val, NodeDef* ndef) { SetAttrValue(val, &((*ndef->mutable_attr())[name])); } -Status ValidateSignatureWithAttrs(const OpDef& sig, - const InstantiateAttrValueMap& attr_values) { +Status ValidateSignatureWithAttrs(const OpDef& sig, AttrSlice attr_values) { // attr_values should specify all attrs defined in fdef. for (const auto& a : sig.attr()) { - auto const iter = attr_values.find(a.name()); - if (iter == attr_values.end()) { + const AttrValue* v = attr_values.Find(a.name()); + if (!v) { return errors::NotFound("Attr ", a.name(), " is not found from ", SummarizeOpDef(sig)); } - Status status = AttrValueHasType(iter->second, a.type()); + Status status = AttrValueHasType(*v, a.type()); if (!status.ok()) { - errors::AppendToMessage(&status, "for attr '", iter->first, "'"); + errors::AppendToMessage(&status, "for attr '", a.name(), "'"); return status; } } @@ -146,7 +145,7 @@ class FunctionInstantiationHelper { // Builds index for nodes that can be used as node's input arguments. Status BuildInputArgIndex(const OpDef::ArgDef& arg_def, - const InstantiateAttrValueMap& attr_values) { + AttrSlice attr_values) { bool is_type_list; DataTypeVector dtypes; TF_RETURN_IF_ERROR( @@ -175,8 +174,7 @@ class FunctionInstantiationHelper { return Status::OK(); } - Status BuildNodeOutputIndex(const NodeDef& node, - const InstantiateAttrValueMap& attrs, + Status BuildNodeOutputIndex(const NodeDef& node, AttrSlice attrs, const int arg_index) { const OpDef* node_sig = nullptr; TF_RETURN_IF_ERROR(get_function_(node.op(), &node_sig)); @@ -206,8 +204,7 @@ class FunctionInstantiationHelper { return Status::OK(); } - Status InstantiateNode(const NodeDef& fnode, - const InstantiateAttrValueMap& attrs) { + Status InstantiateNode(const NodeDef& fnode, AttrSlice attrs) { const OpDef* fnode_sig = nullptr; TF_CHECK_OK(get_function_(fnode.op(), &fnode_sig)); NodeDef* gnode = AddNode(fnode.name()); @@ -295,7 +292,7 @@ class FunctionInstantiationHelper { } Status AddReturnNode( - const OpDef::ArgDef& ret_def, const InstantiateAttrValueMap& attrs, + const OpDef::ArgDef& ret_def, AttrSlice attrs, const ::tensorflow::protobuf::Map<string, string>& ret_map, int* ret_index) { auto ret_iter = ret_map.find(ret_def.name()); @@ -604,7 +601,7 @@ string Print(const GraphDef& gdef) { Status AddDefaultAttrs(const string& op, const GetFunctionSignature& get_function, - InstantiateAttrValueMap* attrs) { + AttrValueMap* attrs) { const OpDef* op_def = nullptr; TF_RETURN_IF_ERROR(get_function(op, &op_def)); AttrSlice attr_slice(attrs); @@ -620,8 +617,7 @@ Status AddDefaultAttrs(const string& op, } // end namespace -Status InstantiateFunction(const FunctionDef& fdef, - const InstantiateAttrValueMap& attr_values, +Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, GetFunctionSignature get_function, InstantiationResult* result) { VLOG(3) << "Instantiation Function: " << Print(fdef); @@ -639,19 +635,17 @@ Status InstantiateFunction(const FunctionDef& fdef, } } - auto substitute = [&attr_values](const string& name, AttrValue* val) { - auto iter = attr_values.find(name); - if (iter == attr_values.end()) { - return false; - } else { - *val = iter->second; + auto substitute = [attr_values](StringPiece name, AttrValue* val) { + if (const AttrValue* v = attr_values.Find(name)) { + *val = *v; return true; } + return false; }; // Makes a copy of all attrs in fdef and substitutes placeholders. // After this step, every attr is bound to a concrete value. - std::vector<InstantiateAttrValueMap> node_attrs; + std::vector<AttrValueMap> node_attrs; node_attrs.resize(fdef.node_def_size()); for (int i = 0; i < fdef.node_def_size(); ++i) { for (auto attr : fdef.node_def(i).attr()) { @@ -668,7 +662,7 @@ Status InstantiateFunction(const FunctionDef& fdef, } for (int i = 0; i < fdef.node_def_size(); ++i) { - s = helper.BuildNodeOutputIndex(fdef.node_def(i), node_attrs[i], + s = helper.BuildNodeOutputIndex(fdef.node_def(i), AttrSlice(&node_attrs[i]), result->gdef.node_size() + i); if (!s.ok()) { errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i))); @@ -677,7 +671,7 @@ Status InstantiateFunction(const FunctionDef& fdef, } // Emits one gdef.node for each fdef.node_def. for (int i = 0; i < fdef.node_def_size(); ++i) { - s = helper.InstantiateNode(fdef.node_def(i), node_attrs[i]); + s = helper.InstantiateNode(fdef.node_def(i), AttrSlice(&node_attrs[i])); if (!s.ok()) { errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i))); return s; @@ -748,8 +742,7 @@ bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) { return true; } -string Canonicalize(const string& funcname, - const InstantiateAttrValueMap& attrs) { +string Canonicalize(const string& funcname, AttrSlice attrs) { std::vector<string> entries; entries.reserve(attrs.size()); for (auto p : attrs) { @@ -953,8 +946,7 @@ const FunctionDef* FunctionLibraryDefinition::GetAttrImpl( // If ndef is SymbolicGradient[f=Foo], we use Foo's gradient or // Foo's attributes. const NameAttrList* forward_func_attrs; - if (!GetNodeAttr(AttrSlice(&ndef.attr()), kFuncAttr, &forward_func_attrs) - .ok()) { + if (!GetNodeAttr(ndef, kFuncAttr, &forward_func_attrs).ok()) { return nullptr; } const string& func_name = forward_func_attrs->name(); @@ -981,34 +973,30 @@ FunctionDefLibrary FunctionLibraryDefinition::ToProto() const { return lib; } -Status InstantiateFunction(const FunctionDef& fdef, - InstantiateAttrValueSlice attr_values, - GetFunctionSignature get_function, - InstantiationResult* result) { - InstantiateAttrValueMap m; - for (const auto& aval : attr_values) { - m.insert({aval.first, aval.second.proto}); +template <typename T> +Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef, + const string& attr, T* value) const { + const FunctionDef* fdef = GetAttrImpl(ndef); + if (fdef && GetNodeAttr(AttrSlice(&fdef->attr()), attr, value).ok()) { + return Status::OK(); } - return InstantiateFunction(fdef, m, std::move(get_function), result); + return errors::InvalidArgument("Attr ", attr, " is not defined."); } -string Canonicalize(const string& funcname, InstantiateAttrValueSlice attrs) { - InstantiateAttrValueMap m; - for (const auto& aval : attrs) { - m.insert({aval.first, aval.second.proto}); - } - return Canonicalize(funcname, m); +template <typename T> +Status FunctionLibraryDefinition::GetAttr(const Node& node, const string& attr, + T* value) const { + return GetAttr(node.def(), attr, value); } -Status FunctionLibraryRuntime::Instantiate(const string& function_name, - InstantiateAttrValueSlice attrs, - Handle* handle) { - InstantiateAttrValueMap m; - for (const auto& aval : attrs) { - m.insert({aval.first, aval.second.proto}); - } - return Instantiate(function_name, m, handle); -} +#define GET_ATTR(T) \ + template Status FunctionLibraryDefinition::GetAttr(const Node&, \ + const string&, T*) const; \ + template Status FunctionLibraryDefinition::GetAttr(const NodeDef&, \ + const string&, T*) const; +GET_ATTR(string) +GET_ATTR(bool) +#undef GET_ATTR void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) { if (val.size() >= 2 && val[0] == '$') { diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 210e5b949a..188c3855c6 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -36,6 +36,7 @@ class CancellationManager; class OpKernel; class ResourceMgr; class ScopedStepContainer; +class Node; // FunctionDefHelper::Create is a convenient helper to construct a // FunctionDef proto. @@ -190,11 +191,6 @@ inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(StringPiece val) { // InstantiateFunction calls "get_function" to find signatures of other // functions and primitive ops. -// Placeholders in "fdef" is substituted based on "attr_values" here. -typedef ::tensorflow::protobuf::Map<string, AttrValue> InstantiateAttrValueMap; -typedef gtl::ArraySlice<std::pair<string, FunctionDefHelper::AttrValueWrapper>> - InstantiateAttrValueSlice; - // GetFunctionSignature(func name, opdef) returns OK if the func name is found // and opdef is filled with a pointer to the corresponding signature // (a OpDef proto). Otherwise, returns an error. @@ -206,12 +202,7 @@ struct InstantiationResult { DataTypeVector ret_types; GraphDef gdef; }; -Status InstantiateFunction(const FunctionDef& fdef, - const InstantiateAttrValueMap& attr_values, - GetFunctionSignature get_function, - InstantiationResult* result); -Status InstantiateFunction(const FunctionDef& fdef, - InstantiateAttrValueSlice attr_values, +Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, GetFunctionSignature get_function, InstantiationResult* result); @@ -241,9 +232,7 @@ bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2); // space. But it may be change as the implementation // evolves. Therefore, it should not be persisted or compared across // address spaces. -string Canonicalize(const string& funcname, - const InstantiateAttrValueMap& attrs); -string Canonicalize(const string& funcname, InstantiateAttrValueSlice attrs); +string Canonicalize(const string& funcname, AttrSlice attrs); // Represents a function call frame. I.e., the data structure used to // pass arguments to a function and retrieve its results. @@ -330,9 +319,16 @@ class FunctionLibraryDefinition : public OpRegistryInterface { // Given a node def 'ndef', inspects attributes of the callee // function to derive the attribute 'value' for 'attr'. Returns OK // iff the attribute is given by the function's definition. + // TODO(irving): Remove; keep only the const Node& version. template <typename T> Status GetAttr(const NodeDef& ndef, const string& attr, T* value) const; + // Given a node, inspects attributes of the callee function to derive the + // attribute 'value' for 'attr'. Returns OK iff the attribute is given by the + // function's definition. + template <typename T> + Status GetAttr(const Node& node, const string& attr, T* value) const; + // Returns a proto representation of the state of this function library. FunctionDefLibrary ToProto() const; @@ -375,11 +371,8 @@ class FunctionLibraryRuntime { // Returns OK and fills in "handle" if the instantiation succeeds. // Otherwise returns an error and "handle" is undefined. typedef uint64 Handle; - virtual Status Instantiate(const string& function_name, - const InstantiateAttrValueMap& attrs, + virtual Status Instantiate(const string& function_name, AttrSlice attrs, Handle* handle) = 0; - Status Instantiate(const string& function_name, - InstantiateAttrValueSlice attrs, Handle* handle); // Returns the function body for the instantiated function given its // handle 'h'. Returns nullptr if "h" is not found. @@ -506,17 +499,15 @@ bool RegisterOp(const string& op, Creator func); Status GetOpGradientCreator(const string& op, Creator* creator); }; -// Implementation details. - -template <typename T> -Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef, - const string& attr, T* value) const { - const FunctionDef* fdef = GetAttrImpl(ndef); - if (fdef && GetNodeAttr(AttrSlice(&fdef->attr()), attr, value).ok()) { - return Status::OK(); - } - return errors::InvalidArgument("Attr ", attr, " is not defined."); -} +// Declare explicit instantiations of GetAttr +#define GET_ATTR(T) \ + extern template Status FunctionLibraryDefinition::GetAttr( \ + const Node&, const string&, T*) const; \ + extern template Status FunctionLibraryDefinition::GetAttr( \ + const NodeDef&, const string&, T*) const; +GET_ATTR(string) +GET_ATTR(bool) +#undef GET_ATTR } // end namespace tensorflow diff --git a/tensorflow/core/framework/function_test.cc b/tensorflow/core/framework/function_test.cc index 07462a575e..c83ecf4e5e 100644 --- a/tensorflow/core/framework/function_test.cc +++ b/tensorflow/core/framework/function_test.cc @@ -29,6 +29,24 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace tensorflow { +namespace { + +// A helper class to make AttrSlice from initializer lists +class Attrs { + public: + Attrs(const std::initializer_list< // NOLINT(runtime/explicit) + std::pair<string, FunctionDefHelper::AttrValueWrapper>> + attrs) { + for (const auto& aval : attrs) { + map_.insert({aval.first, aval.second.proto}); + } + } + + operator AttrSlice() { return AttrSlice(&map_); } // NOLINT(runtime/explicit) + + private: + AttrValueMap map_; +}; typedef FunctionDefHelper FDH; @@ -46,8 +64,6 @@ y: A scalar in type T. )doc"); -static InstantiateAttrValueMap kNoAttrs; - TEST(TFunc, SquarePlusOne) { auto fdef = FDH::Create( // Name @@ -81,7 +97,8 @@ SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) { // Instantiate one with T=float InstantiationResult result; - TF_ASSERT_OK(InstantiateFunction(fdef, {{"T", DT_FLOAT}}, GetOpSig, &result)); + TF_ASSERT_OK( + InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result)); const char* e2 = R"P( (x:float) -> (y:float) { a = Square[T=float](x) @@ -126,7 +143,8 @@ ControlDep(x:int32) -> (y:int32) { // Instantiate one with T=float InstantiationResult result; - TF_ASSERT_OK(InstantiateFunction(fdef, {{"T", DT_FLOAT}}, GetOpSig, &result)); + TF_ASSERT_OK( + InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result)); const char* e2 = R"P( (x:int32) -> (y:int32) { a = Identity[T=int32](x) @@ -171,8 +189,7 @@ BackCompat() -> (y:float) { EXPECT_EQ(DebugString(fdef), e); InstantiationResult result; - TF_ASSERT_OK( - InstantiateFunction(fdef, InstantiateAttrValueMap{}, GetOpSig, &result)); + TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result)); // Should get T=float from Op's default. const char* e2 = R"P( () -> (a:float) { @@ -209,7 +226,7 @@ NTimesT(x:float, y:float) -> (z:float) { EXPECT_EQ(DebugString(fdef), e); InstantiationResult result; - TF_ASSERT_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result)); + TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result)); const char* e2 = R"P( (x:float, y:float) -> (a:float) { a = AddN[N=2, T=float](x, y) @@ -272,8 +289,8 @@ AddSquared[N:int, T:{float, double, int32, int64}](x:N*T) -> (y:T) { // Instantiate one with T=float InstantiationResult result; - TF_ASSERT_OK(InstantiateFunction(fdef, {{"N", 3}, {"T", DT_FLOAT}}, GetOpSig, - &result)); + TF_ASSERT_OK(InstantiateFunction(fdef, Attrs({{"N", 3}, {"T", DT_FLOAT}}), + GetOpSig, &result)); const char* e2 = R"P( (x_0:float, x_1:float, x_2:float) -> (y:float) { a = Map[N=3, T=float, U=float, func=Square[T=float]](x_0, x_1, x_2) @@ -315,7 +332,7 @@ ControlDeps(x:float) -> () { EXPECT_EQ(DebugString(fdef), e); InstantiationResult result; - TF_ASSERT_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result)); + TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result)); const char* e2 = R"P( (x:float) -> () { a = One[T=float]() @ x @@ -395,7 +412,7 @@ Test(i:float) -> (o:float) { EXPECT_EQ(DebugString(fdef), e); InstantiationResult result; - TF_ASSERT_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result)); + TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result)); const char* e2 = R"P( (i:float) -> (o:float) { zero = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]() @@ -467,7 +484,7 @@ MySelect(x:float) -> (z:float) { EXPECT_EQ(DebugString(fdef), e); InstantiationResult result; - TF_ASSERT_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result)); + TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result)); const char* e2 = R"P( (x:float) -> (z:float) { y = Cond[Tin={float}, cond=MyCond, else_branch=MyElse, out_types={float}, then_branch=MyThen](x) @@ -488,8 +505,9 @@ TEST(InstantiateErrors, Not_Sufficient_Attrs) { auto fdef = FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {}); InstantiationResult result; - HasError(InstantiateFunction(fdef, {{"U", DT_FLOAT}}, GetOpSig, &result), - "Attr T is not found from "); + HasError( + InstantiateFunction(fdef, Attrs({{"U", DT_FLOAT}}), GetOpSig, &result), + "Attr T is not found from "); } #if 0 // TODO(josh11b): Enable this test once having an extra attr is an error. @@ -497,7 +515,7 @@ TEST(InstantiateErrors, Too_Many_Attrs) { auto fdef = FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {}); InstantiationResult result; - HasError(InstantiateFunction(fdef, {{"T", DT_INT32}, {"U", DT_FLOAT}}, + HasError(InstantiateFunction(fdef, Attrs({{"T", DT_INT32}, {"U", DT_FLOAT}}), GetOpSig, &result), "Attr U is not found in "); } @@ -508,7 +526,7 @@ TEST(InstantiateErrors, AttrValue_Value_Placeholder) { FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {}); InstantiationResult result; HasError( - InstantiateFunction(fdef, {{"T", "$bad"}}, GetOpSig, &result), + InstantiateFunction(fdef, Attrs({{"T", "$bad"}}), GetOpSig, &result), "AttrValue had value with unexpected type 'placeholder'\n\tfor attr 'T'"); } @@ -518,14 +536,15 @@ TEST(InstantiateErrors, Unbounded_Attr) { {{"a"}, "One", {}, {{"T", "$unknown"}}, {"x"}}, }); InstantiationResult result; - HasError(InstantiateFunction(fdef, {{"T", DT_FLOAT}}, GetOpSig, &result), - "Failed to bind all placeholders"); + HasError( + InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result), + "Failed to bind all placeholders"); } TEST(InstantiateErrors, DupArgs) { auto fdef = FDH::Define("test", {"x:float", "x:float"}, {}, {}, {}); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "Duplicated arg name"); } @@ -536,7 +555,7 @@ TEST(InstantiateErrors, Dup_Node_Names) { {{"y"}, "One", {}, {{"T", DT_FLOAT}}}, }); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "Duplicated ret name"); } @@ -547,7 +566,7 @@ TEST(InstantiateErrors, Node_Arg_Notfound) { }, {}); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "input z is not found"); } @@ -557,7 +576,7 @@ TEST(InstantiateErrors, Node_Arg_TypeMismatch) { {{"y"}, "Add", {"x", "x"}, {{"T", DT_INT32}}}, }); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "input x[0] expected type int32 != float, the type of x[0]"); } @@ -568,7 +587,7 @@ TEST(InstantiateErrors, Node_Arg_ControlMissing) { {{"y"}, "Add", {"x", "x"}, {{"T", DT_FLOAT}}, {"z"}}, }); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "input[2] == '^z', is not found."); } @@ -579,7 +598,7 @@ TEST(InstantiateErrors, FuncRet_Missing) { }, {}); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "Return y missing"); } @@ -590,7 +609,7 @@ TEST(InstantiateErrors, FuncRet_NotFound) { }, {{"y", "z"}}); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "Return y -> z is not found"); } @@ -601,7 +620,7 @@ TEST(InstantiateErrors, FuncRet_NameMismatch) { }, {{"z", "x:y:0"}}); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "Return y missing"); } @@ -613,7 +632,7 @@ TEST(InstantiateErrors, FuncRet_NameMismatch) { // }, // {{"y", "x:y:0"}, {"z", "x:y:0"}}); // InstantiationResult result; -// HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), +// HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), // "ret is not found"); // } @@ -623,7 +642,7 @@ TEST(InstantiateErrors, FuncRet_TypeMismatch) { {{"y"}, "One", {}, {{"T", DT_DOUBLE}}}, }); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "Invalid ret types y : float vs. double\n\tIn function output y"); } @@ -649,7 +668,7 @@ TEST(InstantiateErrors, TypeList_Missing_Retval_Attr) { }, {{"y", "y:output"}}); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "type attr not found: out_types"); } @@ -676,7 +695,7 @@ TEST(InstantiateErrors, TypeList_Num_Retval_Mismatch) { }, {{"y", "y:output"}}); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "Invalid ret types"); } @@ -703,7 +722,7 @@ TEST(InstantiateErrors, TypeList_Missing_Arg) { }, {{"y", "y:output"}}); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "input unknown is not found"); } @@ -724,7 +743,7 @@ TEST(InstantiateErrors, TooManyInputs) { {{"z", "a:sum:0"}}); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "Expected input[2] == 'x' to be a control input."); } @@ -745,7 +764,7 @@ TEST(InstantiateErrors, TooFewInputs) { {{"z", "a:sum:0"}}); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "Attempt to access beyond input size: 2 >= 2"); } @@ -773,7 +792,7 @@ TEST(InstantiateErrors, TooManyInputsFromArray1) { {{"z", "a:sum:0"}}); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "Expected input[1] == 'y' to be a control input."); } @@ -801,7 +820,7 @@ TEST(InstantiateErrors, TooManyInputsFromArray2) { {{"z", "a:sum:0"}}); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "Input a:output too long for inputs"); } @@ -822,7 +841,7 @@ TEST(InstantiateErrors, TypeMismatch) { {{"z", "a:sum:0"}}); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "input inputs[1] expected type float != int32, the type of y[0]"); } @@ -874,17 +893,17 @@ TEST(FunctionCallFrame, Float_Float_Float) { } TEST(Canonicalize, Basic) { - EXPECT_EQ(Canonicalize("MatMul", {{"T", DT_FLOAT}, - {"transpose_a", false}, - {"transpose_b", false}}), + EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_FLOAT}, + {"transpose_a", false}, + {"transpose_b", false}})), "MatMul[T=float,transpose_a=false,transpose_b=false]"); - EXPECT_EQ(Canonicalize("MatMul", {{"T", DT_FLOAT}, - {"transpose_b", false}, - {"transpose_a", false}}), + EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_FLOAT}, + {"transpose_b", false}, + {"transpose_a", false}})), "MatMul[T=float,transpose_a=false,transpose_b=false]"); - EXPECT_EQ(Canonicalize("MatMul", {{"T", DT_DOUBLE}, - {"transpose_b", true}, - {"transpose_a", false}}), + EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_DOUBLE}, + {"transpose_b", true}, + {"transpose_a", false}})), "MatMul[T=double,transpose_a=false,transpose_b=true]"); } @@ -1148,4 +1167,5 @@ TEST(FunctionDefsEqualTest, TestFunctionDefsEqual) { EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2)); } +} // end namespace } // end namespace tensorflow diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc index 36c0842924..9b737e1f72 100644 --- a/tensorflow/core/framework/node_def_util.cc +++ b/tensorflow/core/framework/node_def_util.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/op_def.pb_text.h" #include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/framework/tensor.pb_text.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/scanner.h" @@ -36,18 +37,23 @@ namespace tensorflow { const char* const kColocationAttrName = "_class"; const char* const kColocationGroupPrefix = "loc:@"; +AttrSlice::AttrSlice() : ndef_(nullptr) { + static const AttrValueMap* const kEmptyAttrValueMap = new AttrValueMap; + attrs_ = kEmptyAttrValueMap; +} + AttrSlice::AttrSlice(const NodeDef& node_def) : ndef_(&node_def), attrs_(&ndef_->attr()) {} AttrSlice::AttrSlice(const AttrValueMap* a) : ndef_(nullptr), attrs_(a) {} -string SummarizeNodeDef(const NodeDef& node_def) { - string ret = strings::StrCat(node_def.name(), " = ", node_def.op(), "["); +static string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device) { + string ret; // We sort the attrs so the output is deterministic. std::vector<string> attr_names; - attr_names.reserve(node_def.attr().size()); - for (const auto& attr : node_def.attr()) { + attr_names.reserve(attrs.size()); + for (const auto& attr : attrs) { attr_names.push_back(attr.first); } std::sort(attr_names.begin(), attr_names.end()); @@ -55,20 +61,34 @@ string SummarizeNodeDef(const NodeDef& node_def) { for (const string& attr_name : attr_names) { if (!first) strings::StrAppend(&ret, ", "); first = false; - auto iter = node_def.attr().find(attr_name); - strings::StrAppend(&ret, attr_name, "=", SummarizeAttrValue(iter->second)); + strings::StrAppend(&ret, attr_name, "=", + SummarizeAttrValue(*attrs.Find(attr_name))); } // Consider the device to be a final attr with name "_device". - if (!node_def.device().empty()) { + if (!device.empty()) { if (!first) strings::StrAppend(&ret, ", "); first = false; - strings::StrAppend(&ret, "_device=\"", node_def.device(), "\""); + strings::StrAppend(&ret, "_device=\"", device, "\""); } + return ret; +} + +string AttrSlice::SummarizeNode() const { + return ndef_ ? SummarizeNodeDef(*ndef_) + : strings::StrCat( + "[", SummarizeAttrsHelper(*this, StringPiece()), "]"); +} + +string SummarizeNode(const Node& node) { return SummarizeNodeDef(node.def()); } + +string SummarizeNodeDef(const NodeDef& node_def) { + string ret = strings::StrCat(node_def.name(), " = ", node_def.op(), "["); + strings::StrAppend(&ret, SummarizeAttrsHelper(node_def, node_def.device())); strings::StrAppend(&ret, "]("); // Output inputs, including control inputs, verbatim. - first = true; + bool first = true; for (const string& input : node_def.input()) { if (!first) strings::StrAppend(&ret, ", "); first = false; @@ -109,12 +129,28 @@ Status AttrSlice::Find(StringPiece attr_name, // Skip AttachDef for internal attrs since it is a little bit // expensive and it is common for them to correctly not be included // in a NodeDef. - if (!StringPiece(attr_name).starts_with("_") && ndef_) { + if (!attr_name.starts_with("_") && ndef_ != nullptr) { s = AttachDef(s, *ndef_); } return s; } +bool AttrSlice::EqualAttrs(AttrSlice other, Scratch* scratch) const { + if (size() != other.size()) return false; + + for (const auto& attr : *other.attrs_) { + auto iter = attrs_->find(attr.first); + if (iter == attrs_->end()) return false; + // TODO(irving): Comparing AttrValues by proto is slightly buggy, since + // TensorProto is a nonunique representation of Tensor. This bug will go + // away once AttrSlice switches over to NodeInfo. + iter->second.SerializeToString(&scratch->a); + attr.second.SerializeToString(&scratch->b); + if (scratch->a != scratch->b) return false; + } + return true; +} + // The ... is to allow the caller to inject some value validation code. Use // just ; if no additional validation code is needed. #define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \ @@ -341,14 +377,14 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) { if (StringPiece(input).starts_with("^")) { seen_control = true; if (input.find(':') != string::npos) { - return errors::InvalidArgument("Control input '", input, - "' must not have ':' in NodeDef: ", - SummarizeNodeDef(node_def)); + return errors::InvalidArgument( + "Control input '", input, + "' must not have ':' in NodeDef: ", SummarizeNodeDef(node_def)); } } else if (seen_control) { - return errors::InvalidArgument("Non-control input '", input, - "' after control input in NodeDef: ", - SummarizeNodeDef(node_def)); + return errors::InvalidArgument( + "Non-control input '", input, + "' after control input in NodeDef: ", SummarizeNodeDef(node_def)); } else { ++num_inputs; } @@ -358,8 +394,8 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) { for (const auto& attr : op_def.attr()) { if (!gtl::InsertIfNotPresent(&op_attrs, attr.name(), &attr)) { return errors::InvalidArgument("OpDef has duplicate attr name '", - attr.name(), "': ", - SummarizeOpDef(op_def)); + attr.name(), + "': ", SummarizeOpDef(op_def)); } } for (const auto& attr : node_def.attr()) { @@ -383,8 +419,9 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) { "with your GraphDef-generating binary.)."); } TF_RETURN_WITH_CONTEXT_IF_ERROR( - ValidateAttrValue(attr.second, *iter->second), "; NodeDef: ", - SummarizeNodeDef(node_def), "; ", SummarizeOpDef(op_def)); + ValidateAttrValue(attr.second, *iter->second), + "; NodeDef: ", SummarizeNodeDef(node_def), "; ", + SummarizeOpDef(op_def)); // Keep track of which attr names have (not) been found in the NodeDef. op_attrs.erase(iter); } @@ -431,9 +468,9 @@ Status ComputeArgRange(const NodeDef& node_def, const OpDef::ArgDef& arg_def, } else if (!arg_def.type_attr().empty() || arg_def.type() != DT_INVALID) { *num = 1; } else { - return errors::InvalidArgument("Argument '", arg_def.name(), - "' incorrectly specified in op definition: ", - SummarizeOpDef(op_def)); + return errors::InvalidArgument( + "Argument '", arg_def.name(), + "' incorrectly specified in op definition: ", SummarizeOpDef(op_def)); } return Status::OK(); } @@ -465,6 +502,11 @@ Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def, return Status::OK(); } +Status NameRangesForNode(const Node& node, const OpDef& op_def, + NameRangeMap* inputs, NameRangeMap* outputs) { + return NameRangesForNode(node.def(), op_def, inputs, outputs); +} + void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) { for (const auto& attr_def : op_def.attr()) { AttrSlice attrs(*node_def); @@ -565,4 +607,8 @@ Status AttachDef(const Status& status, const NodeDef& node_def) { return ret; } +Status AttachDef(const Status& status, const Node& node) { + return AttachDef(status, node.def()); +} + } // namespace tensorflow diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h index 018e4d15f2..1438abdec6 100644 --- a/tensorflow/core/framework/node_def_util.h +++ b/tensorflow/core/framework/node_def_util.h @@ -29,6 +29,8 @@ limitations under the License. namespace tensorflow { +class Node; + // Name of the attribute used to encode node colocation constraints. // // Nodes can be co-located on the same device. Desire for explicit co-location @@ -39,8 +41,9 @@ extern const char* const kColocationAttrName; // String prefix applied to the operation name for colocation constraints. extern const char* const kColocationGroupPrefix; -// Produce a human-readable version of a NodeDef that is more concise +// Produce a human-readable version of a Node or NodeDef that is more concise // than a text-format proto. +string SummarizeNode(const Node& node); string SummarizeNodeDef(const NodeDef& node_def); typedef protobuf::Map<string, AttrValue> AttrValueMap; @@ -78,8 +81,11 @@ class AttrSlice { public: AttrSlice(const NodeDef& node_def); // NOLINT(runtime/explicit) + AttrSlice(); // Empty explicit AttrSlice(const AttrValueMap* a); + int size() const { return attrs_->size(); } + // Returns the attr with attr_name if found. Otherwise, returns // nullptr. const AttrValue* Find(StringPiece attr_name) const; @@ -88,6 +94,33 @@ class AttrSlice { // NotFound status. Status Find(StringPiece attr_name, const AttrValue** attr_value) const; + // Helper class to avoid allocations in EqualAttrs. + // TODO(irving): Will go away once NodeInfo is used. + struct Scratch { + string a; + string b; + }; + + // Check if all attrs and attr values match. Does not take defaults into + // account. + // + // TODO(irving): There is a bug in this routine inherited from its + // OptimizerCSE::EqualAttrs precedecessor. The same tensor attr can be + // represented in more than one way as an AttrValue, since TensorProto is + // not 1-1. This bug will go away once I replace everything with NodeInfo, + // which stores a Tensor object directly. The Scratch object will also go + // away. + bool EqualAttrs(AttrSlice other, Scratch* scratch) const; + + // If this AttrSlice has an attached NodeDef, summarize it. This is for + // error messages only: we intentionally do not provide direct access to the + // NodeDef, since it is not always there. + string SummarizeNode() const; + + // Iteration over all attrs + AttrValueMap::const_iterator begin() const { return attrs_->begin(); } + AttrValueMap::const_iterator end() const { return attrs_->end(); } + private: const NodeDef* ndef_; const AttrValueMap* attrs_; @@ -183,9 +216,12 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def); // corresponding input/output index range. For example, // input "foo" corresponds to input indices // [ (*inputs)["foo"].first, (*inputs)["foo"].second ). +// TODO(irving): Remove the NodeDef version; keep only the Node version. typedef std::unordered_map<string, std::pair<int, int>> NameRangeMap; Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def, NameRangeMap* inputs, NameRangeMap* outputs); +Status NameRangesForNode(const Node& node, const OpDef& op_def, + NameRangeMap* inputs, NameRangeMap* outputs); // Adds default values to *node_def for unspecified attrs from op_def. void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def); @@ -206,6 +242,7 @@ Status ValidateExternalNodeDefSyntax(const NodeDef& node_def); // Returns "status" with kernel's NodeDef attached as additional text // in the error message. Status AttachDef(const Status& status, const NodeDef& node_def); +Status AttachDef(const Status& status, const Node& node); } // namespace tensorflow diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index 422ee80720..6c3917c686 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -842,13 +843,10 @@ bool InTypeList(DataType dt, const AttrValue& type_list) { return false; } -// Returns whether the attrs in the NodeDef satisfy the constraints in -// the kernel_def. Returns an error if attrs in kernel_def are not -// found, or have a mismatching type. -Status AttrsMatch(const NodeDef& node_def, const KernelDef& kernel_def, - bool* match) { +// Returns whether the attrs satisfy the constraints in the kernel_def. Returns +// an error if attrs in kernel_def are not found, or have a mismatching type. +Status AttrsMatch(AttrSlice attrs, const KernelDef& kernel_def, bool* match) { *match = false; - AttrSlice attrs(node_def); for (const auto& constraint : kernel_def.constraint()) { if (constraint.allowed_values().list().type_size() == 0) { return errors::Unimplemented( @@ -872,7 +870,7 @@ Status AttrsMatch(const NodeDef& node_def, const KernelDef& kernel_def, "' that has value '", SummarizeAttrValue(*found), "' that does not have type 'type' or 'list(type)' in NodeDef " "'", - SummarizeNodeDef(node_def), "'"); + attrs.SummarizeNode(), "'"); } for (int t : found->list().type()) { @@ -885,7 +883,7 @@ Status AttrsMatch(const NodeDef& node_def, const KernelDef& kernel_def, } else { return errors::InvalidArgument( "OpKernel '", kernel_def.op(), "' has constraint on attr '", - constraint.name(), "' not in NodeDef '", SummarizeNodeDef(node_def), + constraint.name(), "' not in NodeDef '", attrs.SummarizeNode(), "', KernelDef: '", ProtoShortDebugString(kernel_def), "'"); } } @@ -895,6 +893,7 @@ Status AttrsMatch(const NodeDef& node_def, const KernelDef& kernel_def, static const StringPiece kKernelAttr("_kernel"); +// TODO(irving): Replace with const Node& version below. Status FindKernelRegistration(const DeviceType& device_type, const NodeDef& node_def, const KernelRegistration** reg, @@ -927,8 +926,16 @@ Status FindKernelRegistration(const DeviceType& device_type, return Status::OK(); } +Status FindKernelRegistration(const DeviceType& device_type, const Node& node, + const KernelRegistration** reg, + bool* was_attr_mismatch) { + return FindKernelRegistration(device_type, node.def(), reg, + was_attr_mismatch); +} + } // namespace +// TODO(irving): Change const NodeDef& to const Node& Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def, const KernelDef** def, string* kernel_class_name) { const KernelRegistration* reg = nullptr; diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index cebadcc5b4..d064a8ec4d 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -184,8 +184,8 @@ class InferenceContext { } #ifndef NDEBUG for (int i = 0; i < num_outputs(); ++i) { - DCHECK(output(i).IsSet()) << i << " for " << node_def().name() - << " of type " << node_def().op(); + DCHECK(output(i).IsSet()) + << i << " for " << node_def_.name() << " of type " << node_def_.op(); } #endif // NDEBUG return s; @@ -394,11 +394,6 @@ class InferenceContext { // the value. Status MakeDimForScalarInput(int idx, DimensionHandle* out); - // Returns the NodeDef. The returned reference does not outlive the - // InferenceContext, and it should not be used after InferenceContext is - // destroyed. - const NodeDef& node_def() { return node_def_; } - // Look up the attr for the NodeDef being evaluated with name attr_name and // set *value to its value. If no attr with attr_name is found in def(), or // the attr does not have a matching type, a non-ok status will be returned. |