diff options
Diffstat (limited to 'tensorflow/core/framework/function.h')
-rw-r--r-- | tensorflow/core/framework/function.h | 376 |
1 files changed, 376 insertions, 0 deletions
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h new file mode 100644 index 0000000000..1ef93a0533 --- /dev/null +++ b/tensorflow/core/framework/function.h @@ -0,0 +1,376 @@ +#ifndef TENSORFLOW_FRAMEWORK_FUNCTION_H_ +#define TENSORFLOW_FRAMEWORK_FUNCTION_H_ + +#include <unordered_map> + +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +class CancellationManager; +class Node; +class OpKernel; + +// FunctionDefHelper::Define is a convenient helper to construct a +// FunctionDef proto. +// +// E.g., +// FunctionDef my_func = FunctionDefHelper::Define( +// "my_func_name", +// {"x:T", "y:T" /* one string per argument */}, +// {"z:T" /* one string per return value */}, +// {"T: {float, double}" /* one string per attribute */}, +// { +// {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}} +// /* one entry per function node */ +// }) +// +// NOTE: When we have a TFLang parser, we can add another helper: +// FunctionDef FunctionDefHelper::Define(const string& tf_func); +class FunctionDefHelper { + public: + // AttrValueWrapper has copy constructors for the type T so that + // it's easy to construct a simple AttrValue proto. + // + // If T is a string type (const char*, string, or StringPiece), and + // it starts with "$", we construct a AttrValue of "placeholder". + // + // E.g., + // std::<string, AttrValueWrapper> x = {"T", "$T"} + // is a named attr value placeholder. + struct AttrValueWrapper { + AttrValue proto; + + AttrValueWrapper() {} + + template <typename T> + AttrValueWrapper(T val) { // NOLINT(runtime/explicit) + SetAttrValue(val, &proto); + } + + private: + void InitFromString(StringPiece val); + }; + + // Constructs an AttrValue.func given the "name" and "attrs". + static AttrValueWrapper FunctionRef( + const string& name, + gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs); + static AttrValueWrapper FunctionRef(const string& name) { + return FunctionRef(name, {}); + } + + // Node is used to consturct FunctionDef.Node using initialization + // lists. E.g., + // Node n = {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}}; // z = x * y + struct Node { + std::vector<string> ret; + string op; + std::vector<string> arg; + std::vector<std::pair<string, AttrValueWrapper>> attr; + std::vector<string> dep; + + FunctionDef::Node ToProto() const; + }; + + static FunctionDef Define(const string& function_name, + gtl::ArraySlice<string> arg_def, + gtl::ArraySlice<string> ret_def, + gtl::ArraySlice<string> attr_def, + gtl::ArraySlice<Node> node_def); + + // Defines an anonymous function. I.e., its name is not relevant. + static FunctionDef Define(gtl::ArraySlice<string> arg_def, + gtl::ArraySlice<string> ret_def, + gtl::ArraySlice<string> attr_def, + gtl::ArraySlice<Node> node_def); + + // Helpers to construct a constant scalar. + template <typename T> + static Node Const(const string& name, const T& val) { + Node n = {{name}, "Const"}; + const DataType dtype = DataTypeToEnum<T>::value; + n.attr.push_back({"dtype", dtype}); + Tensor t(dtype, TensorShape({})); + t.scalar<T>()() = val; + n.attr.push_back({"value", t}); + return n; + } + + template <typename T> + static Node Const(const string& name, gtl::ArraySlice<T> vals) { + Node n = {{name}, "Const"}; + const DataType dtype = DataTypeToEnum<T>::value; + n.attr.push_back({"dtype", dtype}); + int64 num = vals.size(); + Tensor t(dtype, TensorShape({num})); + for (int i = 0; i < vals.size(); ++i) { + t.flat<T>()(i) = vals[i]; + } + n.attr.push_back({"value", t}); + return n; + } +}; + +template <> +inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(const char* val) { + InitFromString(val); +} + +template <> +inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper( + const string& val) { + InitFromString(val); +} + +template <> +inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(StringPiece val) { + InitFromString(val); +} + +// Instantiate a function. +// +// "fdef" encodes a TF function with some attrs in fdef.signature.attr +// containing placeholders. InstantiateFunction binds these +// placeholders and produces an instantiated function encoded in +// "result.gdef". The value to substitute a placeholder is given by +// "attr_values", which is a map from a placeholder name to an attr +// value. +// +// InstatiateFunction calls "get_function" to find signatures of other +// functions and primitive ops. + +// Placeholders in "fdef" is substitued based on "attr_values" here. +typedef ::tensorflow::protobuf::Map<string, AttrValue> InstantiateAttrValueMap; +typedef gtl::ArraySlice<std::pair<string, FunctionDefHelper::AttrValueWrapper>> + InstantiateAttrValueSlice; + +// GetFunctionSignature(func name, opdef) returns OK if the func name is found +// and opdef is filled with a pointer to the corresponding signature +// (a OpDef proto). Otherwise, returns an error. +typedef std::function<Status(const string&, const OpDef**)> + GetFunctionSignature; + +struct InstantiationResult { + DataTypeVector arg_types; + DataTypeVector ret_types; + GraphDef gdef; +}; +Status InstantiateFunction(const FunctionDef& fdef, + const InstantiateAttrValueMap& attr_values, + GetFunctionSignature get_function, + InstantiationResult* result); +Status InstantiateFunction(const FunctionDef& fdef, + InstantiateAttrValueSlice attr_values, + GetFunctionSignature get_function, + InstantiationResult* result); + +// Returns a debug string for a function definition. +// +// The returned text is multiple-line. It is intended to be +// human-readable rather than being friendly to parsers. It is _NOT_ +// intended to be the canonical string representation of "func_def". +// Particularly, it may not include all information presented in +// "func_def" (e.g., comments, description of the function arguments, +// etc.) +string DebugString(const FunctionDef& func_def); +string DebugString(const GraphDef& instantiated_func_def); + +// Returns a debug string for a top level graph (the main program and +// its supporting functions defined in its library). +string DebugStringWhole(const GraphDef& gdef); + +// Returns a canonicalized string for the instantiation of the +// function of the given "name" and attributes "attrs". +// +// The returned string is guaranteed to be stable within one address +// space. But it may be change as the implementation +// evolves. Therefore, it should not be persisted or compared across +// address spaces. +string Canonicalize(const string& funcname, + const InstantiateAttrValueMap& attrs); +string Canonicalize(const string& funcname, InstantiateAttrValueSlice attrs); + +// Represents a function call frame. I.e., the data structure used to +// pass arguments to a function and retrieve its results. +// +// Runtime must arrange accesses to one FunctionCallFrame s.t. +// 1. SetArgs() happens before any GetArg(); +// 2. GetRetvals happens after all SetRetval(); +class FunctionCallFrame { + public: + FunctionCallFrame(DataTypeSlice arg_types, DataTypeSlice ret_types); + ~FunctionCallFrame(); + + // Caller methods. + Status SetArgs(gtl::ArraySlice<Tensor> args); + Status GetRetvals(std::vector<Tensor>* rets) const; + + // Callee methods. + Status GetArg(int index, Tensor* val) const; + Status SetRetval(int index, const Tensor& val); + + private: + DataTypeVector arg_types_; + DataTypeVector ret_types_; + gtl::InlinedVector<Tensor, 4> args_; + struct Retval { + bool has_val = false; + Tensor val; + }; + gtl::InlinedVector<Retval, 4> rets_; + + TF_DISALLOW_COPY_AND_ASSIGN(FunctionCallFrame); +}; + +// Helper to maintain a map between function names in a given +// FunctionDefLibrary and function definitions. +class FunctionLibraryDefinition : public OpRegistryInterface { + public: + explicit FunctionLibraryDefinition(const FunctionDefLibrary& lib_def); + ~FunctionLibraryDefinition() override; + + // Returns nullptr if "func" is not defined in "lib_def". Otherwise, + // returns its definition proto. + const FunctionDef* Find(const string& func) const; + + // OpRegistryInterface method. Useful for constructing a Graph. + // + // If "op" is defined in the library, returns its signature. + // Otherwise, assume "op" is a primitive op and returns its op + // signature. + const OpDef* LookUp(const string& op, Status* status) const override; + + private: + std::unordered_map<string, FunctionDef> function_defs_; + + TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryDefinition); +}; + +// Forward declare. Defined in common_runtime/function.h +struct FunctionBody; + +class FunctionLibraryRuntime { + public: + virtual ~FunctionLibraryRuntime() {} + + // Instantiate a function with the given "attrs". + // + // Returns OK and fills in "handle" if the instantiation succeeds. + // Otherwise returns an error and "handle" is undefined. + typedef uint64 Handle; + virtual Status Instantiate(const string& function_name, + const InstantiateAttrValueMap& attrs, + Handle* handle) = 0; + Status Instantiate(const string& function_name, + InstantiateAttrValueSlice attrs, Handle* handle); + + // Returns the function body for the instantiated function given its + // handle 'h'. Returns nullptr if "h" is not found. + // + // *this keeps the ownership of the returned object, which remains alive + // as long as *this. + virtual const FunctionBody* GetFunctionBody(Handle h) = 0; + + // Asynchronously invokes the instantiated function identified by + // "handle". + // + // If function execution succeeds, "done" is called with OK and + // "*rets" is filled with the function's return values. Otheriwse, + // "done" is called with an error status. + // + // Does not take ownership of "rets". + struct Options { + CancellationManager* cancellation_manager = nullptr; + }; + typedef std::function<void(const Status&)> DoneCallback; + virtual void Run(const Options& opts, Handle handle, + gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets, + DoneCallback done) = 0; + + // Creates a "kernel" for the given node def "ndef". + // + // If succeeds, returns OK and the caller takes the ownership of the + // returned "*kernel". Otherwise, returns an error. + virtual Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) = 0; + + // Return true iff 'function_name' is the name of a defined function. + virtual bool IsDefined(const string& function_name) = 0; +}; + +// To register a gradient function for a builtin op, one should use +// REGISTER_OP_GRADIENT(<op_name>, <c++ grad factory>); +// +// Typically, the c++ grad factory is a plan function that can be +// converted into ::tensorflow::gradient::Creator, which is +// std::function<Status(const AttrSlice&, FunctionDef*)>. +// +// A ::tensorflow::gradient::Creator should populate in FunctionDef* with a +// definition of a brain function which computate the gradient for the +// <op_name> when the <op_name> is instantiated with the given attrs. +// +// E.g., +// +// Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) { +// bool transpose_a; +// TF_RETURN_IF_ERROR(attrs.Get("transpose_a", &transpose_a)); +// bool transpose_b; +// TF_RETURN_IF_ERROR(attrs.Get("transpose_b", &transpose_b)); +// DataType dtype; +// TF_RETURN_IF_ERROR(attrs.Get("dtype", &dtype)); +// if (!transpose_a && !transpose_b) { +// *g = FunctionDefHelper::Define( +// "MatMulGrad", +// {"x:T ", "y:T", "dz:T"}, // Inputs to this function +// {"dx:T", "dy:T"}, // Outputs from this function +// {"T: {float, double}"}, // Attributes needed by this function +// { +// {{"x_t"}, "Transpose", {"x"}, {{"T", "$T"}}}, +// {{"y_t"}, "Transpose", {"y"}, {{"T", "$T"}}}, +// {{"dx"}, "MatMul", {"dz", "y_t"}, {{"T", "$T"}}}, +// {{"dy"}, "MatMul", {"x_", "dz"}, {{"T", "$T"}}}, +// }); +// } else { +// ... ... +// } +// return Status::OK(); +// } +// +// NOTE: $T is substituted with the type variable "T" when the +// gradient function MatMul is instantiated. +// +// TODO(zhifengc): Better documentation somewhere. + +// Macros to define a gradient function factory for a primitive +// operation. +#define REGISTER_OP_GRADIENT(name, fn) \ + REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, fn) + +#define REGISTER_OP_NO_GRADIENT(name) \ + REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, nullptr) + +#define REGISTER_OP_GRADIENT_UNIQ_HELPER(ctr, name, fn) \ + REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn) + +#define REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn) \ + static bool unused_grad_##ctr = ::tensorflow::gradient::RegisterOp(name, fn) + +namespace gradient { +// Register a gradient creator for the "op". +typedef std::function<Status(const AttrSlice& attrs, FunctionDef*)> Creator; +bool RegisterOp(const string& op, Creator func); + +// Returns OK the gradient creator for the "op" is found (may be +// nullptr if REGISTER_OP_NO_GRADIENT is used. +Status GetOpGradientCreator(const string& op, Creator* creator); +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_FUNCTION_H_ |