aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/linear_optimizer/python/ops
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/linear_optimizer/python/ops')
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py3
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py2
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py6
3 files changed, 10 insertions, 1 deletions
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
index 48ac429701..b5099a0bf6 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
@@ -152,7 +152,8 @@ class SdcaModel(object):
default_value=[0.0, 0.0, 0.0, 0.0],
# SdcaFprint never returns 0 or 1 for the low64 bits, so this a safe
# empty_key (that will never collide with actual payloads).
- empty_key=[0, 0])
+ empty_key=[0, 0],
+ deleted_key=[1, 1])
summary.scalar('approximate_duality_gap', self.approximate_duality_gap())
summary.scalar('examples_seen', self._hashtable.size())
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py
index 5015fb0848..44a869f7c2 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py
@@ -48,6 +48,7 @@ class ShardedMutableDenseHashTable(lookup.LookupInterface):
value_dtype,
default_value,
empty_key,
+ deleted_key,
num_shards=1,
checkpoint=True,
name='ShardedMutableHashTable'):
@@ -62,6 +63,7 @@ class ShardedMutableDenseHashTable(lookup.LookupInterface):
value_dtype=value_dtype,
default_value=default_value,
empty_key=empty_key,
+ deleted_key=deleted_key,
checkpoint=checkpoint,
name='%s-%d-of-%d' % (name, i + 1, num_shards)))
self._table_shards = table_shards
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py
index 553b116a3b..2b56d0fa3a 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py
@@ -33,6 +33,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase):
with self.cached_session():
default_val = -1
empty_key = 0
+ deleted_key = -1
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = ShardedMutableDenseHashTable(
@@ -40,6 +41,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase):
dtypes.int64,
default_val,
empty_key,
+ deleted_key,
num_shards=num_shards)
self.assertAllEqual(0, table.size().eval())
@@ -56,6 +58,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase):
with self.cached_session():
default_val = [-0.1, 0.2]
empty_key = [0, 1]
+ deleted_key = [1, 0]
keys = constant_op.constant([[11, 12], [13, 14], [15, 16]],
dtypes.int64)
values = constant_op.constant([[0.5, 0.6], [1.5, 1.6], [2.5, 2.6]],
@@ -65,6 +68,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase):
dtypes.float32,
default_val,
empty_key,
+ deleted_key,
num_shards=num_shards)
self.assertAllEqual(0, table.size().eval())
@@ -81,6 +85,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase):
def testExportSharded(self):
with self.cached_session():
empty_key = -2
+ deleted_key = -3
default_val = -1
num_shards = 2
keys = constant_op.constant([10, 11, 12], dtypes.int64)
@@ -90,6 +95,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase):
dtypes.int64,
default_val,
empty_key,
+ deleted_key,
num_shards=num_shards)
self.assertAllEqual(0, table.size().eval())