aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-03-09 18:01:47 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-03-09 20:13:30 -0800
commit1a4ff3757aa4b46ad9a13d9070430a3b7ae657e7 (patch)
tree3f5e0d2e21481a85e5f22eb5bbba20e213be748d
parent016b25521d0d3462579334f68f34e4052e217bb0 (diff)
Avoid shuffling, and the use of random generator within the sdca optimizer.
Change: 116824486
-rw-r--r--tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc30
1 files changed, 6 insertions, 24 deletions
diff --git a/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc b/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc
index 76671a47cd..37c372cc00 100644
--- a/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc
+++ b/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc
@@ -22,7 +22,6 @@ limitations under the License.
#include <cmath>
#include <functional>
#include <limits>
-#include <random>
#include <string>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -379,8 +378,7 @@ WeightsByGroup MakeDeltaWeightsFrom(std::vector<Tensor>* const tensors) {
}
Status RunTrainStepsForMiniBatch(
- const std::vector<int64>& example_indices,
- const TTypes<const string>::Vec example_ids,
+ const int64 num_examples, const TTypes<const string>::Vec example_ids,
const TTypes<const float>::Vec example_labels,
const TTypes<const float>::Vec example_weights,
const DeviceBase::CpuWorkerThreads& worker_threads,
@@ -397,13 +395,12 @@ Status RunTrainStepsForMiniBatch(
mutex mu;
Status train_step_status GUARDED_BY(mu);
auto train_step = [&](const int64 begin, const int64 end) {
- for (int64 offset = begin; offset < end; ++offset) {
+ for (int64 example_index = begin; example_index < end; ++example_index) {
// Get example id, label, and weight.
- const int64 example_index = example_indices[offset];
const DataByExample::Key example_key =
DataByExample::MakeKey(example_ids(example_index));
DataByExample::Data data = data_by_example->Get(example_key);
- const double example_weight = example_weights(example_index);
+ const float example_weight = example_weights(example_index);
float example_label = example_labels(example_index);
const Status conversion_status =
loss_updater.ConvertLabel(&example_label);
@@ -455,8 +452,8 @@ Status RunTrainStepsForMiniBatch(
// but perhaps we can tune it better.
const int64 kCostPerUnit = 100000 * (sparse_examples_by_group.size() +
dense_features_by_group.size());
- Shard(worker_threads.num_threads, worker_threads.workers,
- example_indices.size(), kCostPerUnit, train_step);
+ Shard(worker_threads.num_threads, worker_threads.workers, num_examples,
+ kCostPerUnit, train_step);
return train_step_status;
}
@@ -636,21 +633,6 @@ class SdcaSolver : public OpKernel {
SetZeroDeltaWeights(&sparse_delta_weights_by_group,
&dense_delta_weights_by_group);
- // Examples are shuffled below at each iteration and processed in a
- // partitioned fashion across multiple threads.
- // TODO(rohananil): We may want to avoid shuffling inside the op
- // but instead shuffle outside as part of the input reader. Re-evaluate
- // once we have more data on how this op is used.
- const std::vector<int64> example_indices = [num_examples]() {
- std::vector<int64> result(num_examples);
- std::iota(result.begin(), result.end(), 0);
- std::random_device random_device;
- std::mt19937 random_generator(random_device());
- // Randomize the examples across iterations for faster convergence.
- std::shuffle(result.begin(), result.end(), random_generator);
- return result;
- }();
-
// TODO(rohananil): Provide emperical evidence for this. It is better to run
// more than one iteration on single mini-batch as we want to spend more
// time in compute. SDCA works better with larger mini batches and there
@@ -660,7 +642,7 @@ class SdcaSolver : public OpKernel {
OP_REQUIRES_OK(
context,
RunTrainStepsForMiniBatch(
- example_indices, example_ids, example_labels, example_weights,
+ num_examples, example_ids, example_labels, example_weights,
*context->device()->tensorflow_cpu_worker_threads(),
regularizations_, sparse_weights_by_group,
sparse_examples_by_group, dense_weights_by_group,