diff options
author | Akshay Modi <nareshmodi@google.com> | 2018-10-03 13:45:59 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-03 13:55:52 -0700 |
commit | c26b5e9685b05fafc509d8ebc88c8304be5974a4 (patch) | |
tree | ce32a1c8a88ebe1428b140901053f1745db62640 /tensorflow/python/eager | |
parent | ce9a5d143f89a37ab029a29c62433883323987e8 (diff) |
Some tiny speed improvements for defun.
Before:
entry {
name: "MicroBenchmarks.benchmark_defun_matmul_2_by_2_CPU"
iters: 30000
wall_time: 48.4476327896
extras {
key: "examples_per_sec"
value {
double_value: 20640.8433688
}
}
}
After:
entry {
name: "MicroBenchmarks.benchmark_defun_matmul_2_by_2_CPU"
iters: 30000
wall_time: 45.2344338099
extras {
key: "examples_per_sec"
value {
double_value: 22107.0524327
}
}
}
PiperOrigin-RevId: 215619902
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r-- | tensorflow/python/eager/function.py | 36 |
1 files changed, 19 insertions, 17 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index f261d92d64..dd9f5e233c 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -1152,23 +1152,22 @@ class PolymorphicFunction(object): del args, kwargs cache_key = self._flat_input_signature + ctx = context.context() with ops.init_scope(): - init_graph = ops.get_default_graph() - # The graph, or whether we're executing eagerly, should be a part of the # cache key so we don't improperly capture tensors such as variables. - executing_eagerly = context.executing_eagerly() - execution_context = executing_eagerly or init_graph - - default_graph = ops.get_default_graph() - # Putting the device in the cache key ensures that call-site device - # annotations are respected. - device_functions = _get_device_functions(context.context(), default_graph) + executing_eagerly = ctx.executing_eagerly() + execution_context = executing_eagerly or ops.get_default_graph() - # `ops.colocate_with` directives translate into `ops.device` directives when - # eager execution is enabled. - colocation_stack = (() if executing_eagerly else - tuple(default_graph._colocation_stack.peek_objs())) # pylint: disable=protected-access + if executing_eagerly: + device_functions = (pydev.merge_device(ctx.device_name),) + colocation_stack = () + else: + default_graph = ops.get_default_graph() + # Putting the device in the cache key ensures that call-site device + # annotations are respected. + device_functions = tuple(default_graph._device_functions_outer_to_inner) # pylint: disable=protected-access + colocation_stack = tuple(default_graph._colocation_stack.peek_objs()) # pylint: disable=protected-access return (cache_key, execution_context, device_functions, colocation_stack) @@ -1195,9 +1194,6 @@ class PolymorphicFunction(object): """ args = self._args_to_prepend + args kwargs = dict(kwargs, **self._kwargs_to_include) - # Maps from index of arg to its corresponding value, according to `args` - # and `kwargs`; seeded with the default values for the named args that - # aren't in `args`. if not kwargs: if self._default_values: inputs = args + self._default_values[len(args) - @@ -1205,6 +1201,9 @@ class PolymorphicFunction(object): else: inputs = args else: + # Maps from index of arg to its corresponding value, according to `args` + # and `kwargs`; seeded with the default values for the named args that + # aren't in `args`. arg_indices_to_values = { index: default for index, default in six.iteritems( self._arg_indices_to_default_values) if index >= len(args) @@ -1227,9 +1226,12 @@ class PolymorphicFunction(object): flat_inputs = nest.flatten(inputs) # Check for NumPy arrays in arguments and convert them to Tensors. + # TODO(nareshmodi): Skip ndarray conversion to tensor altogether, perhaps + # finding a way to store them directly in the cache key (currently not + # possible since ndarrays are not hashable). need_packing = False for index, value in enumerate(flat_inputs): - if isinstance(value, np.ndarray): + if type(value) == np.ndarray: flat_inputs[index] = constant_op.constant(value) need_packing = True if need_packing: |