diff options
Diffstat (limited to 'tensorflow/contrib/timeseries/python')
3 files changed, 18 insertions, 10 deletions
diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py index 03da2b82e5..9c585fe6a7 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py +++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py @@ -543,20 +543,25 @@ class TupleOfTensorsLookup(lookup.LookupInterface): overhead. """ - def __init__( - self, key_dtype, default_values, empty_key, name, checkpoint=True): + def __init__(self, + key_dtype, + default_values, + empty_key, + deleted_key, + name, + checkpoint=True): default_values_flat = nest.flatten(default_values) - self._hash_tables = nest.pack_sequence_as( - default_values, - [TensorValuedMutableDenseHashTable( + self._hash_tables = nest.pack_sequence_as(default_values, [ + TensorValuedMutableDenseHashTable( key_dtype=key_dtype, value_dtype=default_value.dtype.base_dtype, default_value=default_value, empty_key=empty_key, + deleted_key=deleted_key, name=name + "_{}".format(table_number), checkpoint=checkpoint) - for table_number, default_value - in enumerate(default_values_flat)]) + for table_number, default_value in enumerate(default_values_flat) + ]) self._name = name def lookup(self, keys): diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py index c0de42b15b..91265b9b2e 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py @@ -223,10 +223,12 @@ class TestLookupTable(test.TestCase): hash_table = math_utils.TupleOfTensorsLookup( key_dtype=dtypes.int64, default_values=[[ - array_ops.ones([3, 2], dtype=dtypes.float32), array_ops.zeros( - [5], dtype=dtypes.float64) - ], array_ops.ones([7, 7], dtype=dtypes.int64)], + array_ops.ones([3, 2], dtype=dtypes.float32), + array_ops.zeros([5], dtype=dtypes.float64) + ], + array_ops.ones([7, 7], dtype=dtypes.int64)], empty_key=-1, + deleted_key=-2, name="test_lookup") def stack_tensor(base_tensor): return array_ops.stack([base_tensor + 1, base_tensor + 2]) diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_management.py b/tensorflow/contrib/timeseries/python/timeseries/state_management.py index 13eecd4d82..138406c616 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_management.py +++ b/tensorflow/contrib/timeseries/python/timeseries/state_management.py @@ -149,6 +149,7 @@ class ChainingStateManager(_OverridableStateManager): key_dtype=dtypes.int64, default_values=self._start_state, empty_key=-1, + deleted_key=-2, name="cached_states", checkpoint=self._checkpoint_state) |