diff options
Diffstat (limited to 'tensorflow/contrib/linear_optimizer/python')
3 files changed, 10 insertions, 1 deletions
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py index 48ac429701..b5099a0bf6 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py @@ -152,7 +152,8 @@ class SdcaModel(object): default_value=[0.0, 0.0, 0.0, 0.0], # SdcaFprint never returns 0 or 1 for the low64 bits, so this a safe # empty_key (that will never collide with actual payloads). - empty_key=[0, 0]) + empty_key=[0, 0], + deleted_key=[1, 1]) summary.scalar('approximate_duality_gap', self.approximate_duality_gap()) summary.scalar('examples_seen', self._hashtable.size()) diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py index 5015fb0848..44a869f7c2 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py @@ -48,6 +48,7 @@ class ShardedMutableDenseHashTable(lookup.LookupInterface): value_dtype, default_value, empty_key, + deleted_key, num_shards=1, checkpoint=True, name='ShardedMutableHashTable'): @@ -62,6 +63,7 @@ class ShardedMutableDenseHashTable(lookup.LookupInterface): value_dtype=value_dtype, default_value=default_value, empty_key=empty_key, + deleted_key=deleted_key, checkpoint=checkpoint, name='%s-%d-of-%d' % (name, i + 1, num_shards))) self._table_shards = table_shards diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py index 553b116a3b..2b56d0fa3a 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py @@ -33,6 +33,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase): with self.cached_session(): default_val = -1 empty_key = 0 + deleted_key = -1 keys = constant_op.constant([11, 12, 13], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64) table = ShardedMutableDenseHashTable( @@ -40,6 +41,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase): dtypes.int64, default_val, empty_key, + deleted_key, num_shards=num_shards) self.assertAllEqual(0, table.size().eval()) @@ -56,6 +58,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase): with self.cached_session(): default_val = [-0.1, 0.2] empty_key = [0, 1] + deleted_key = [1, 0] keys = constant_op.constant([[11, 12], [13, 14], [15, 16]], dtypes.int64) values = constant_op.constant([[0.5, 0.6], [1.5, 1.6], [2.5, 2.6]], @@ -65,6 +68,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase): dtypes.float32, default_val, empty_key, + deleted_key, num_shards=num_shards) self.assertAllEqual(0, table.size().eval()) @@ -81,6 +85,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase): def testExportSharded(self): with self.cached_session(): empty_key = -2 + deleted_key = -3 default_val = -1 num_shards = 2 keys = constant_op.constant([10, 11, 12], dtypes.int64) @@ -90,6 +95,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase): dtypes.int64, default_val, empty_key, + deleted_key, num_shards=num_shards) self.assertAllEqual(0, table.size().eval()) |