aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-10-27 16:46:34 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-27 18:03:44 -0700
commit87a61b18f83d4f5ec8a796c2a4d665d3010eac91 (patch)
tree51c9285415bc0e5bda9dfa14b250084ffdd7bd99
parent33dd5d021939ac84c020a905b65c1e134b840944 (diff)
Change SDCA to use DenseMutableHashtable.
Change SdcaFprint to return a two int64 values instead of one string to save memory given that DenseMutableHashtable supports non-scalar keys. Corresponding changes to the private sharded hashtable wrapper to support non-scalar keys. Change: 137464256
-rw-r--r--tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py67
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py97
-rw-r--r--tensorflow/core/kernels/sdca_ops.cc34
-rw-r--r--tensorflow/core/ops/compat/ops_history.v0.pbtxt11
-rw-r--r--tensorflow/core/ops/sdca_ops.cc14
-rw-r--r--tensorflow/python/ops/hidden_ops.txt5
6 files changed, 139 insertions, 89 deletions
diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
index b749cd1866..9cd3c6d690 100644
--- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
@@ -22,10 +22,11 @@ from threading import Thread
import tensorflow as tf
-from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import _ShardedMutableHashTable
+from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import _ShardedMutableDenseHashTable
from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import SdcaModel
from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import SparseFeatureColumn
from tensorflow.python.framework.test_util import TensorFlowTestCase
+from tensorflow.python.ops import gen_sdca_ops
from tensorflow.python.platform import googletest
_MAX_ITERATIONS = 100
@@ -1013,46 +1014,64 @@ class SdcaFprintTest(SdcaModelTest):
def testFprint(self):
with self._single_threaded_test_session():
in_data = tf.constant(['abc', 'very looooooong string', 'def'])
- out_data = tf.sdca.sdca_fprint(in_data)
- self.assertAllEqual(
- [b'\x04l\x12\xd2\xaf\xb2\x809E\x9e\x02\x13',
- b'\x9f\x0f\x91P\x9aG.Ql\xf2Y\xf9',
- b'"0\xe00"\x18_\x08\x12?\xa0\x17'], out_data.eval())
+ out_data = gen_sdca_ops._sdca_fprint(in_data)
+ self.assertAllEqual([[4143508125394299908, -6879828354153669051],
+ [5849691694103072671, -4874542629849009556],
+ [603227410218889250, 8762207001949257490]],
+ out_data.eval())
-class ShardedMutableHashTableTest(SdcaModelTest):
+class ShardedMutableDenseHashTableTest(SdcaModelTest):
"""Tests for the _ShardedMutableHashTable class."""
def testShardedMutableHashTable(self):
for num_shards in [1, 3, 10]:
with self._single_threaded_test_session():
default_val = -1
- keys = tf.constant(['brain', 'salad', 'surgery'])
+ empty_key = 0
+ keys = tf.constant([11, 12, 13], tf.int64)
values = tf.constant([0, 1, 2], tf.int64)
- table = _ShardedMutableHashTable(tf.string,
- tf.int64,
- default_val,
- num_shards=num_shards)
+ table = _ShardedMutableDenseHashTable(
+ tf.int64, tf.int64, default_val, empty_key, num_shards=num_shards)
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
self.assertAllEqual(3, table.size().eval())
- input_string = tf.constant(['brain', 'salad', 'tank'])
+ input_string = tf.constant([11, 12, 14], tf.int64)
output = table.lookup(input_string)
self.assertAllEqual([3], output.get_shape())
+ self.assertAllEqual([0, 1, -1], output.eval())
- result = output.eval()
- self.assertAllEqual([0, 1, -1], result)
+ def testShardedMutableHashTableVectors(self):
+ for num_shards in [1, 3, 10]:
+ with self._single_threaded_test_session():
+ default_val = [-0.1, 0.2]
+ empty_key = [0, 1]
+ keys = tf.constant([[11, 12], [13, 14], [15, 16]], tf.int64)
+ values = tf.constant([[0.5, 0.6], [1.5, 1.6], [2.5, 2.6]], tf.float32)
+ table = _ShardedMutableDenseHashTable(
+ tf.int64, tf.float32, default_val, empty_key, num_shards=num_shards)
+ self.assertAllEqual(0, table.size().eval())
+
+ table.insert(keys, values).run()
+ self.assertAllEqual(3, table.size().eval())
+
+ input_string = tf.constant([[11, 12], [13, 14], [11, 14]], tf.int64)
+ output = table.lookup(input_string)
+ self.assertAllEqual([3, 2], output.get_shape())
+ self.assertAllClose([[0.5, 0.6], [1.5, 1.6], [-0.1, 0.2]],
+ output.eval())
def testExportSharded(self):
with self._single_threaded_test_session():
+ empty_key = -2
default_val = -1
num_shards = 2
- keys = tf.constant(['a1', 'b1', 'c2'])
- values = tf.constant([0, 1, 2], tf.int64)
- table = _ShardedMutableHashTable(
- tf.string, tf.int64, default_val, num_shards=num_shards)
+ keys = tf.constant([10, 11, 12], tf.int64)
+ values = tf.constant([2, 3, 4], tf.int64)
+ table = _ShardedMutableDenseHashTable(
+ tf.int64, tf.int64, default_val, empty_key, num_shards=num_shards)
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
@@ -1062,10 +1081,12 @@ class ShardedMutableHashTableTest(SdcaModelTest):
self.assertAllEqual(num_shards, len(keys_list))
self.assertAllEqual(num_shards, len(values_list))
- self.assertAllEqual(set([b'b1', b'c2']), set(keys_list[0].eval()))
- self.assertAllEqual([b'a1'], keys_list[1].eval())
- self.assertAllEqual(set([1, 2]), set(values_list[0].eval()))
- self.assertAllEqual([0], values_list[1].eval())
+ # Exported keys include empty key buckets set to the empty_key
+ self.assertAllEqual(set([-2, 10, 12]), set(keys_list[0].eval().flatten()))
+ self.assertAllEqual(set([-2, 11]), set(keys_list[1].eval().flatten()))
+ # Exported values include empty value buckets set to 0
+ self.assertAllEqual(set([0, 2, 4]), set(values_list[0].eval().flatten()))
+ self.assertAllEqual(set([0, 3]), set(values_list[1].eval().flatten()))
if __name__ == '__main__':
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
index ffb1477ce8..13310a3ec8 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
@@ -25,36 +25,31 @@ from six.moves import range
from tensorflow.contrib.lookup import lookup_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework.ops import convert_to_tensor
from tensorflow.python.framework.ops import name_scope
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import gen_sdca_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variables as var_ops
from tensorflow.python.ops.nn import sigmoid_cross_entropy_with_logits
-from tensorflow.python.ops.sdca_ops import sdca_fprint
-from tensorflow.python.ops.sdca_ops import sdca_optimizer
-from tensorflow.python.ops.sdca_ops import sdca_shrink_l1
__all__ = ['SdcaModel']
-class _ShardedMutableHashTable(lookup_ops.LookupInterface):
- """A sharded version of MutableHashTable.
+class _ShardedMutableDenseHashTable(lookup_ops.LookupInterface):
+ """A sharded version of MutableDenseHashTable.
It is designed to be interface compatible with LookupInterface and
- MutableHashTable, with the exception of the export method, which is replaced
- by a custom values_reduce_sum method for SDCA needs. The class is not part of
- lookup ops because it is unclear how to make the device placement general
- enough to be useful.
-
- The _ShardedHashTable keeps `num_shards` MutableHashTables internally. If keys
- are integers, the shard is computed via the modulo operation. If keys are
- strings, the shard is computed via string_to_hash_bucket_fast.
+ MutableDenseHashTable, with the exception of the export method, which is
+ replaced by an export_sharded method.
+
+ The _ShardedMutableDenseHashTable keeps `num_shards` MutableDenseHashTable
+ internally. The shard is computed via the modulo operation on the key.
"""
# TODO(andreasst): consider moving this to lookup_ops
@@ -63,18 +58,21 @@ class _ShardedMutableHashTable(lookup_ops.LookupInterface):
key_dtype,
value_dtype,
default_value,
+ empty_key,
num_shards=1,
name='ShardedMutableHashTable'):
with ops.name_scope(name, 'sharded_mutable_hash_table') as scope:
- super(_ShardedMutableHashTable, self).__init__(key_dtype, value_dtype,
- scope)
+ super(_ShardedMutableDenseHashTable, self).__init__(key_dtype,
+ value_dtype, scope)
table_shards = []
for i in range(num_shards):
- table_shards.append(lookup_ops.MutableHashTable(
- key_dtype=key_dtype,
- value_dtype=value_dtype,
- default_value=default_value,
- name='%s-%d-of-%d' % (name, i + 1, num_shards)))
+ table_shards.append(
+ lookup_ops.MutableDenseHashTable(
+ key_dtype=key_dtype,
+ value_dtype=value_dtype,
+ default_value=default_value,
+ empty_key=empty_key,
+ name='%s-%d-of-%d' % (name, i + 1, num_shards)))
self._table_shards = table_shards
# TODO(andreasst): add a value_shape() method to LookupInterface
# pylint: disable=protected-access
@@ -97,16 +95,28 @@ class _ShardedMutableHashTable(lookup_ops.LookupInterface):
return math_ops.add_n(sizes)
def _shard_indices(self, keys):
- if self._key_dtype == dtypes.string:
- indices = string_ops.string_to_hash_bucket_fast(keys, self._num_shards)
- else:
- indices = math_ops.mod(keys, self._num_shards)
+ key_shape = keys.get_shape()
+ if key_shape.ndims > 1:
+ # If keys are a matrix (i.e. a single key is a vector), we use the first
+ # element of each key vector to determine the shard.
+ keys = array_ops.slice(keys, [0, 0], [key_shape[0].value, 1])
+ keys = array_ops.reshape(keys, [-1])
+ indices = math_ops.mod(math_ops.abs(keys), self._num_shards)
return math_ops.cast(indices, dtypes.int32)
+ def _check_keys(self, keys):
+ if not keys.get_shape().is_fully_defined():
+ raise ValueError('Key shape must be fully defined, got %s.' %
+ keys.get_shape())
+ if keys.get_shape().ndims != 1 and keys.get_shape().ndims != 2:
+ raise ValueError('Expected a vector or matrix for keys, got %s.' %
+ keys.get_shape())
+
def lookup(self, keys, name=None):
if keys.dtype != self._key_dtype:
raise TypeError('Signature mismatch. Keys must be dtype %s, got %s.' %
(self._key_dtype, keys.dtype))
+ self._check_keys(keys)
num_shards = self._num_shards
if num_shards == 1:
return self._table_shards[0].lookup(keys, name=name)
@@ -120,15 +130,18 @@ class _ShardedMutableHashTable(lookup_ops.LookupInterface):
for i in range(num_shards)
]
- original_indices = math_ops.range(array_ops.size(keys))
+ num_keys = keys.get_shape().dims[0]
+ original_indices = math_ops.range(num_keys)
partitioned_indices = data_flow_ops.dynamic_partition(original_indices,
shard_indices,
num_shards)
result = data_flow_ops.dynamic_stitch(partitioned_indices, value_shards)
- result.set_shape(keys.get_shape().concatenate(self._value_shape))
+ result.set_shape(
+ tensor_shape.TensorShape([num_keys]).concatenate(self._value_shape))
return result
def insert(self, keys, values, name=None):
+ self._check_keys(keys)
num_shards = self._num_shards
if num_shards == 1:
return self._table_shards[0].insert(keys, values, name=name)
@@ -359,11 +372,14 @@ class SdcaModel(object):
self._variables = variables
self._options = options
self._create_slots()
- self._hashtable = _ShardedMutableHashTable(
- key_dtype=dtypes.string,
+ self._hashtable = _ShardedMutableDenseHashTable(
+ key_dtype=dtypes.int64,
value_dtype=dtypes.float32,
num_shards=self._num_table_shards(),
- default_value=[0.0, 0.0, 0.0, 0.0])
+ default_value=[0.0, 0.0, 0.0, 0.0],
+ # SdcaFprint never returns 0 or 1 for the low64 bits, so this a safe
+ # empty_key (that will never collide with actual payloads).
+ empty_key=[0, 0])
logging_ops.scalar_summary('approximate_duality_gap',
self.approximate_duality_gap())
@@ -519,8 +535,10 @@ class SdcaModel(object):
if sf.feature_values is not None:
sparse_features_values.append(sf.feature_values)
- example_ids_hashed = sdca_fprint(
+ # pylint: disable=protected-access
+ example_ids_hashed = gen_sdca_ops._sdca_fprint(
convert_to_tensor(self._examples['example_ids']))
+ # pylint: enable=protected-access
example_state_data = self._hashtable.lookup(example_ids_hashed)
# Solver returns example_state_update, new delta sparse_feature_weights
# and delta dense_feature_weights.
@@ -538,7 +556,8 @@ class SdcaModel(object):
dtypes.int64))
sparse_weights.append(array_ops.gather(w, sparse_indices[-1]))
- esu, sfw, dfw = sdca_optimizer(
+ # pylint: disable=protected-access
+ esu, sfw, dfw = gen_sdca_ops._sdca_optimizer(
sparse_example_indices,
sparse_feature_indices,
sparse_features_values,
@@ -555,6 +574,7 @@ class SdcaModel(object):
l2=self._symmetric_l2_regularization(),
num_loss_partitions=self._num_loss_partitions(),
num_inner_iterations=1)
+ # pylint: enable=protected-access
with ops.control_dependencies([esu]):
update_ops = [self._hashtable.insert(example_ids_hashed, esu)]
@@ -597,8 +617,9 @@ class SdcaModel(object):
for name in ['sparse_features_weights', 'dense_features_weights']:
for var in self._variables[name]:
with ops.device(var.device):
+ # pylint: disable=protected-access
update_ops.append(
- sdca_shrink_l1(
+ gen_sdca_ops._sdca_shrink_l1(
self._convert_n_to_tensor(
[var], as_ref=True),
l1=self._symmetric_l1_regularization(),
@@ -617,8 +638,14 @@ class SdcaModel(object):
shard_sums = []
for values in values_list:
with ops.device(values.device):
- shard_sums.append(
- math_ops.reduce_sum(math_ops.cast(values, dtypes.float64), 0))
+ # For large tables to_double() below allocates a large temporary
+ # tensor that is freed once the sum operation completes. To reduce
+ # peak memory usage in cases where we have multiple large tables on a
+ # single device, we serialize these operations.
+ # Note that we need double precision to get accurate results.
+ with ops.control_dependencies(shard_sums):
+ shard_sums.append(
+ math_ops.reduce_sum(math_ops.to_double(values), 0))
summed_values = math_ops.add_n(shard_sums)
primal_loss = summed_values[1]
diff --git a/tensorflow/core/kernels/sdca_ops.cc b/tensorflow/core/kernels/sdca_ops.cc
index d30e7486f5..578466e202 100644
--- a/tensorflow/core/kernels/sdca_ops.cc
+++ b/tensorflow/core/kernels/sdca_ops.cc
@@ -1079,7 +1079,7 @@ REGISTER_KERNEL_BUILDER(Name("SdcaShrinkL1").Device(DEVICE_CPU), SdcaShrinkL1);
// persistent storage, as its implementation may change in the future.
//
// The current probability of at least one collision for 1B example_ids is
-// approximately 10^-11 (ie 2^60 / 2^97).
+// approximately 10^-21 (ie 2^60 / 2^129).
class SdcaFprint : public OpKernel {
public:
explicit SdcaFprint(OpKernelConstruction* const context)
@@ -1087,27 +1087,27 @@ class SdcaFprint : public OpKernel {
void Compute(OpKernelContext* const context) override {
const Tensor& input = context->input(0);
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()),
+ errors::InvalidArgument("Input must be a vector, got shape ",
+ input.shape().DebugString()));
Tensor* out;
- OP_REQUIRES_OK(context, context->allocate_output(0, input.shape(), &out));
+ const int64 num_elements = input.NumElements();
+ OP_REQUIRES_OK(context, context->allocate_output(
+ 0, TensorShape({num_elements, 2}), &out));
const auto in_values = input.flat<string>();
- auto out_values = out->flat<string>();
-
- for (int64 i = 0; i < in_values.size(); ++i) {
- out_values(i) = Fp128ToBinaryString(Fingerprint128(in_values(i)));
+ auto out_values = out->matrix<int64>();
+
+ for (int64 i = 0; i < num_elements; ++i) {
+ const Fprint128 fprint = Fingerprint128(in_values(i));
+ // Never return 0 or 1 as the first value of the hash to allow these to
+ // safely be used as sentinel values (e.g. dense hash table empty key).
+ out_values(i, 0) = TF_PREDICT_TRUE(fprint.low64 >= 2)
+ ? fprint.low64
+ : fprint.low64 + ~static_cast<uint64>(1);
+ out_values(i, 1) = fprint.high64;
}
}
-
- private:
- // Returns a 12 character binary string of the fprint.
- // We use 12 of the 16 fingerprint bytes to save memory, in particular in
- // string implementations that use a short string optimization.
- static string Fp128ToBinaryString(const Fprint128& fprint) {
- string result;
- core::PutFixed64(&result, fprint.low64);
- core::PutFixed32(&result, fprint.high64);
- return result;
- }
};
REGISTER_KERNEL_BUILDER(Name("SdcaFprint").Device(DEVICE_CPU), SdcaFprint);
diff --git a/tensorflow/core/ops/compat/ops_history.v0.pbtxt b/tensorflow/core/ops/compat/ops_history.v0.pbtxt
index b5b056e41f..7b35139cce 100644
--- a/tensorflow/core/ops/compat/ops_history.v0.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v0.pbtxt
@@ -25009,17 +25009,6 @@ op {
}
}
op {
- name: "SdcaFprint"
- input_arg {
- name: "input"
- type: DT_STRING
- }
- output_arg {
- name: "output"
- type: DT_STRING
- }
-}
-op {
name: "SdcaOptimizer"
input_arg {
name: "sparse_example_indices"
diff --git a/tensorflow/core/ops/sdca_ops.cc b/tensorflow/core/ops/sdca_ops.cc
index 1d1fc51317..0aed75366e 100644
--- a/tensorflow/core/ops/sdca_ops.cc
+++ b/tensorflow/core/ops/sdca_ops.cc
@@ -137,13 +137,21 @@ weights: a list of vectors where each value is the weight associated with a
REGISTER_OP("SdcaFprint")
.Input("input: string")
- .Output("output: string")
- .SetShapeFn(shape_inference::UnchangedShape)
+ .Output("output: int64")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
+ ShapeHandle output_shape;
+ TF_RETURN_IF_ERROR(c->Concatenate(handle, c->Vector(2), &output_shape));
+ c->set_output(0, output_shape);
+ return Status::OK();
+ })
.Doc(R"doc(
Computes fingerprints of the input strings.
input: vector of strings to compute fingerprints on.
-output: vector containing the computed fingerprints.
+output: a (N,2) shaped matrix where N is the number of elements in the input
+ vector. Each row contains the low and high parts of the fingerprint.
)doc");
} // namespace tensorflow
diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt
index fce784de00..362cc0518e 100644
--- a/tensorflow/python/ops/hidden_ops.txt
+++ b/tensorflow/python/ops/hidden_ops.txt
@@ -229,6 +229,11 @@ TruncatedNormal
PyFunc
PyFuncStateless
+# sdca_ops
+SdcaFprint
+SdcaOptimizer
+SdcaShrinkL1
+
# state_ops
Variable
TemporaryVariable