diff options
author | Geoffrey Irving <geoffreyi@google.com> | 2017-05-16 16:08:20 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-05-16 16:12:05 -0700 |
commit | 749e5cc18381f7a5ec174673f76e20aead8529c6 (patch) | |
tree | 4b92d36c9e1d8e59e34fd8d08e7f11fbda1315d9 /tensorflow/core/framework/function.h | |
parent | ed5d05d8b53425ef98aad129a60143a5011a4288 (diff) |
Reduce direct references to NodeDef in favor of Node and AttrSlice
This is one step towards replacing in-memory use of NodeDef with a customized
NodeInfo class. There are still quite a few Node::def() references, but far fewer than before. Those remaining require more work, either because they are part of kernel registration (which is a bunch of functions), copy and modify the NodeDef, etc. Follow-on CLs will remove more.
RELNOTES: n/a
PiperOrigin-RevId: 156244933
Diffstat (limited to 'tensorflow/core/framework/function.h')
-rw-r--r-- | tensorflow/core/framework/function.h | 49 |
1 files changed, 20 insertions, 29 deletions
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 210e5b949a..188c3855c6 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -36,6 +36,7 @@ class CancellationManager; class OpKernel; class ResourceMgr; class ScopedStepContainer; +class Node; // FunctionDefHelper::Create is a convenient helper to construct a // FunctionDef proto. @@ -190,11 +191,6 @@ inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(StringPiece val) { // InstantiateFunction calls "get_function" to find signatures of other // functions and primitive ops. -// Placeholders in "fdef" is substituted based on "attr_values" here. -typedef ::tensorflow::protobuf::Map<string, AttrValue> InstantiateAttrValueMap; -typedef gtl::ArraySlice<std::pair<string, FunctionDefHelper::AttrValueWrapper>> - InstantiateAttrValueSlice; - // GetFunctionSignature(func name, opdef) returns OK if the func name is found // and opdef is filled with a pointer to the corresponding signature // (a OpDef proto). Otherwise, returns an error. @@ -206,12 +202,7 @@ struct InstantiationResult { DataTypeVector ret_types; GraphDef gdef; }; -Status InstantiateFunction(const FunctionDef& fdef, - const InstantiateAttrValueMap& attr_values, - GetFunctionSignature get_function, - InstantiationResult* result); -Status InstantiateFunction(const FunctionDef& fdef, - InstantiateAttrValueSlice attr_values, +Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, GetFunctionSignature get_function, InstantiationResult* result); @@ -241,9 +232,7 @@ bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2); // space. But it may be change as the implementation // evolves. Therefore, it should not be persisted or compared across // address spaces. -string Canonicalize(const string& funcname, - const InstantiateAttrValueMap& attrs); -string Canonicalize(const string& funcname, InstantiateAttrValueSlice attrs); +string Canonicalize(const string& funcname, AttrSlice attrs); // Represents a function call frame. I.e., the data structure used to // pass arguments to a function and retrieve its results. @@ -330,9 +319,16 @@ class FunctionLibraryDefinition : public OpRegistryInterface { // Given a node def 'ndef', inspects attributes of the callee // function to derive the attribute 'value' for 'attr'. Returns OK // iff the attribute is given by the function's definition. + // TODO(irving): Remove; keep only the const Node& version. template <typename T> Status GetAttr(const NodeDef& ndef, const string& attr, T* value) const; + // Given a node, inspects attributes of the callee function to derive the + // attribute 'value' for 'attr'. Returns OK iff the attribute is given by the + // function's definition. + template <typename T> + Status GetAttr(const Node& node, const string& attr, T* value) const; + // Returns a proto representation of the state of this function library. FunctionDefLibrary ToProto() const; @@ -375,11 +371,8 @@ class FunctionLibraryRuntime { // Returns OK and fills in "handle" if the instantiation succeeds. // Otherwise returns an error and "handle" is undefined. typedef uint64 Handle; - virtual Status Instantiate(const string& function_name, - const InstantiateAttrValueMap& attrs, + virtual Status Instantiate(const string& function_name, AttrSlice attrs, Handle* handle) = 0; - Status Instantiate(const string& function_name, - InstantiateAttrValueSlice attrs, Handle* handle); // Returns the function body for the instantiated function given its // handle 'h'. Returns nullptr if "h" is not found. @@ -506,17 +499,15 @@ bool RegisterOp(const string& op, Creator func); Status GetOpGradientCreator(const string& op, Creator* creator); }; -// Implementation details. - -template <typename T> -Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef, - const string& attr, T* value) const { - const FunctionDef* fdef = GetAttrImpl(ndef); - if (fdef && GetNodeAttr(AttrSlice(&fdef->attr()), attr, value).ok()) { - return Status::OK(); - } - return errors::InvalidArgument("Attr ", attr, " is not defined."); -} +// Declare explicit instantiations of GetAttr +#define GET_ATTR(T) \ + extern template Status FunctionLibraryDefinition::GetAttr( \ + const Node&, const string&, T*) const; \ + extern template Status FunctionLibraryDefinition::GetAttr( \ + const NodeDef&, const string&, T*) const; +GET_ATTR(string) +GET_ATTR(bool) +#undef GET_ATTR } // end namespace tensorflow |