diff options
author | 2016-03-09 18:01:47 -0800 | |
---|---|---|
committer | 2016-03-09 20:13:30 -0800 | |
commit | 1a4ff3757aa4b46ad9a13d9070430a3b7ae657e7 (patch) | |
tree | 3f5e0d2e21481a85e5f22eb5bbba20e213be748d | |
parent | 016b25521d0d3462579334f68f34e4052e217bb0 (diff) |
Avoid shuffling, and the use of random generator within the sdca optimizer.
Change: 116824486
-rw-r--r-- | tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc | 30 |
1 files changed, 6 insertions, 24 deletions
diff --git a/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc b/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc index 76671a47cd..37c372cc00 100644 --- a/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc +++ b/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc @@ -22,7 +22,6 @@ limitations under the License. #include <cmath> #include <functional> #include <limits> -#include <random> #include <string> #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" @@ -379,8 +378,7 @@ WeightsByGroup MakeDeltaWeightsFrom(std::vector<Tensor>* const tensors) { } Status RunTrainStepsForMiniBatch( - const std::vector<int64>& example_indices, - const TTypes<const string>::Vec example_ids, + const int64 num_examples, const TTypes<const string>::Vec example_ids, const TTypes<const float>::Vec example_labels, const TTypes<const float>::Vec example_weights, const DeviceBase::CpuWorkerThreads& worker_threads, @@ -397,13 +395,12 @@ Status RunTrainStepsForMiniBatch( mutex mu; Status train_step_status GUARDED_BY(mu); auto train_step = [&](const int64 begin, const int64 end) { - for (int64 offset = begin; offset < end; ++offset) { + for (int64 example_index = begin; example_index < end; ++example_index) { // Get example id, label, and weight. - const int64 example_index = example_indices[offset]; const DataByExample::Key example_key = DataByExample::MakeKey(example_ids(example_index)); DataByExample::Data data = data_by_example->Get(example_key); - const double example_weight = example_weights(example_index); + const float example_weight = example_weights(example_index); float example_label = example_labels(example_index); const Status conversion_status = loss_updater.ConvertLabel(&example_label); @@ -455,8 +452,8 @@ Status RunTrainStepsForMiniBatch( // but perhaps we can tune it better. const int64 kCostPerUnit = 100000 * (sparse_examples_by_group.size() + dense_features_by_group.size()); - Shard(worker_threads.num_threads, worker_threads.workers, - example_indices.size(), kCostPerUnit, train_step); + Shard(worker_threads.num_threads, worker_threads.workers, num_examples, + kCostPerUnit, train_step); return train_step_status; } @@ -636,21 +633,6 @@ class SdcaSolver : public OpKernel { SetZeroDeltaWeights(&sparse_delta_weights_by_group, &dense_delta_weights_by_group); - // Examples are shuffled below at each iteration and processed in a - // partitioned fashion across multiple threads. - // TODO(rohananil): We may want to avoid shuffling inside the op - // but instead shuffle outside as part of the input reader. Re-evaluate - // once we have more data on how this op is used. - const std::vector<int64> example_indices = [num_examples]() { - std::vector<int64> result(num_examples); - std::iota(result.begin(), result.end(), 0); - std::random_device random_device; - std::mt19937 random_generator(random_device()); - // Randomize the examples across iterations for faster convergence. - std::shuffle(result.begin(), result.end(), random_generator); - return result; - }(); - // TODO(rohananil): Provide emperical evidence for this. It is better to run // more than one iteration on single mini-batch as we want to spend more // time in compute. SDCA works better with larger mini batches and there @@ -660,7 +642,7 @@ class SdcaSolver : public OpKernel { OP_REQUIRES_OK( context, RunTrainStepsForMiniBatch( - example_indices, example_ids, example_labels, example_weights, + num_examples, example_ids, example_labels, example_weights, *context->device()->tensorflow_cpu_worker_threads(), regularizations_, sparse_weights_by_group, sparse_examples_by_group, dense_weights_by_group, |