aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/gradients/README.md
blob: 3253163cc735cfdda60910742cb72a44d38a5bf6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# C++ gradients

Gradients are currently being ported from
[python](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python/ops)
to C++ (in this directory).

Contributions are welcome and much appreciated; please follow the instructions
below.

1.  Create the op gradient function in `foo_grad.cc` corresponding to the
    `foo_grad.py` file where the op originated (i.e. `array_grad.py` op
    gradients should be written in `array_grad.cc`).

2.  Write the op gradient with the following naming scheme:

        Status OpNameGrad(const Scope& scope, const Operation& op,
                          const std::vector<Output>& grad_inputs,
                          std::vector<Output>* grad_outputs) {
          ...
          return scope.status();
        }
        REGISTER_GRADIENT_OP("OpName", OpNameGrad);

3.  Ops gradients are implemented by using the [C++
    API](https://www.tensorflow.org/api_docs/cc/).

4.  Tests should be included in `foo_grad_test.cc`. Please see
    [`array_grad_test.cc`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/gradients/array_grad_test.cc)
    for an many examples. Tests are as simple as, creating a placeholder input
    for the op's inputs and calling `RunTest` (`RunTest` uses a [gradient
    checker](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/framework/gradient_checker.cc)
    to verify that the theoretical gradient matches the numeric gradient). For
    example:

        TEST_F(ArrayGradTest, IdentityGrad) {
          TensorShape shape({5, 2});
          auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
          auto y = Identity(scope_, x);
          RunTest(x, shape, y, shape);
        }

NOTE: There are some ops that require features from the C++ API that are not yet
implemented.

*   Ops that require PartialTensorShape information cannot yet be implemented.

*   Ops that require SparseTensor or IndexSlices (currently only in python)
    cannot yet be implemented.

*   Maybe more.

For questions: Please create an issue assigned to suharshs.