aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/timeseries
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/timeseries')
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/math_utils.py19
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py8
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_management.py1
3 files changed, 18 insertions, 10 deletions
diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py
index 03da2b82e5..9c585fe6a7 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py
@@ -543,20 +543,25 @@ class TupleOfTensorsLookup(lookup.LookupInterface):
overhead.
"""
- def __init__(
- self, key_dtype, default_values, empty_key, name, checkpoint=True):
+ def __init__(self,
+ key_dtype,
+ default_values,
+ empty_key,
+ deleted_key,
+ name,
+ checkpoint=True):
default_values_flat = nest.flatten(default_values)
- self._hash_tables = nest.pack_sequence_as(
- default_values,
- [TensorValuedMutableDenseHashTable(
+ self._hash_tables = nest.pack_sequence_as(default_values, [
+ TensorValuedMutableDenseHashTable(
key_dtype=key_dtype,
value_dtype=default_value.dtype.base_dtype,
default_value=default_value,
empty_key=empty_key,
+ deleted_key=deleted_key,
name=name + "_{}".format(table_number),
checkpoint=checkpoint)
- for table_number, default_value
- in enumerate(default_values_flat)])
+ for table_number, default_value in enumerate(default_values_flat)
+ ])
self._name = name
def lookup(self, keys):
diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py
index c0de42b15b..91265b9b2e 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py
@@ -223,10 +223,12 @@ class TestLookupTable(test.TestCase):
hash_table = math_utils.TupleOfTensorsLookup(
key_dtype=dtypes.int64,
default_values=[[
- array_ops.ones([3, 2], dtype=dtypes.float32), array_ops.zeros(
- [5], dtype=dtypes.float64)
- ], array_ops.ones([7, 7], dtype=dtypes.int64)],
+ array_ops.ones([3, 2], dtype=dtypes.float32),
+ array_ops.zeros([5], dtype=dtypes.float64)
+ ],
+ array_ops.ones([7, 7], dtype=dtypes.int64)],
empty_key=-1,
+ deleted_key=-2,
name="test_lookup")
def stack_tensor(base_tensor):
return array_ops.stack([base_tensor + 1, base_tensor + 2])
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_management.py b/tensorflow/contrib/timeseries/python/timeseries/state_management.py
index 13eecd4d82..138406c616 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_management.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_management.py
@@ -149,6 +149,7 @@ class ChainingStateManager(_OverridableStateManager):
key_dtype=dtypes.int64,
default_values=self._start_state,
empty_key=-1,
+ deleted_key=-2,
name="cached_states",
checkpoint=self._checkpoint_state)