aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager
diff options
context:
space:
mode:
authorGravatar Akshay Agrawal <akshayka@google.com>2018-08-06 15:17:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-06 15:21:01 -0700
commit080e15b9287968d8ae6f0e7aa041593bd4b82c8d (patch)
treec9dc23b2da3cfe762e0d339ceabc388788a48f41 /tensorflow/contrib/eager
parentcb29a6b2217a140c248188015da424669fd08e54 (diff)
Add an optional input signature for functions generated with defun.
An input signature is a possibly nested collection of `TensorSpec` objects declaring the shapes and dtypes of a function's arguments. tfe.defun propagates these shapes and dtypes to the graph function it generates for the traced Python function. Since the shapes may be partially specified, this makes it possible to generate functions with partial shape information. When an input signature is specified, every argument to the `defun`-wrapped function *must* be a Tensor. Input signatures cannot be specified for functions with keyword arguments, but positional arguments that are called in a keyword-like fashion are acceptable. For example, the following code snippet is valid @tfe.defun(input_signature=(TensorSpec([], tf.float32), TensorSpec([], tf.int32)) def foo(b, a): ... foo(b=tf.constant(1.0), a=tf.constant(2)) foo(a=tf.constant(2), b=tf.constant(1.0)) while the next code snippet is invalid @tfe.defun(input_signature=(TensorSpec([], tf.float32), TensorSpec([], tf.int32)) def foo(b, **kwargs): ... # This will fail --- arbitrary kwargs are not allowed. foo(b=tf.constant(1.0), a=tf.constant(2)) This change also adds benchmarks that approximately measure the time taken to compute the cache key and verify the signature (and execute an empty function). Signatures introduce an overhead of ~100 us. The benchmarks for the non-signature path are the same as they were before this change. entry { name: "MicroBenchmarks.benchmark_defun_with_signature" iters: 30000 wall_time: 349.911403656 extras { key: "examples_per_sec" value { double_value: 2857.86627572 } } } entry { name: "MicroBenchmarks.benchmark_defun_with_signature_and_kwargs" iters: 30000 wall_time: 360.46500206 extras { key: "examples_per_sec" value { double_value: 2774.19442743 } } } entry { name: "MicroBenchmarks.benchmark_defun_without_signature" iters: 30000 wall_time: 259.087236722 extras { key: "examples_per_sec" value { double_value: 3859.70383046 } } } entry { name: "MicroBenchmarks.benchmark_defun_without_signature_and_with_kwargs" iters: 30000 wall_time: 272.486400604 extras { key: "examples_per_sec" value { double_value: 3669.90792121 } } } PiperOrigin-RevId: 207617442
Diffstat (limited to 'tensorflow/contrib/eager')
-rw-r--r--tensorflow/contrib/eager/python/tfe.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py
index 2f0ab616e4..de11d00a1a 100644
--- a/tensorflow/contrib/eager/python/tfe.py
+++ b/tensorflow/contrib/eager/python/tfe.py
@@ -71,6 +71,8 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
@@run_test_in_graph_and_eager_modes
@@run_all_tests_in_graph_and_eager_modes
+@@TensorSpec
+
@@DEVICE_PLACEMENT_EXPLICIT
@@DEVICE_PLACEMENT_WARN
@@DEVICE_PLACEMENT_SILENT
@@ -114,6 +116,7 @@ from tensorflow.python.eager.execution_callbacks import inf_callback
from tensorflow.python.eager.execution_callbacks import inf_nan_callback
from tensorflow.python.eager.execution_callbacks import nan_callback
from tensorflow.python.eager.execution_callbacks import seterr
+from tensorflow.python.framework.tensor_spec import TensorSpec
from tensorflow.python.framework.ops import enable_eager_execution
from tensorflow.python.framework.ops import enable_eager_execution_internal as enable_remote_eager_execution
from tensorflow.python.framework.ops import eager_run as run