diff options
author | Akshay Agrawal <akshayka@google.com> | 2018-08-06 15:17:03 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-06 15:21:01 -0700 |
commit | 080e15b9287968d8ae6f0e7aa041593bd4b82c8d (patch) | |
tree | c9dc23b2da3cfe762e0d339ceabc388788a48f41 /tensorflow/contrib/eager | |
parent | cb29a6b2217a140c248188015da424669fd08e54 (diff) |
Add an optional input signature for functions generated with defun.
An input signature is a possibly nested collection of `TensorSpec` objects declaring the shapes and dtypes of a function's arguments. tfe.defun propagates these shapes and dtypes to the graph function it generates for the traced Python function.
Since the shapes may be partially specified, this makes it possible to generate functions with partial shape information.
When an input signature is specified, every argument to the `defun`-wrapped function *must* be a Tensor. Input signatures cannot be specified for functions with keyword arguments, but positional arguments that are called in a keyword-like fashion are acceptable. For example, the following code snippet is valid
@tfe.defun(input_signature=(TensorSpec([], tf.float32), TensorSpec([], tf.int32))
def foo(b, a):
...
foo(b=tf.constant(1.0), a=tf.constant(2))
foo(a=tf.constant(2), b=tf.constant(1.0))
while the next code snippet is invalid
@tfe.defun(input_signature=(TensorSpec([], tf.float32), TensorSpec([], tf.int32))
def foo(b, **kwargs):
...
# This will fail --- arbitrary kwargs are not allowed.
foo(b=tf.constant(1.0), a=tf.constant(2))
This change also adds benchmarks that approximately measure the time taken to compute the cache key and verify the signature (and execute an empty function). Signatures introduce an overhead of ~100 us. The benchmarks for the non-signature path are the same as they were before this change.
entry {
name: "MicroBenchmarks.benchmark_defun_with_signature"
iters: 30000
wall_time: 349.911403656
extras {
key: "examples_per_sec"
value {
double_value: 2857.86627572
}
}
}
entry {
name: "MicroBenchmarks.benchmark_defun_with_signature_and_kwargs"
iters: 30000
wall_time: 360.46500206
extras {
key: "examples_per_sec"
value {
double_value: 2774.19442743
}
}
}
entry {
name: "MicroBenchmarks.benchmark_defun_without_signature"
iters: 30000
wall_time: 259.087236722
extras {
key: "examples_per_sec"
value {
double_value: 3859.70383046
}
}
}
entry {
name: "MicroBenchmarks.benchmark_defun_without_signature_and_with_kwargs"
iters: 30000
wall_time: 272.486400604
extras {
key: "examples_per_sec"
value {
double_value: 3669.90792121
}
}
}
PiperOrigin-RevId: 207617442
Diffstat (limited to 'tensorflow/contrib/eager')
-rw-r--r-- | tensorflow/contrib/eager/python/tfe.py | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index 2f0ab616e4..de11d00a1a 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -71,6 +71,8 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@run_test_in_graph_and_eager_modes @@run_all_tests_in_graph_and_eager_modes +@@TensorSpec + @@DEVICE_PLACEMENT_EXPLICIT @@DEVICE_PLACEMENT_WARN @@DEVICE_PLACEMENT_SILENT @@ -114,6 +116,7 @@ from tensorflow.python.eager.execution_callbacks import inf_callback from tensorflow.python.eager.execution_callbacks import inf_nan_callback from tensorflow.python.eager.execution_callbacks import nan_callback from tensorflow.python.eager.execution_callbacks import seterr +from tensorflow.python.framework.tensor_spec import TensorSpec from tensorflow.python.framework.ops import enable_eager_execution from tensorflow.python.framework.ops import enable_eager_execution_internal as enable_remote_eager_execution from tensorflow.python.framework.ops import eager_run as run |