diff options
author | 2016-10-27 16:46:34 -0800 | |
---|---|---|
committer | 2016-10-27 18:03:44 -0700 | |
commit | 87a61b18f83d4f5ec8a796c2a4d665d3010eac91 (patch) | |
tree | 51c9285415bc0e5bda9dfa14b250084ffdd7bd99 | |
parent | 33dd5d021939ac84c020a905b65c1e134b840944 (diff) |
Change SDCA to use DenseMutableHashtable.
Change SdcaFprint to return a two int64 values instead of one string to save memory
given that DenseMutableHashtable supports non-scalar keys.
Corresponding changes to the private sharded hashtable wrapper to support non-scalar keys.
Change: 137464256
-rw-r--r-- | tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py | 67 | ||||
-rw-r--r-- | tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py | 97 | ||||
-rw-r--r-- | tensorflow/core/kernels/sdca_ops.cc | 34 | ||||
-rw-r--r-- | tensorflow/core/ops/compat/ops_history.v0.pbtxt | 11 | ||||
-rw-r--r-- | tensorflow/core/ops/sdca_ops.cc | 14 | ||||
-rw-r--r-- | tensorflow/python/ops/hidden_ops.txt | 5 |
6 files changed, 139 insertions, 89 deletions
diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py index b749cd1866..9cd3c6d690 100644 --- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py +++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py @@ -22,10 +22,11 @@ from threading import Thread import tensorflow as tf -from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import _ShardedMutableHashTable +from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import _ShardedMutableDenseHashTable from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import SdcaModel from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import SparseFeatureColumn from tensorflow.python.framework.test_util import TensorFlowTestCase +from tensorflow.python.ops import gen_sdca_ops from tensorflow.python.platform import googletest _MAX_ITERATIONS = 100 @@ -1013,46 +1014,64 @@ class SdcaFprintTest(SdcaModelTest): def testFprint(self): with self._single_threaded_test_session(): in_data = tf.constant(['abc', 'very looooooong string', 'def']) - out_data = tf.sdca.sdca_fprint(in_data) - self.assertAllEqual( - [b'\x04l\x12\xd2\xaf\xb2\x809E\x9e\x02\x13', - b'\x9f\x0f\x91P\x9aG.Ql\xf2Y\xf9', - b'"0\xe00"\x18_\x08\x12?\xa0\x17'], out_data.eval()) + out_data = gen_sdca_ops._sdca_fprint(in_data) + self.assertAllEqual([[4143508125394299908, -6879828354153669051], + [5849691694103072671, -4874542629849009556], + [603227410218889250, 8762207001949257490]], + out_data.eval()) -class ShardedMutableHashTableTest(SdcaModelTest): +class ShardedMutableDenseHashTableTest(SdcaModelTest): """Tests for the _ShardedMutableHashTable class.""" def testShardedMutableHashTable(self): for num_shards in [1, 3, 10]: with self._single_threaded_test_session(): default_val = -1 - keys = tf.constant(['brain', 'salad', 'surgery']) + empty_key = 0 + keys = tf.constant([11, 12, 13], tf.int64) values = tf.constant([0, 1, 2], tf.int64) - table = _ShardedMutableHashTable(tf.string, - tf.int64, - default_val, - num_shards=num_shards) + table = _ShardedMutableDenseHashTable( + tf.int64, tf.int64, default_val, empty_key, num_shards=num_shards) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() self.assertAllEqual(3, table.size().eval()) - input_string = tf.constant(['brain', 'salad', 'tank']) + input_string = tf.constant([11, 12, 14], tf.int64) output = table.lookup(input_string) self.assertAllEqual([3], output.get_shape()) + self.assertAllEqual([0, 1, -1], output.eval()) - result = output.eval() - self.assertAllEqual([0, 1, -1], result) + def testShardedMutableHashTableVectors(self): + for num_shards in [1, 3, 10]: + with self._single_threaded_test_session(): + default_val = [-0.1, 0.2] + empty_key = [0, 1] + keys = tf.constant([[11, 12], [13, 14], [15, 16]], tf.int64) + values = tf.constant([[0.5, 0.6], [1.5, 1.6], [2.5, 2.6]], tf.float32) + table = _ShardedMutableDenseHashTable( + tf.int64, tf.float32, default_val, empty_key, num_shards=num_shards) + self.assertAllEqual(0, table.size().eval()) + + table.insert(keys, values).run() + self.assertAllEqual(3, table.size().eval()) + + input_string = tf.constant([[11, 12], [13, 14], [11, 14]], tf.int64) + output = table.lookup(input_string) + self.assertAllEqual([3, 2], output.get_shape()) + self.assertAllClose([[0.5, 0.6], [1.5, 1.6], [-0.1, 0.2]], + output.eval()) def testExportSharded(self): with self._single_threaded_test_session(): + empty_key = -2 default_val = -1 num_shards = 2 - keys = tf.constant(['a1', 'b1', 'c2']) - values = tf.constant([0, 1, 2], tf.int64) - table = _ShardedMutableHashTable( - tf.string, tf.int64, default_val, num_shards=num_shards) + keys = tf.constant([10, 11, 12], tf.int64) + values = tf.constant([2, 3, 4], tf.int64) + table = _ShardedMutableDenseHashTable( + tf.int64, tf.int64, default_val, empty_key, num_shards=num_shards) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() @@ -1062,10 +1081,12 @@ class ShardedMutableHashTableTest(SdcaModelTest): self.assertAllEqual(num_shards, len(keys_list)) self.assertAllEqual(num_shards, len(values_list)) - self.assertAllEqual(set([b'b1', b'c2']), set(keys_list[0].eval())) - self.assertAllEqual([b'a1'], keys_list[1].eval()) - self.assertAllEqual(set([1, 2]), set(values_list[0].eval())) - self.assertAllEqual([0], values_list[1].eval()) + # Exported keys include empty key buckets set to the empty_key + self.assertAllEqual(set([-2, 10, 12]), set(keys_list[0].eval().flatten())) + self.assertAllEqual(set([-2, 11]), set(keys_list[1].eval().flatten())) + # Exported values include empty value buckets set to 0 + self.assertAllEqual(set([0, 2, 4]), set(values_list[0].eval().flatten())) + self.assertAllEqual(set([0, 3]), set(values_list[1].eval().flatten())) if __name__ == '__main__': diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py index ffb1477ce8..13310a3ec8 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py @@ -25,36 +25,31 @@ from six.moves import range from tensorflow.contrib.lookup import lookup_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework.ops import convert_to_tensor from tensorflow.python.framework.ops import name_scope from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import gen_sdca_ops from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import state_ops -from tensorflow.python.ops import string_ops from tensorflow.python.ops import variables as var_ops from tensorflow.python.ops.nn import sigmoid_cross_entropy_with_logits -from tensorflow.python.ops.sdca_ops import sdca_fprint -from tensorflow.python.ops.sdca_ops import sdca_optimizer -from tensorflow.python.ops.sdca_ops import sdca_shrink_l1 __all__ = ['SdcaModel'] -class _ShardedMutableHashTable(lookup_ops.LookupInterface): - """A sharded version of MutableHashTable. +class _ShardedMutableDenseHashTable(lookup_ops.LookupInterface): + """A sharded version of MutableDenseHashTable. It is designed to be interface compatible with LookupInterface and - MutableHashTable, with the exception of the export method, which is replaced - by a custom values_reduce_sum method for SDCA needs. The class is not part of - lookup ops because it is unclear how to make the device placement general - enough to be useful. - - The _ShardedHashTable keeps `num_shards` MutableHashTables internally. If keys - are integers, the shard is computed via the modulo operation. If keys are - strings, the shard is computed via string_to_hash_bucket_fast. + MutableDenseHashTable, with the exception of the export method, which is + replaced by an export_sharded method. + + The _ShardedMutableDenseHashTable keeps `num_shards` MutableDenseHashTable + internally. The shard is computed via the modulo operation on the key. """ # TODO(andreasst): consider moving this to lookup_ops @@ -63,18 +58,21 @@ class _ShardedMutableHashTable(lookup_ops.LookupInterface): key_dtype, value_dtype, default_value, + empty_key, num_shards=1, name='ShardedMutableHashTable'): with ops.name_scope(name, 'sharded_mutable_hash_table') as scope: - super(_ShardedMutableHashTable, self).__init__(key_dtype, value_dtype, - scope) + super(_ShardedMutableDenseHashTable, self).__init__(key_dtype, + value_dtype, scope) table_shards = [] for i in range(num_shards): - table_shards.append(lookup_ops.MutableHashTable( - key_dtype=key_dtype, - value_dtype=value_dtype, - default_value=default_value, - name='%s-%d-of-%d' % (name, i + 1, num_shards))) + table_shards.append( + lookup_ops.MutableDenseHashTable( + key_dtype=key_dtype, + value_dtype=value_dtype, + default_value=default_value, + empty_key=empty_key, + name='%s-%d-of-%d' % (name, i + 1, num_shards))) self._table_shards = table_shards # TODO(andreasst): add a value_shape() method to LookupInterface # pylint: disable=protected-access @@ -97,16 +95,28 @@ class _ShardedMutableHashTable(lookup_ops.LookupInterface): return math_ops.add_n(sizes) def _shard_indices(self, keys): - if self._key_dtype == dtypes.string: - indices = string_ops.string_to_hash_bucket_fast(keys, self._num_shards) - else: - indices = math_ops.mod(keys, self._num_shards) + key_shape = keys.get_shape() + if key_shape.ndims > 1: + # If keys are a matrix (i.e. a single key is a vector), we use the first + # element of each key vector to determine the shard. + keys = array_ops.slice(keys, [0, 0], [key_shape[0].value, 1]) + keys = array_ops.reshape(keys, [-1]) + indices = math_ops.mod(math_ops.abs(keys), self._num_shards) return math_ops.cast(indices, dtypes.int32) + def _check_keys(self, keys): + if not keys.get_shape().is_fully_defined(): + raise ValueError('Key shape must be fully defined, got %s.' % + keys.get_shape()) + if keys.get_shape().ndims != 1 and keys.get_shape().ndims != 2: + raise ValueError('Expected a vector or matrix for keys, got %s.' % + keys.get_shape()) + def lookup(self, keys, name=None): if keys.dtype != self._key_dtype: raise TypeError('Signature mismatch. Keys must be dtype %s, got %s.' % (self._key_dtype, keys.dtype)) + self._check_keys(keys) num_shards = self._num_shards if num_shards == 1: return self._table_shards[0].lookup(keys, name=name) @@ -120,15 +130,18 @@ class _ShardedMutableHashTable(lookup_ops.LookupInterface): for i in range(num_shards) ] - original_indices = math_ops.range(array_ops.size(keys)) + num_keys = keys.get_shape().dims[0] + original_indices = math_ops.range(num_keys) partitioned_indices = data_flow_ops.dynamic_partition(original_indices, shard_indices, num_shards) result = data_flow_ops.dynamic_stitch(partitioned_indices, value_shards) - result.set_shape(keys.get_shape().concatenate(self._value_shape)) + result.set_shape( + tensor_shape.TensorShape([num_keys]).concatenate(self._value_shape)) return result def insert(self, keys, values, name=None): + self._check_keys(keys) num_shards = self._num_shards if num_shards == 1: return self._table_shards[0].insert(keys, values, name=name) @@ -359,11 +372,14 @@ class SdcaModel(object): self._variables = variables self._options = options self._create_slots() - self._hashtable = _ShardedMutableHashTable( - key_dtype=dtypes.string, + self._hashtable = _ShardedMutableDenseHashTable( + key_dtype=dtypes.int64, value_dtype=dtypes.float32, num_shards=self._num_table_shards(), - default_value=[0.0, 0.0, 0.0, 0.0]) + default_value=[0.0, 0.0, 0.0, 0.0], + # SdcaFprint never returns 0 or 1 for the low64 bits, so this a safe + # empty_key (that will never collide with actual payloads). + empty_key=[0, 0]) logging_ops.scalar_summary('approximate_duality_gap', self.approximate_duality_gap()) @@ -519,8 +535,10 @@ class SdcaModel(object): if sf.feature_values is not None: sparse_features_values.append(sf.feature_values) - example_ids_hashed = sdca_fprint( + # pylint: disable=protected-access + example_ids_hashed = gen_sdca_ops._sdca_fprint( convert_to_tensor(self._examples['example_ids'])) + # pylint: enable=protected-access example_state_data = self._hashtable.lookup(example_ids_hashed) # Solver returns example_state_update, new delta sparse_feature_weights # and delta dense_feature_weights. @@ -538,7 +556,8 @@ class SdcaModel(object): dtypes.int64)) sparse_weights.append(array_ops.gather(w, sparse_indices[-1])) - esu, sfw, dfw = sdca_optimizer( + # pylint: disable=protected-access + esu, sfw, dfw = gen_sdca_ops._sdca_optimizer( sparse_example_indices, sparse_feature_indices, sparse_features_values, @@ -555,6 +574,7 @@ class SdcaModel(object): l2=self._symmetric_l2_regularization(), num_loss_partitions=self._num_loss_partitions(), num_inner_iterations=1) + # pylint: enable=protected-access with ops.control_dependencies([esu]): update_ops = [self._hashtable.insert(example_ids_hashed, esu)] @@ -597,8 +617,9 @@ class SdcaModel(object): for name in ['sparse_features_weights', 'dense_features_weights']: for var in self._variables[name]: with ops.device(var.device): + # pylint: disable=protected-access update_ops.append( - sdca_shrink_l1( + gen_sdca_ops._sdca_shrink_l1( self._convert_n_to_tensor( [var], as_ref=True), l1=self._symmetric_l1_regularization(), @@ -617,8 +638,14 @@ class SdcaModel(object): shard_sums = [] for values in values_list: with ops.device(values.device): - shard_sums.append( - math_ops.reduce_sum(math_ops.cast(values, dtypes.float64), 0)) + # For large tables to_double() below allocates a large temporary + # tensor that is freed once the sum operation completes. To reduce + # peak memory usage in cases where we have multiple large tables on a + # single device, we serialize these operations. + # Note that we need double precision to get accurate results. + with ops.control_dependencies(shard_sums): + shard_sums.append( + math_ops.reduce_sum(math_ops.to_double(values), 0)) summed_values = math_ops.add_n(shard_sums) primal_loss = summed_values[1] diff --git a/tensorflow/core/kernels/sdca_ops.cc b/tensorflow/core/kernels/sdca_ops.cc index d30e7486f5..578466e202 100644 --- a/tensorflow/core/kernels/sdca_ops.cc +++ b/tensorflow/core/kernels/sdca_ops.cc @@ -1079,7 +1079,7 @@ REGISTER_KERNEL_BUILDER(Name("SdcaShrinkL1").Device(DEVICE_CPU), SdcaShrinkL1); // persistent storage, as its implementation may change in the future. // // The current probability of at least one collision for 1B example_ids is -// approximately 10^-11 (ie 2^60 / 2^97). +// approximately 10^-21 (ie 2^60 / 2^129). class SdcaFprint : public OpKernel { public: explicit SdcaFprint(OpKernelConstruction* const context) @@ -1087,27 +1087,27 @@ class SdcaFprint : public OpKernel { void Compute(OpKernelContext* const context) override { const Tensor& input = context->input(0); + OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()), + errors::InvalidArgument("Input must be a vector, got shape ", + input.shape().DebugString())); Tensor* out; - OP_REQUIRES_OK(context, context->allocate_output(0, input.shape(), &out)); + const int64 num_elements = input.NumElements(); + OP_REQUIRES_OK(context, context->allocate_output( + 0, TensorShape({num_elements, 2}), &out)); const auto in_values = input.flat<string>(); - auto out_values = out->flat<string>(); - - for (int64 i = 0; i < in_values.size(); ++i) { - out_values(i) = Fp128ToBinaryString(Fingerprint128(in_values(i))); + auto out_values = out->matrix<int64>(); + + for (int64 i = 0; i < num_elements; ++i) { + const Fprint128 fprint = Fingerprint128(in_values(i)); + // Never return 0 or 1 as the first value of the hash to allow these to + // safely be used as sentinel values (e.g. dense hash table empty key). + out_values(i, 0) = TF_PREDICT_TRUE(fprint.low64 >= 2) + ? fprint.low64 + : fprint.low64 + ~static_cast<uint64>(1); + out_values(i, 1) = fprint.high64; } } - - private: - // Returns a 12 character binary string of the fprint. - // We use 12 of the 16 fingerprint bytes to save memory, in particular in - // string implementations that use a short string optimization. - static string Fp128ToBinaryString(const Fprint128& fprint) { - string result; - core::PutFixed64(&result, fprint.low64); - core::PutFixed32(&result, fprint.high64); - return result; - } }; REGISTER_KERNEL_BUILDER(Name("SdcaFprint").Device(DEVICE_CPU), SdcaFprint); diff --git a/tensorflow/core/ops/compat/ops_history.v0.pbtxt b/tensorflow/core/ops/compat/ops_history.v0.pbtxt index b5b056e41f..7b35139cce 100644 --- a/tensorflow/core/ops/compat/ops_history.v0.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v0.pbtxt @@ -25009,17 +25009,6 @@ op { } } op { - name: "SdcaFprint" - input_arg { - name: "input" - type: DT_STRING - } - output_arg { - name: "output" - type: DT_STRING - } -} -op { name: "SdcaOptimizer" input_arg { name: "sparse_example_indices" diff --git a/tensorflow/core/ops/sdca_ops.cc b/tensorflow/core/ops/sdca_ops.cc index 1d1fc51317..0aed75366e 100644 --- a/tensorflow/core/ops/sdca_ops.cc +++ b/tensorflow/core/ops/sdca_ops.cc @@ -137,13 +137,21 @@ weights: a list of vectors where each value is the weight associated with a REGISTER_OP("SdcaFprint") .Input("input: string") - .Output("output: string") - .SetShapeFn(shape_inference::UnchangedShape) + .Output("output: int64") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); + ShapeHandle output_shape; + TF_RETURN_IF_ERROR(c->Concatenate(handle, c->Vector(2), &output_shape)); + c->set_output(0, output_shape); + return Status::OK(); + }) .Doc(R"doc( Computes fingerprints of the input strings. input: vector of strings to compute fingerprints on. -output: vector containing the computed fingerprints. +output: a (N,2) shaped matrix where N is the number of elements in the input + vector. Each row contains the low and high parts of the fingerprint. )doc"); } // namespace tensorflow diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt index fce784de00..362cc0518e 100644 --- a/tensorflow/python/ops/hidden_ops.txt +++ b/tensorflow/python/ops/hidden_ops.txt @@ -229,6 +229,11 @@ TruncatedNormal PyFunc PyFuncStateless +# sdca_ops +SdcaFprint +SdcaOptimizer +SdcaShrinkL1 + # state_ops Variable TemporaryVariable |