aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/function.proto
blob: 4003943a8526e46d867368b6f7dde53bfc506b0d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
syntax = "proto3";

package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "FunctionProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "tensorflow/core/framework/attr_value.proto";
import "tensorflow/core/framework/op_def.proto";

// A library is a set of named functions.
message FunctionDefLibrary {
  repeated FunctionDef function = 1;
  repeated GradientDef gradient = 2;
}

// A function can be instantiated when the runtime can bind every attr
// with a value. When a GraphDef has a call to a function, it must
// have binding for every attr defined in the signature.
//
// TODO(zhifengc):
//   * device spec, etc.
message FunctionDef {
  // The definition of the function's name, arguments, return values,
  // attrs etc.
  OpDef signature = 1;

  // The body of the function.
  repeated Node node = 2;  // function.node.ret[*] are unique.

  // A node is a multi-value assignment:
  //   (ret[0], ret[1], ...) = func(arg[0], arg[1], ...)
  //
  // By convention, "func" is resolved by consulting with a user-defined
  // library first. If not resolved, "func" is assumed to be a builtin op.
  message Node {
    // This node produces multiple outputs. They are named ret[0],
    // ret[1], ..., etc.
    //
    // REQUIRES: function.node.ret[*] are unique across all nodes.
    // REQUIRES: ret.size == func/op def's number of output args.
    repeated string ret = 1;

    // The op/function name.
    string op = 2;

    // Arguments passed to this func/op.
    //
    // arg[i] must be either one of
    // function.signature.input_args[*].name or one of
    // function.node[*].ret[*].
    //
    // REQUIRES: arg.size == func/op def's number of input args.
    repeated string arg = 3;

    // Control dependencies.
    //
    // dep[i] must be one of function.node[*].ret[*] or one of
    // function.signature.input_args[*].name.
    repeated string dep = 4;

    // Attrs.
    //
    // 'attr' maps names defined by 'func's attr defs to attr values.
    // attr values may have placeholders which are substituted
    // recursively by concrete values when this node is instantiated.
    // These placeholders must name an attr listed in the FunctionDef's
    // signature.
    map<string, AttrValue> attr = 5;
  }
}

// GradientDef defines the gradient function of a function defined in
// a function library.
//
// A gradient function g (specified by gradient_func) for a function f
// (specified by function_name) must follow the following:
//
// The function 'f' must be a numerical function which takes N inputs
// and produces M outputs. Its gradient function 'g', which is a
// function taking N + M inputs and produces N outputs.
//
// I.e. if we have
//    (y1, y2, ..., y_M) = f(x1, x2, ..., x_N),
// then, g is
//    (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N,
//                                      dL/dy1, dL/dy2, ..., dL/dy_M),
// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the
// loss function). dL/dx_i is the partial derivative of L with respect
// to x_i.
message GradientDef {
  string function_name = 1;  // The function name.
  string gradient_func = 2;  // The gradient function's name.
}