aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/docs_src/extend/adding_an_op.md
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/docs_src/extend/adding_an_op.md')
-rw-r--r--tensorflow/docs_src/extend/adding_an_op.md16
1 files changed, 13 insertions, 3 deletions
diff --git a/tensorflow/docs_src/extend/adding_an_op.md b/tensorflow/docs_src/extend/adding_an_op.md
index f95298d377..4fc4c2faa2 100644
--- a/tensorflow/docs_src/extend/adding_an_op.md
+++ b/tensorflow/docs_src/extend/adding_an_op.md
@@ -1056,7 +1056,7 @@ cuda_op_kernel.cu.o -I $TF_INC -fPIC -lcudart
Note that if your CUDA libraries are not installed in `/usr/local/lib64`,
you'll need to specify the path explicitly in the second (g++) command above.
-For example, add `-L /usr/local/cuda-8.0/lib64/` if your CUDA is installed in
+For example, add `-L /usr/local/cuda-8.0/lib64/` if your CUDA is installed in
`/usr/local/cuda-8.0`.
### Implement the gradient in Python {#implement-gradient}
@@ -1160,7 +1160,9 @@ for ZeroOut:
```
`c->set_output(0, c->input(0));` declares that the first output's shape should
-be set to the first input's shape. There are a number of common shape functions
+be set to the first input's shape. If the output is selected by its index as in the above example, the second parameter of `set_output` should be a `ShapeHandle` object. You can create an empty `ShapeHandle` object by its default constructor. The `ShapeHandle` object for an input with index `idx` can be obtained by `c->input(idx)`.
+
+There are a number of common shape functions
that apply to many ops, such as `shape_inference::UnchangedShape` which can be
found in [common_shape_fns.h](https://www.tensorflow.org/code/tensorflow/core/framework/common_shape_fns.h) and used as follows:
@@ -1220,7 +1222,15 @@ particular dimension has a very specific value using `InferenceContext::Dim` and
`InferenceContext::WithValue`; you can specify that an output dimension is the
sum / product of two input dimensions using `InferenceContext::Add` and
`InferenceContext::Multiply`. See the `InferenceContext` class for
-all of the various shape manipulations you can specify.
+all of the various shape manipulations you can specify. The following example sets
+shape of the first output to (n, 3), where first input has shape (n, ...)
+
+```c++
+.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
+ c->set_output(0, c->Matrix(c->Dim(c->input(0), 0), 3));
+ return Status::OK();
+});
+```
If you have a complicated shape function, you should consider adding a test for
validating that various input shape combinations produce the expected output