aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-10-01 13:46:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 13:57:32 -0700
commitec900f15e352e4b203b1f0678f7d2ff042df57d5 (patch)
tree2d7a7ffc0f17cb28801c7a9937b6f4e3777592c7 /tensorflow/python/framework
parent3039a4694e22674b502257ae34b0a5b614a631f3 (diff)
Minor speed improvements to defun.
- EncodeArg in C instead of python. - Also caches parsed device specs, and device spec hashes - Adds a common way to register python types in C. - Fastpath canonicalize function inputs when no kwargs are passed - Set the func name attr directly instead of creating an op to wrap it. - Rewrite IsAttrsHelper without caching Before: entry { name: "MicroBenchmarks.benchmark_defun_matmul_2_by_2_CPU" iters: 30000 wall_time: 101.803263028 extras { key: "examples_per_sec" value { double_value: 9822.86785562 } } } After: entry { name: "MicroBenchmarks.benchmark_defun_matmul_2_by_2_CPU" iters: 30000 wall_time: 47.2899993261 extras { key: "examples_per_sec" value { double_value: 21146.1199884 } } } PiperOrigin-RevId: 215272962
Diffstat (limited to 'tensorflow/python/framework')
-rw-r--r--tensorflow/python/framework/device.py12
-rw-r--r--tensorflow/python/framework/sparse_tensor.py2
2 files changed, 11 insertions, 3 deletions
diff --git a/tensorflow/python/framework/device.py b/tensorflow/python/framework/device.py
index 06c653097a..7f6e0a75a5 100644
--- a/tensorflow/python/framework/device.py
+++ b/tensorflow/python/framework/device.py
@@ -87,6 +87,7 @@ class DeviceSpec(object):
else:
self.device_type = device_type
self.device_index = device_index
+ self._hash = hash(self.to_string())
def _clear(self):
self._job = None
@@ -234,7 +235,7 @@ class DeviceSpec(object):
return self.to_string() == other.to_string()
def __hash__(self):
- return hash(self.to_string())
+ return self._hash
def check_valid(spec):
@@ -266,6 +267,7 @@ def canonical_name(device):
# possible to compare the device function stacks belonging to different
# graphs in a meaningful way.
_cached_device_functions = {}
+_cached_device_specs = {}
_cache_lock = threading.Lock()
@@ -297,7 +299,13 @@ def merge_device(spec):
"""
with _cache_lock:
if not isinstance(spec, DeviceSpec):
- spec = DeviceSpec.from_string(spec or "")
+ cached_device_spec = _cached_device_specs.get(spec, None)
+ if cached_device_spec is None:
+ device_spec = DeviceSpec.from_string(spec or "")
+ _cached_device_specs[spec] = device_spec
+ spec = device_spec
+ else:
+ spec = cached_device_spec
cached_function = _cached_device_functions.get(spec, None)
if cached_function is not None:
return cached_function
diff --git a/tensorflow/python/framework/sparse_tensor.py b/tensorflow/python/framework/sparse_tensor.py
index 41ef2e11d1..440e3a0968 100644
--- a/tensorflow/python/framework/sparse_tensor.py
+++ b/tensorflow/python/framework/sparse_tensor.py
@@ -245,7 +245,7 @@ class SparseTensor(_TensorLike):
SparseTensorValue = collections.namedtuple(
"SparseTensorValue", ["indices", "values", "dense_shape"])
tf_export("SparseTensorValue")(SparseTensorValue)
-pywrap_tensorflow.RegisterSparseTensorValueClass(SparseTensorValue)
+pywrap_tensorflow.RegisterType("SparseTensorValue", SparseTensorValue)
@tf_export("convert_to_tensor_or_sparse_tensor")