aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/docs/pyfunc_dtypes.md
blob: bcbb920cc53de4b89dc67128c9c2c2312f030f0a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# Specifying return data type for `py_func` calls

The `py_func` op requires specifying a
[data type](https://www.tensorflow.org/guide/tensors#data_types).

When wrapping a function with `py_func`, for instance using
`@autograph.do_not_convert(run_mode=autograph.RunMode.PY_FUNC)`, you have two
options to specify the returned data type:

 * explicitly, with a specified `tf.DType` value
 * by matching the data type of an input argument, which is then assumed to be
     a `Tensor`

Examples:

Specify an explicit data type:

```
  def foo(a):
    return a + 1

  autograph.util.wrap_py_func(f, return_dtypes=[tf.float32])
```

Match the data type of the first argument:

```
  def foo(a):
    return a + 1

  autograph.util.wrap_py_func(
      f, return_dtypes=[autograph.utils.py_func.MatchDType(0)])
```