diff options
Diffstat (limited to 'tensorflow/contrib/timeseries/python/timeseries/math_utils.py')
-rw-r--r-- | tensorflow/contrib/timeseries/python/timeseries/math_utils.py | 19 |
1 files changed, 12 insertions, 7 deletions
diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py index 03da2b82e5..9c585fe6a7 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py +++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py @@ -543,20 +543,25 @@ class TupleOfTensorsLookup(lookup.LookupInterface): overhead. """ - def __init__( - self, key_dtype, default_values, empty_key, name, checkpoint=True): + def __init__(self, + key_dtype, + default_values, + empty_key, + deleted_key, + name, + checkpoint=True): default_values_flat = nest.flatten(default_values) - self._hash_tables = nest.pack_sequence_as( - default_values, - [TensorValuedMutableDenseHashTable( + self._hash_tables = nest.pack_sequence_as(default_values, [ + TensorValuedMutableDenseHashTable( key_dtype=key_dtype, value_dtype=default_value.dtype.base_dtype, default_value=default_value, empty_key=empty_key, + deleted_key=deleted_key, name=name + "_{}".format(table_number), checkpoint=checkpoint) - for table_number, default_value - in enumerate(default_values_flat)]) + for table_number, default_value in enumerate(default_values_flat) + ]) self._name = name def lookup(self, keys): |