aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-10-03 13:45:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 13:55:52 -0700
commitc26b5e9685b05fafc509d8ebc88c8304be5974a4 (patch)
treece32a1c8a88ebe1428b140901053f1745db62640 /tensorflow/python/eager
parentce9a5d143f89a37ab029a29c62433883323987e8 (diff)
Some tiny speed improvements for defun.
Before: entry { name: "MicroBenchmarks.benchmark_defun_matmul_2_by_2_CPU" iters: 30000 wall_time: 48.4476327896 extras { key: "examples_per_sec" value { double_value: 20640.8433688 } } } After: entry { name: "MicroBenchmarks.benchmark_defun_matmul_2_by_2_CPU" iters: 30000 wall_time: 45.2344338099 extras { key: "examples_per_sec" value { double_value: 22107.0524327 } } } PiperOrigin-RevId: 215619902
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r--tensorflow/python/eager/function.py36
1 files changed, 19 insertions, 17 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index f261d92d64..dd9f5e233c 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -1152,23 +1152,22 @@ class PolymorphicFunction(object):
del args, kwargs
cache_key = self._flat_input_signature
+ ctx = context.context()
with ops.init_scope():
- init_graph = ops.get_default_graph()
-
# The graph, or whether we're executing eagerly, should be a part of the
# cache key so we don't improperly capture tensors such as variables.
- executing_eagerly = context.executing_eagerly()
- execution_context = executing_eagerly or init_graph
-
- default_graph = ops.get_default_graph()
- # Putting the device in the cache key ensures that call-site device
- # annotations are respected.
- device_functions = _get_device_functions(context.context(), default_graph)
+ executing_eagerly = ctx.executing_eagerly()
+ execution_context = executing_eagerly or ops.get_default_graph()
- # `ops.colocate_with` directives translate into `ops.device` directives when
- # eager execution is enabled.
- colocation_stack = (() if executing_eagerly else
- tuple(default_graph._colocation_stack.peek_objs())) # pylint: disable=protected-access
+ if executing_eagerly:
+ device_functions = (pydev.merge_device(ctx.device_name),)
+ colocation_stack = ()
+ else:
+ default_graph = ops.get_default_graph()
+ # Putting the device in the cache key ensures that call-site device
+ # annotations are respected.
+ device_functions = tuple(default_graph._device_functions_outer_to_inner) # pylint: disable=protected-access
+ colocation_stack = tuple(default_graph._colocation_stack.peek_objs()) # pylint: disable=protected-access
return (cache_key, execution_context, device_functions, colocation_stack)
@@ -1195,9 +1194,6 @@ class PolymorphicFunction(object):
"""
args = self._args_to_prepend + args
kwargs = dict(kwargs, **self._kwargs_to_include)
- # Maps from index of arg to its corresponding value, according to `args`
- # and `kwargs`; seeded with the default values for the named args that
- # aren't in `args`.
if not kwargs:
if self._default_values:
inputs = args + self._default_values[len(args) -
@@ -1205,6 +1201,9 @@ class PolymorphicFunction(object):
else:
inputs = args
else:
+ # Maps from index of arg to its corresponding value, according to `args`
+ # and `kwargs`; seeded with the default values for the named args that
+ # aren't in `args`.
arg_indices_to_values = {
index: default for index, default in six.iteritems(
self._arg_indices_to_default_values) if index >= len(args)
@@ -1227,9 +1226,12 @@ class PolymorphicFunction(object):
flat_inputs = nest.flatten(inputs)
# Check for NumPy arrays in arguments and convert them to Tensors.
+ # TODO(nareshmodi): Skip ndarray conversion to tensor altogether, perhaps
+ # finding a way to store them directly in the cache key (currently not
+ # possible since ndarrays are not hashable).
need_packing = False
for index, value in enumerate(flat_inputs):
- if isinstance(value, np.ndarray):
+ if type(value) == np.ndarray:
flat_inputs[index] = constant_op.constant(value)
need_packing = True
if need_packing: