aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/linear_optimizer
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-04-18 15:28:27 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-04-18 16:32:37 -0700
commitad2122584430ca4e4cc97fa642c1536357402eb6 (patch)
tree7d196451b88fa48e73846b91392ddbce5c25de0d /tensorflow/contrib/linear_optimizer
parent5c9e7d39187321f27cfc1ebb32064f49bb337313 (diff)
Refactoring/unification of weights & features within sdca_ops.cc.
Combines the 6 entities (sparse, dense) x (features, weights, delta-weights) into a single class (FeaturesAndWeights) which tracks all features and weights, hiding their underlying representation. This class is built by composing other classes: FeaturesAndWeights SparseFeaturesAndWeights (examples_by_group_) WeightsAndDeltas WeightsByGroup (delta_weights_by_group_) DenseFeaturesAndWeights (features_by_group_) WeightsAndDeltas (same as above) Also adds a microbenchmark. Change: 120173207
Diffstat (limited to 'tensorflow/contrib/linear_optimizer')
-rw-r--r--tensorflow/contrib/linear_optimizer/kernels/BUILD23
-rw-r--r--tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc731
-rw-r--r--tensorflow/contrib/linear_optimizer/kernels/sdca_ops_test.cc215
3 files changed, 697 insertions, 272 deletions
diff --git a/tensorflow/contrib/linear_optimizer/kernels/BUILD b/tensorflow/contrib/linear_optimizer/kernels/BUILD
index 8a5466b34e..ac9a33ccc4 100644
--- a/tensorflow/contrib/linear_optimizer/kernels/BUILD
+++ b/tensorflow/contrib/linear_optimizer/kernels/BUILD
@@ -6,6 +6,12 @@ exports_files(["LICENSE"])
package(default_visibility = ["//tensorflow:__subpackages__"])
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load(
+ "//tensorflow/core:platform/default/build_config.bzl",
+ "tf_kernel_tests_linkstatic",
+)
+
cc_library(
name = "loss_updaters",
hdrs = [
@@ -70,6 +76,23 @@ cc_library(
alwayslink = 1,
)
+tf_cc_test(
+ name = "sdca_ops_test",
+ size = "small",
+ linkstatic = tf_kernel_tests_linkstatic(), # Required for benchmarking
+ deps = [
+ "//tensorflow/contrib/linear_optimizer:sdca_op_kernels",
+ "//tensorflow/core:all_kernels",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/kernels:ops_util",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc b/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc
index 7500f38be3..d8772e40e5 100644
--- a/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc
+++ b/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc
@@ -56,66 +56,405 @@ limitations under the License.
namespace tensorflow {
namespace {
-// A feature group of a single example by this struct.
-struct PerExampleSparseIndicesWeights {
- // N X 1 vector with feature indices.
- Eigen::Tensor</*const*/ int64, 1, Eigen::RowMajor> feature_indices;
+struct PerExampleData {
+ // feature_weights dot feature_values for the example
+ double wx = 0;
+ // sum of squared feature values occurring in the example divided by
+ // (L2 * N)
+ double normalized_squared_norm = 0;
+};
- // N X 1 vector with feature values.
- TTypes</*const*/ float>::UnalignedVec feature_values;
+PerExampleData AddPerExampleData(const PerExampleData& data1,
+ const PerExampleData& data2) {
+ PerExampleData result;
+ result.wx = data1.wx + data2.wx;
+ result.normalized_squared_norm =
+ data1.normalized_squared_norm + data2.normalized_squared_norm;
+ return result;
+}
- // sum squared norm of the features.
- double norm;
+class Regularizations {
+ public:
+ Regularizations(){};
+
+ // Initialize() must be called immediately after construction.
+ Status Initialize(OpKernelConstruction* const context) {
+ TF_RETURN_IF_ERROR(context->GetAttr("l1", &symmetric_l1_));
+ TF_RETURN_IF_ERROR(context->GetAttr("l2", &symmetric_l2_));
+ shrinkage_factor_ = symmetric_l1_ / symmetric_l2_;
+ return Status::OK();
+ }
+
+ // Proximal SDCA shrinking for L1 regularization.
+ double Shrink(const double weight) const {
+ const double shrink_weight =
+ std::max(std::abs(weight) - shrinkage_factor_, 0.0);
+ if (shrink_weight > 0.0) {
+ return std::copysign(shrink_weight, weight);
+ }
+ return 0.0;
+ }
+
+ float symmetric_l2() const { return symmetric_l2_; }
+
+ private:
+ float symmetric_l1_ = 0;
+ float symmetric_l2_ = 0;
+
+ // L1 divided by L2, precomputed for use during weight shrinking.
+ double shrinkage_factor_ = 0;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Regularizations);
};
-struct Regularizations {
- float symmetric_l1 = 0;
- float symmetric_l2 = 0;
+// Tracks feature weights for groups of features which are input as lists
+// of weight tensors. Each list element becomes a "group" allowing us to
+// refer to an individual feature by [group_num][feature_num].
+class WeightsByGroup {
+ public:
+ WeightsByGroup(){};
+
+ // Initialize() must be called immediately after construction.
+ Status Initialize(OpKernelContext* const context,
+ const string& input_list_name) {
+ OpMutableInputList weights_inputs;
+ TF_RETURN_IF_ERROR(
+ context->mutable_input_list(input_list_name, &weights_inputs));
+ for (int i = 0; i < weights_inputs.size(); ++i) {
+ weights_by_group_.emplace_back(
+ weights_inputs.at(i, /*lock_held=*/true).flat<float>());
+ }
+
+ return Status::OK();
+ }
+
+ // Adds the given 'delta' to the feature indexed by 'group' and 'feature'.
+ void AddDelta(const size_t group, const size_t feature, const float delta) {
+ weights_by_group_[group](feature) += delta;
+ }
+
+ // Modifies all weights according to the shrinkage factor determined by
+ // 'regularizations'.
+ void Shrink(const Regularizations& regularizations) {
+ for (TTypes<float>::Vec weights : weights_by_group_) {
+ for (int64 i = 0; i < weights.size(); ++i) {
+ weights(i) = regularizations.Shrink(weights(i));
+ }
+ }
+ }
+
+ // Returns an error if these weights do not appear to come from dense
+ // features. Currently this means that each group contains a single feature.
+ // TODO(sibyl-Mooth6ku): Support arbitrary dimensional dense weights and remove
+ // this.
+ Status ValidateAsDense() const {
+ for (const TTypes<float>::Vec weights : weights_by_group_) {
+ if (weights.size() != 1) {
+ return errors::InvalidArgument(strings::Printf(
+ "Dense weight vectors should have exactly one entry. Found (%ld). "
+ "This is probably due to a misconfiguration in the optimizer "
+ "setup.",
+ weights.size()));
+ }
+ }
+ return Status::OK();
+ }
+
+ size_t NumGroups() const { return weights_by_group_.size(); }
+
+ const TTypes<float>::Vec& WeightsOfGroup(size_t group) const {
+ return weights_by_group_[group];
+ }
+
+ private:
+ // Weights associated with a (sparse or dense) feature group, such that the
+ // size of weights_by_group_ is the number of feature groups.
+ std::vector<TTypes<float>::Vec> weights_by_group_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(WeightsByGroup);
};
-struct PerExampleData {
- double wx = 0;
- double norm = 0;
+// Tracks weights and delta-weights for either sparse or dense features.
+// As we process a mini-batch, weights are read from tensors and delta-weights
+// are initialized to 0. During processing, delta-weights are modified and
+// at the completion of processing the mini-batch, the delta-weights are added
+// into the original weights and then discarded.
+class WeightsAndDeltas {
+ public:
+ WeightsAndDeltas() {}
+
+ // Initialize() must be called immediately after construction.
+ Status Initialize(OpKernelContext* const context,
+ const string& input_list_name) {
+ TF_RETURN_IF_ERROR(weights_by_group_.Initialize(context, input_list_name));
+ InitializeDeltaWeightsToZero();
+ return Status::OK();
+ }
+
+ // Adds all of the delta weights which were computed during processing
+ // of this mini-batch into the feature-weights. Must be called once
+ // at the end of mini-batch processing.
+ void AddDeltaWeights() {
+ // TODO(sibyl-Aix6ihai): Parallelize this.
+ for (size_t group = 0; group < delta_weights_by_group_.size(); ++group) {
+ for (size_t i = 0; i < delta_weights_by_group_[group].size(); ++i) {
+ weights_by_group_.AddDelta(group, i,
+ delta_weights_by_group_[group][i].load());
+ }
+ }
+ }
+
+ std::vector<std::atomic<double>>* DeltaWeightsOfGroup(size_t group) {
+ return &delta_weights_by_group_[group];
+ }
+
+ const std::vector<std::atomic<double>>& DeltaWeightsOfGroup(
+ size_t group) const {
+ return delta_weights_by_group_[group];
+ }
+
+ const TTypes<float>::Vec& WeightsOfGroup(size_t group) const {
+ return weights_by_group_.WeightsOfGroup(group);
+ }
+
+ size_t NumGroups() const { return delta_weights_by_group_.size(); }
+
+ size_t NumFeaturesOfGroup(int group) const {
+ return delta_weights_by_group_[group].size();
+ }
+
+ // Returns an error if these weights do not appear to come from dense
+ // features. Currently this means that each group contains a single feature.
+ Status ValidateAsDense() const { return weights_by_group_.ValidateAsDense(); }
+
+ private:
+ void InitializeDeltaWeightsToZero() {
+ // TODO(sibyl-Mooth6ku): Maybe parallelize this.
+ for (size_t group = 0; group < weights_by_group_.NumGroups(); ++group) {
+ const TTypes<float>::Vec weights =
+ weights_by_group_.WeightsOfGroup(group);
+ delta_weights_by_group_.emplace_back(weights.size());
+ std::fill(delta_weights_by_group_.back().begin(),
+ delta_weights_by_group_.back().end(), 0);
+ }
+ }
+
+ WeightsByGroup weights_by_group_;
+
+ // Delta weights associated with each of the weights in weights_by_group_,
+ // indexed by [group_num][feature_num]. Atomicity is required when changing
+ // the delta weights in order to have transactional updates.
+ std::vector<std::vector<std::atomic<double>>> delta_weights_by_group_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(WeightsAndDeltas);
};
-// Weights associated with a (sparse or dense) feature group, such that the size
-// of WeightsByGroup is the number of feature groups.
-using WeightsByGroup = std::vector<TTypes<float>::Vec>;
+// Atomically add a double to a std::atomic<double>.
+inline void AtomicAdd(const double src, std::atomic<double>* const dst) {
+ // We use a strong version of compare-exchange, as weak version can spuriously
+ // fail.
+ for (double c = dst->load(); !dst->compare_exchange_strong(c, c + src);) {
+ }
+}
+
+// Tracks all of the information related to the dense features: weights
+// and delta weights, as well as feature occurrences in the current mini-batch.
+class DenseFeaturesAndWeights {
+ public:
+ DenseFeaturesAndWeights() {}
+
+ // Initialize() must be called immediately after construction.
+ Status Initialize(OpKernelContext* const context) {
+ OpInputList dense_features_inputs;
+ TF_RETURN_IF_ERROR(
+ context->input_list("dense_features", &dense_features_inputs));
+ for (const auto& dense_feature : dense_features_inputs) {
+ features_by_group_.emplace_back(dense_feature.vec<float>());
+ }
+
+ TF_RETURN_IF_ERROR(
+ weights_and_deltas_.Initialize(context, "dense_weights"));
+ TF_RETURN_IF_ERROR(weights_and_deltas_.ValidateAsDense());
+ return Status::OK();
+ }
-// DeltaWeights associated with a (sparse or dense) feature group, such that the
-// size of DeltaWeightsByGroup is the number of feature groups. Atomicity is
-// required when changing the weights in order to have transactional updates.
-using DeltaWeightsByGroup = std::vector<std::vector<std::atomic<double>>>;
+ // Computes PerExampleData for 'example_id'.
+ PerExampleData ComputeWxAndWeightedExampleNorm(
+ const int64 example_id, const Regularizations& regularizations) const {
+ PerExampleData result;
+ for (size_t group = 0; group < features_by_group_.size(); ++group) {
+ const double weight = weights_and_deltas_.WeightsOfGroup(group)(0);
+ const std::atomic<double>& delta_weight =
+ weights_and_deltas_.DeltaWeightsOfGroup(group)[0];
+ const double value = features_by_group_[group](example_id);
+ result.wx += regularizations.Shrink(weight + delta_weight.load()) * value;
+ result.normalized_squared_norm += value * value;
+ }
+ result.normalized_squared_norm /= regularizations.symmetric_l2();
+ return result;
+ }
-// SparseExamples represent sparse feature groups of each example.
-using SparseExamples =
- std::vector<std::unique_ptr<const PerExampleSparseIndicesWeights>>;
+ // Updates the delta weight for each feature occuring in 'example_id',
+ // given the weighted change in the dual for this example
+ // (bounded_dual_delta), and the 'l2_regularization'.
+ void UpdateDeltaWeights(const int64 example_id,
+ const double bounded_dual_delta,
+ const double l2_regularization) {
+ for (size_t group = 0; group < features_by_group_.size(); ++group) {
+ std::atomic<double>* const delta_weight =
+ &(*weights_and_deltas_.DeltaWeightsOfGroup(group))[0];
+ const double value = features_by_group_[group](example_id);
+ AtomicAdd(bounded_dual_delta * value / l2_regularization, delta_weight);
+ }
+ }
-// SparseExamples associated with each sparse feature group.
-using SparseExamplesByGroup = std::vector<SparseExamples>;
+ // Adds all of the delta weights which were computed during processing
+ // of this mini-batch into the feature-weights. Must be called once
+ // at the end of mini-batch processing.
+ void AddDeltaWeights() { weights_and_deltas_.AddDeltaWeights(); }
-// Dense features associated with each dense feature group.
-using DenseFeaturesByGroup = std::vector<TTypes<const float>::Vec>;
+ size_t NumGroups() const { return features_by_group_.size(); }
-// Go through the entire training set once, in a parallel and partitioned
+ private:
+ // Dense features associated with each dense feature group.
+ std::vector<TTypes<const float>::Vec> features_by_group_;
+
+ WeightsAndDeltas weights_and_deltas_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(DenseFeaturesAndWeights);
+};
+
+// Tracks all of the information related to the sparse features: weights
+// and delta weights, as well as feature occurrences in the current mini-batch.
+class SparseFeaturesAndWeights {
+ public:
+ SparseFeaturesAndWeights() {}
+
+ // Initialize() must be called immediately after construction.
+ Status Initialize(OpKernelContext* const context,
+ const int64 num_sparse_features, const int num_examples,
+ const DeviceBase::CpuWorkerThreads& worker_threads) {
+ TF_RETURN_IF_ERROR(
+ weights_and_deltas_.Initialize(context, "sparse_weights"));
+ TF_RETURN_IF_ERROR(FillExamples(context, num_sparse_features, num_examples,
+ worker_threads));
+ return Status::OK();
+ }
+
+ // Computes PerExampleData for 'example_id'.
+ PerExampleData ComputeWxAndWeightedExampleNorm(
+ const int64 example_id, const Regularizations& regularizations) const {
+ PerExampleData result;
+ for (size_t group = 0; group < examples_by_group_.size(); ++group) {
+ const TTypes<float>::Vec weights =
+ weights_and_deltas_.WeightsOfGroup(group);
+ const std::vector<std::atomic<double>>& delta_weights =
+ weights_and_deltas_.DeltaWeightsOfGroup(group);
+
+ const SparseExamples& sparse_indices_values = examples_by_group_[group];
+ if (sparse_indices_values[example_id]) {
+ const auto indices = sparse_indices_values[example_id]->feature_indices;
+ const auto values = sparse_indices_values[example_id]->feature_values;
+ for (int64 dim = 0; dim < indices.dimension(0); ++dim) {
+ const int64 index = internal::SubtleMustCopy(indices(dim));
+ const double weight = weights(index);
+ const std::atomic<double>& delta_weight = delta_weights[index];
+ const double value = values(dim);
+ result.wx +=
+ regularizations.Shrink(weight + delta_weight.load()) * value;
+ }
+ result.normalized_squared_norm +=
+ sparse_indices_values[example_id]->squared_norm;
+ }
+ }
+ result.normalized_squared_norm /= regularizations.symmetric_l2();
+ return result;
+ }
+
+ // Updates the delta weight for each feature occuring in 'example_id',
+ // given the weighted change in the dual for this example
+ // (bounded_dual_delta), and the 'l2_regularization'.
+ void UpdateDeltaWeights(const int64 example_id,
+ const double bounded_dual_delta,
+ const double l2_regularization) {
+ for (size_t group = 0; group < examples_by_group_.size(); ++group) {
+ std::vector<std::atomic<double>>& delta_weights =
+ *weights_and_deltas_.DeltaWeightsOfGroup(group);
+
+ const SparseExamples& sparse_indices_values = examples_by_group_[group];
+ if (sparse_indices_values[example_id]) {
+ const auto indices = sparse_indices_values[example_id]->feature_indices;
+ const auto values = sparse_indices_values[example_id]->feature_values;
+ for (int64 dim = 0; dim < indices.dimension(0); ++dim) {
+ const int64 index = internal::SubtleMustCopy(indices(dim));
+ std::atomic<double>* const delta_weight = &delta_weights[index];
+ const double value = values(dim);
+ AtomicAdd(bounded_dual_delta * value / l2_regularization,
+ delta_weight);
+ }
+ }
+ }
+ }
+
+ // Adds all of the delta weights which were computed during processing
+ // of this mini-batch into the feature-weights. Must be called once
+ // at the end of mini-batch processing.
+ void AddDeltaWeights() { weights_and_deltas_.AddDeltaWeights(); }
+
+ size_t NumGroups() const { return examples_by_group_.size(); }
+
+ private:
+ // A feature group of a single example by this struct.
+ struct PerExampleSparseIndicesValues {
+ // N X 1 vector with feature indices.
+ Eigen::Tensor</*const*/ int64, 1, Eigen::RowMajor> feature_indices;
+
+ // N X 1 vector with feature values.
+ TTypes</*const*/ float>::UnalignedVec feature_values;
+
+ // sum squared norm of the features.
+ double squared_norm;
+ };
+
+ Status FillExamples(OpKernelContext* const context,
+ const size_t num_sparse_features, const int num_examples,
+ const DeviceBase::CpuWorkerThreads& worker_threads);
+
+ // SparseExamples represent sparse feature groups of each example.
+ using SparseExamples =
+ std::vector<std::unique_ptr<const PerExampleSparseIndicesValues>>;
+
+ // SparseExamples associated with each sparse feature group.
+ std::vector<SparseExamples> examples_by_group_;
+
+ WeightsAndDeltas weights_and_deltas_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(SparseFeaturesAndWeights);
+};
+
+// Goes through the entire training set once, in a parallel and partitioned
// fashion, so that we create per-example structures. A non-OK return status
-// indicates that the contents of sparse_examples_by_group cannot be trusted or
+// indicates that the contents of SparseFeaturesAndWeights cannot be trusted or
// used.
-Status FillSparseExamplesByGroup(
- const int64 num_sparse_features, const int num_examples,
- const OpInputList& sparse_features_indices_inputs,
- const OpInputList& sparse_features_values_inputs,
- const WeightsByGroup& sparse_weights_by_group,
- const DeviceBase::CpuWorkerThreads& worker_threads,
- SparseExamplesByGroup* const sparse_examples_by_group) {
+Status SparseFeaturesAndWeights::FillExamples(
+ OpKernelContext* const context, const size_t num_sparse_features,
+ const int num_examples,
+ const DeviceBase::CpuWorkerThreads& worker_threads) {
+ OpInputList sparse_features_indices_inputs;
+ TF_RETURN_IF_ERROR(context->input_list("sparse_features_indices",
+ &sparse_features_indices_inputs));
+ OpInputList sparse_features_values_inputs;
+ TF_RETURN_IF_ERROR(context->input_list("sparse_features_values",
+ &sparse_features_values_inputs));
+
if (sparse_features_indices_inputs.size() != num_sparse_features ||
sparse_features_values_inputs.size() != num_sparse_features ||
- sparse_weights_by_group.size() != num_sparse_features) {
+ weights_and_deltas_.NumGroups() != num_sparse_features) {
return errors::Internal("Unaligned sparse features.");
}
- sparse_examples_by_group->clear();
- sparse_examples_by_group->resize(num_sparse_features);
+ examples_by_group_.clear();
+ examples_by_group_.resize(num_sparse_features);
mutex mu;
Status result GUARDED_BY(mu);
@@ -142,7 +481,7 @@ Status FillSparseExamplesByGroup(
sparse::SparseTensor st(
sparse_features_indices_inputs[i], sparse_features_values_inputs[i],
sparse_features_indices_inputs[i].shape(), order);
- (*sparse_examples_by_group)[i] = SparseExamples(num_examples);
+ examples_by_group_[i] = SparseExamples(num_examples);
for (const auto& example_group : st.group({0})) {
const TTypes<int64>::UnalignedConstMatrix indices =
example_group.indices();
@@ -161,21 +500,23 @@ Status FillSparseExamplesByGroup(
const Eigen::Tensor<int64, 0, Eigen::RowMajor> max_feature_index =
feature_indices.maximum();
if (min_feature_index() < 0 ||
- max_feature_index() >= sparse_weights_by_group[i].size()) {
+ static_cast<size_t>(max_feature_index()) >=
+ weights_and_deltas_.NumFeaturesOfGroup(i)) {
mutex_lock l(mu);
result = errors::InvalidArgument(strings::Printf(
"Feature indices should be in [0, %ld). Encountered "
"min:%lld max:%lld for example:%lld",
- sparse_weights_by_group[i].size(), min_feature_index(),
+ weights_and_deltas_.NumFeaturesOfGroup(i), min_feature_index(),
max_feature_index(), example_index));
return;
}
- const Eigen::Tensor<float, 0, Eigen::RowMajor> norm =
+ const Eigen::Tensor<float, 0, Eigen::RowMajor> squared_norm =
example_group.values<float>().square().sum();
- (*sparse_examples_by_group)[i][example_index].reset(
- new PerExampleSparseIndicesWeights{
- feature_indices, example_group.values<float>(), norm()});
+ examples_by_group_[i][example_index].reset(
+ new PerExampleSparseIndicesValues{feature_indices,
+ example_group.values<float>(),
+ squared_norm()});
}
}
};
@@ -192,145 +533,75 @@ Status FillSparseExamplesByGroup(
return result;
}
-// Atomically add a double to a std::atomic<double>.
-inline void AtomicAdd(const double src, std::atomic<double>* const dst) {
- // We use a strong version of compare-exchange, as weak version can spuriously
- // fail.
- for (double c = dst->load(); !dst->compare_exchange_strong(c, c + src);) {
+// FeaturesAndWeights provides a unified view of training features and their
+// weights, abstracting away the differences between sparse and dense
+// feature representations.
+class FeaturesAndWeights {
+ public:
+ FeaturesAndWeights() {}
+
+ // Initialize() must be called immediately after construction.
+ Status Initialize(OpKernelContext* const context,
+ const int64 num_sparse_features, const int num_examples,
+ const DeviceBase::CpuWorkerThreads& worker_threads) {
+ TF_RETURN_IF_ERROR(sparse_features_and_weights_.Initialize(
+ context, num_sparse_features, num_examples, worker_threads));
+ TF_RETURN_IF_ERROR(dense_features_and_weights_.Initialize(context));
+ return Status::OK();
}
-}
-// Compute the shrinkage factor for proximal sdca.
-inline double ShrinkageFactor(const Regularizations& regularizations) {
- return regularizations.symmetric_l1 / regularizations.symmetric_l2;
-}
-
-// Proximal SDCA shrinking for L1 regularization.
-inline double Shrink(const double weight, const double shrink_by) {
- const double shrink_weight = std::max(std::abs(weight) - shrink_by, 0.0);
- if (shrink_weight > 0.0) {
- return std::copysign(shrink_weight, weight);
+ // Computes PerExampleData for 'example_id'.
+ PerExampleData ComputeWxAndWeightedExampleNorm(
+ const int64 example_id, const Regularizations& regularizations) const {
+ const PerExampleData sparse_data =
+ sparse_features_and_weights_.ComputeWxAndWeightedExampleNorm(
+ example_id, regularizations);
+ const PerExampleData dense_data =
+ dense_features_and_weights_.ComputeWxAndWeightedExampleNorm(
+ example_id, regularizations);
+
+ return AddPerExampleData(sparse_data, dense_data);
}
- return 0.0;
-}
-// Compute PerExampleData which contains the logits, and weighted example norm
-// for a given example_id. Norm is weighted by 1/(lambda*N).
-inline PerExampleData ComputeWxAndWeightedExampleNorm(
- const int64 example_id, //
- const WeightsByGroup& sparse_weights_by_group,
- const DeltaWeightsByGroup& sparse_delta_weights_by_group,
- const SparseExamplesByGroup& sparse_examples_by_group,
- const WeightsByGroup& dense_weights_by_group,
- const DeltaWeightsByGroup& dense_delta_weights_by_group,
- const DenseFeaturesByGroup& dense_features_by_group,
- const Regularizations& regularizations) {
- PerExampleData result;
- const double shrink_by = ShrinkageFactor(regularizations);
- for (size_t i = 0; i < sparse_examples_by_group.size(); ++i) {
- const SparseExamples& sparse_indices_values = sparse_examples_by_group[i];
- const TTypes<float>::Vec weights = sparse_weights_by_group[i];
- const std::vector<std::atomic<double>>& delta_weights =
- sparse_delta_weights_by_group[i];
- if (sparse_indices_values[example_id]) {
- const auto indices = sparse_indices_values[example_id]->feature_indices;
- const auto values = sparse_indices_values[example_id]->feature_values;
- for (int64 dim = 0; dim < indices.dimension(0); ++dim) {
- const int64 index = internal::SubtleMustCopy(indices(dim));
- const double weight = weights(index);
- const double value = values(dim);
- result.wx +=
- Shrink(weight + delta_weights[index].load(), shrink_by) * value;
- }
- result.norm += sparse_indices_values[example_id]->norm;
- }
+ // Updates the delta weight for each feature occuring in 'example_id',
+ // given the weighted change in the dual for this example
+ // (bounded_dual_delta), and the 'l2_regularization'.
+ void UpdateDeltaWeights(const int64 example_id,
+ const double bounded_dual_delta,
+ const double l2_regularization) {
+ sparse_features_and_weights_.UpdateDeltaWeights(
+ example_id, bounded_dual_delta, l2_regularization);
+ dense_features_and_weights_.UpdateDeltaWeights(
+ example_id, bounded_dual_delta, l2_regularization);
}
- for (size_t i = 0; i < dense_features_by_group.size(); ++i) {
- // (0) and [0] access guaranteed to be ok due to ValidateDenseWeights().
- const double weight = dense_weights_by_group[i](0);
- const double value = dense_features_by_group[i](example_id);
- result.wx +=
- Shrink(weight + dense_delta_weights_by_group[i][0].load(), shrink_by) *
- value;
- result.norm += value * value;
- }
- result.norm /= regularizations.symmetric_l2;
- return result;
-}
-// Add delta weights to original weights.
-void AddDeltaWeights(const DeltaWeightsByGroup& src,
- WeightsByGroup* const dst) {
- // TODO(sibyl-Aix6ihai): Parallelize this.
- for (size_t group = 0; group < src.size(); ++group) {
- for (size_t i = 0; i < src[group].size(); ++i) {
- (*dst)[group](i) += src[group][i].load();
- }
+ // Adds all of the delta weights which were computed during processing
+ // of this mini-batch into the feature-weights. Must be called once
+ // at the end of mini-batch processing.
+ void AddDeltaWeights() {
+ sparse_features_and_weights_.AddDeltaWeights();
+ dense_features_and_weights_.AddDeltaWeights();
}
-}
-void UpdateDeltaWeights(
- const int64 example_id,
- const SparseExamplesByGroup& sparse_examples_by_group,
- const DenseFeaturesByGroup& dense_features_by_group,
- const double bounded_dual_delta, const double l2_regularization,
- DeltaWeightsByGroup* const sparse_delta_weights_by_group,
- DeltaWeightsByGroup* const dense_delta_weights_by_group) {
- for (size_t i = 0; i < sparse_examples_by_group.size(); ++i) {
- const SparseExamples& sparse_examples = sparse_examples_by_group[i];
- std::vector<std::atomic<double>>& delta_weights =
- (*sparse_delta_weights_by_group)[i];
- if (sparse_examples[example_id]) {
- const auto indices = sparse_examples[example_id]->feature_indices;
- const auto values = sparse_examples[example_id]->feature_values;
- for (int64 dim = 0; dim < indices.dimension(0); ++dim) {
- AtomicAdd(bounded_dual_delta * values(dim) / l2_regularization,
- &delta_weights[indices(dim)]);
- }
- }
- }
- for (size_t i = 0; i < dense_features_by_group.size(); ++i) {
- const auto values = dense_features_by_group[i];
- std::vector<std::atomic<double>>& delta_weights =
- (*dense_delta_weights_by_group)[i];
- AtomicAdd(bounded_dual_delta * values(example_id) / l2_regularization,
- // [0] access guaranteed to be ok due to ValidateDenseWeights().
- &delta_weights[0]);
+ size_t NumGroups() const {
+ return sparse_features_and_weights_.NumGroups() +
+ dense_features_and_weights_.NumGroups();
}
-}
-WeightsByGroup MakeWeightsFrom(OpMutableInputList* const input_list) {
- WeightsByGroup result;
- for (int i = 0; i < input_list->size(); ++i) {
- result.emplace_back(input_list->at(i, /*lock_held=*/true).flat<float>());
- }
- return result;
-}
+ private:
+ SparseFeaturesAndWeights sparse_features_and_weights_;
+ DenseFeaturesAndWeights dense_features_and_weights_;
-DeltaWeightsByGroup MakeZeroDeltaWeightsLike(
- const WeightsByGroup& weights_by_group) {
- // TODO(sibyl-Mooth6ku): Maybe parallelize this.
- DeltaWeightsByGroup result;
- for (const TTypes<float>::Vec weights : weights_by_group) {
- result.emplace_back(weights.size());
- std::fill(result.back().begin(), result.back().end(), 0);
- }
- return result;
-}
+ TF_DISALLOW_COPY_AND_ASSIGN(FeaturesAndWeights);
+};
Status RunTrainStepsForMiniBatch(
const int num_examples, const TTypes<const string>::Vec example_ids,
const TTypes<const float>::Vec example_labels,
const TTypes<const float>::Vec example_weights,
const DeviceBase::CpuWorkerThreads& worker_threads,
- const Regularizations& regularizations,
- const WeightsByGroup& sparse_weights_by_group,
- const SparseExamplesByGroup& sparse_examples_by_group,
- const WeightsByGroup& dense_weights_by_group,
- const DenseFeaturesByGroup& dense_features_by_group,
- const DualLossUpdater& loss_updater,
- DeltaWeightsByGroup* const sparse_delta_weights_by_group,
- DeltaWeightsByGroup* const dense_delta_weights_by_group,
+ const Regularizations& regularizations, const DualLossUpdater& loss_updater,
+ FeaturesAndWeights* const features_and_weights,
DataByExample* const data_by_example) {
// Process examples in parallel, in a partitioned fashion.
mutex mu;
@@ -355,11 +626,9 @@ Status RunTrainStepsForMiniBatch(
// Compute wx, example norm weighted by regularization, dual loss,
// primal loss.
- const PerExampleData per_example_data = ComputeWxAndWeightedExampleNorm(
- example_index, sparse_weights_by_group,
- *sparse_delta_weights_by_group, sparse_examples_by_group,
- dense_weights_by_group, *dense_delta_weights_by_group,
- dense_features_by_group, regularizations);
+ const PerExampleData per_example_data =
+ features_and_weights->ComputeWxAndWeightedExampleNorm(
+ example_index, regularizations);
const double primal_loss = loss_updater.ComputePrimalLoss(
per_example_data.wx, example_label, example_weight);
@@ -369,14 +638,12 @@ Status RunTrainStepsForMiniBatch(
const double new_dual = loss_updater.ComputeUpdatedDual(
example_label, example_weight, data.dual, per_example_data.wx,
- per_example_data.norm, primal_loss, dual_loss);
+ per_example_data.normalized_squared_norm, primal_loss, dual_loss);
// Compute new weights.
const double bounded_dual_delta = (new_dual - data.dual) * example_weight;
- UpdateDeltaWeights(
- example_index, sparse_examples_by_group, dense_features_by_group,
- bounded_dual_delta, regularizations.symmetric_l2,
- sparse_delta_weights_by_group, dense_delta_weights_by_group);
+ features_and_weights->UpdateDeltaWeights(
+ example_index, bounded_dual_delta, regularizations.symmetric_l2());
// Update example data.
data.dual = new_dual;
@@ -388,33 +655,12 @@ Status RunTrainStepsForMiniBatch(
};
// TODO(sibyl-Aix6ihai): Current multiplier 100000 works well empirically
// but perhaps we can tune it better.
- const int64 kCostPerUnit = 100000 * (sparse_examples_by_group.size() +
- dense_features_by_group.size());
+ const int64 kCostPerUnit = 100000 * features_and_weights->NumGroups();
Shard(worker_threads.num_threads, worker_threads.workers, num_examples,
kCostPerUnit, train_step);
return train_step_status;
}
-Status FillRegularizations(OpKernelConstruction* const context,
- Regularizations* const regularizations) {
- TF_RETURN_IF_ERROR(context->GetAttr("l1", &regularizations->symmetric_l1));
- TF_RETURN_IF_ERROR(context->GetAttr("l2", &regularizations->symmetric_l2));
- return Status::OK();
-}
-
-// TODO(sibyl-Mooth6ku): Support arbitrary dimensional dense weights and remove this.
-Status ValidateDenseWeights(const WeightsByGroup& weights_by_group) {
- for (const TTypes<float>::Vec weights : weights_by_group) {
- if (weights.size() != 1) {
- return errors::InvalidArgument(strings::Printf(
- "Dense weight vectors should have exactly one entry. Found (%ld). "
- "This is probably due to a misconfiguration in the optimizer setup.",
- weights.size()));
- }
- }
- return Status::OK();
-}
-
} // namespace
class SdcaSolver : public OpKernel {
@@ -440,7 +686,7 @@ class SdcaSolver : public OpKernel {
OP_REQUIRES(
context, num_sparse_features_ + num_dense_features_ > 0,
errors::InvalidArgument("Requires at least one feature to train."));
- OP_REQUIRES_OK(context, FillRegularizations(context, &regularizations_));
+ OP_REQUIRES_OK(context, regularizations_.Initialize(context));
OP_REQUIRES_OK(context, context->GetAttr("num_inner_iterations",
&num_inner_iterations_));
OP_REQUIRES_OK(context, context->GetAttr("container", &container_));
@@ -480,14 +726,6 @@ class SdcaSolver : public OpKernel {
example_weights.size(), std::numeric_limits<int>::max())));
const int num_examples = static_cast<int>(example_weights.size());
- OpInputList dense_features_inputs;
- OP_REQUIRES_OK(
- context, context->input_list("dense_features", &dense_features_inputs));
- DenseFeaturesByGroup dense_features_by_group;
- for (const auto& dense_feature : dense_features_inputs) {
- dense_features_by_group.emplace_back(dense_feature.vec<float>());
- }
-
const Tensor* example_labels_t;
OP_REQUIRES_OK(context,
context->input("example_labels", &example_labels_t));
@@ -511,39 +749,11 @@ class SdcaSolver : public OpKernel {
"of example weights (%d).",
example_ids.size(), num_examples)));
- OpMutableInputList sparse_weights_inputs;
- OP_REQUIRES_OK(context, context->mutable_input_list(
- "sparse_weights", &sparse_weights_inputs));
- WeightsByGroup sparse_weights_by_group =
- MakeWeightsFrom(&sparse_weights_inputs);
- DeltaWeightsByGroup sparse_delta_weights_by_group =
- MakeZeroDeltaWeightsLike(sparse_weights_by_group);
-
- OpMutableInputList dense_weights_inputs;
- OP_REQUIRES_OK(context, context->mutable_input_list("dense_weights",
- &dense_weights_inputs));
- WeightsByGroup dense_weights_by_group =
- MakeWeightsFrom(&dense_weights_inputs);
- OP_REQUIRES_OK(context, ValidateDenseWeights(dense_weights_by_group));
- DeltaWeightsByGroup dense_delta_weights_by_group =
- MakeZeroDeltaWeightsLike(dense_weights_by_group);
-
- OpInputList sparse_features_indices_inputs;
+ FeaturesAndWeights features_and_weights;
OP_REQUIRES_OK(context,
- context->input_list("sparse_features_indices",
- &sparse_features_indices_inputs));
- OpInputList sparse_features_values_inputs;
- OP_REQUIRES_OK(context,
- context->input_list("sparse_features_values",
- &sparse_features_values_inputs));
- SparseExamplesByGroup sparse_examples_by_group;
- OP_REQUIRES_OK(
- context,
- FillSparseExamplesByGroup(
- num_sparse_features_, num_examples, sparse_features_indices_inputs,
- sparse_features_values_inputs, sparse_weights_by_group,
- *context->device()->tensorflow_cpu_worker_threads(),
- &sparse_examples_by_group));
+ features_and_weights.Initialize(
+ context, num_sparse_features_, num_examples,
+ *context->device()->tensorflow_cpu_worker_threads()));
for (int i = 0; i < num_inner_iterations_; ++i) {
OP_REQUIRES_OK(
@@ -551,14 +761,10 @@ class SdcaSolver : public OpKernel {
RunTrainStepsForMiniBatch(
num_examples, example_ids, example_labels, example_weights,
*context->device()->tensorflow_cpu_worker_threads(),
- regularizations_, sparse_weights_by_group,
- sparse_examples_by_group, dense_weights_by_group,
- dense_features_by_group, *loss_updater_,
- &sparse_delta_weights_by_group, &dense_delta_weights_by_group,
+ regularizations_, *loss_updater_, &features_and_weights,
data_by_example));
}
- AddDeltaWeights(sparse_delta_weights_by_group, &sparse_weights_by_group);
- AddDeltaWeights(dense_delta_weights_by_group, &dense_weights_by_group);
+ features_and_weights.AddDeltaWeights();
// TODO(sibyl-Mooth6ku): Use core::ScopedUnref once it's moved out of internal.
data_by_example->Unref();
@@ -581,33 +787,14 @@ REGISTER_KERNEL_BUILDER(Name("SdcaSolver").Device(DEVICE_CPU), SdcaSolver);
class SdcaShrinkL1 : public OpKernel {
public:
explicit SdcaShrinkL1(OpKernelConstruction* context) : OpKernel(context) {
- OP_REQUIRES_OK(context, FillRegularizations(context, &regularizations_));
+ OP_REQUIRES_OK(context, regularizations_.Initialize(context));
}
void Compute(OpKernelContext* context) override {
- OpMutableInputList sparse_weights_inputs;
- OP_REQUIRES_OK(context, context->mutable_input_list(
- "sparse_weights", &sparse_weights_inputs));
- WeightsByGroup sparse_weights_by_group =
- MakeWeightsFrom(&sparse_weights_inputs);
-
- OpMutableInputList dense_weights_inputs;
- OP_REQUIRES_OK(context, context->mutable_input_list("dense_weights",
- &dense_weights_inputs));
- WeightsByGroup dense_weights_by_group =
- MakeWeightsFrom(&dense_weights_inputs);
- OP_REQUIRES_OK(context, ValidateDenseWeights(dense_weights_by_group));
-
- // TODO(sibyl-Aix6ihai): Parallelize this.
- const double shrink_by = ShrinkageFactor(regularizations_);
- for (TTypes<float>::Vec weights : sparse_weights_by_group) {
- for (int64 i = 0; i < weights.size(); ++i) {
- weights(i) = Shrink(weights(i), shrink_by);
- }
- }
- for (TTypes<float>::Vec weights : dense_weights_by_group) {
- // (0) access guaranteed to be ok due to ValidateDenseWeights().
- weights(0) = Shrink(weights(0), shrink_by);
+ for (const string& list_name : {"sparse_weights", "dense_weights"}) {
+ WeightsByGroup weights_by_group;
+ OP_REQUIRES_OK(context, weights_by_group.Initialize(context, list_name));
+ weights_by_group.Shrink(regularizations_);
}
}
diff --git a/tensorflow/contrib/linear_optimizer/kernels/sdca_ops_test.cc b/tensorflow/contrib/linear_optimizer/kernels/sdca_ops_test.cc
new file mode 100644
index 0000000000..d585553300
--- /dev/null
+++ b/tensorflow/contrib/linear_optimizer/kernels/sdca_ops_test.cc
@@ -0,0 +1,215 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+
+namespace {
+
+SessionOptions* GetOptions() {
+ static SessionOptions* options = []() {
+ // We focus on the single thread performance of training ops.
+ SessionOptions* const result = new SessionOptions();
+ result->config.set_intra_op_parallelism_threads(1);
+ result->config.set_inter_op_parallelism_threads(1);
+ return result;
+ }();
+ return options;
+}
+
+Node* Var(Graph* const g, const int n) {
+ return test::graph::Var(g, DT_FLOAT, TensorShape({n}));
+}
+
+// Returns a vector of size 'nodes' with each node being of size 'node_size'.
+std::vector<Node*> VarVector(Graph* const g, const int nodes,
+ const int node_size) {
+ std::vector<Node*> result;
+ for (int i = 0; i < nodes; ++i) {
+ result.push_back(Var(g, node_size));
+ }
+ return result;
+}
+
+Node* Zeros(Graph* const g, const int n) {
+ Tensor data(DT_FLOAT, TensorShape({n}));
+ data.flat<float>().setZero();
+ return test::graph::Constant(g, data);
+}
+
+Node* Ones(Graph* const g, const int n) {
+ Tensor data(DT_FLOAT, TensorShape({n}));
+ test::FillFn<float>(&data, [](const int i) { return 1.0f; });
+ return test::graph::Constant(g, data);
+}
+
+Node* StringIota(Graph* const g, const int n) {
+ Tensor data(DT_STRING, TensorShape({n}));
+ test::FillFn<string>(
+ &data, [](const int i) { return strings::StrCat(strings::Hex(i)); });
+ return test::graph::Constant(g, data);
+}
+
+Node* SparseIndices(Graph* const g, const int sparse_features_per_group,
+ const int num_examples) {
+ const int x_size = num_examples * 4;
+ const int y_size = 2;
+ Tensor data(DT_INT64, TensorShape({x_size, y_size}));
+ test::FillFn<int64>(&data, [&](const int i) {
+ // Convert FillFn index 'i', to (x,y) for this tensor.
+ const int x = i % y_size;
+ const int y = i / y_size;
+ if (y == 0) {
+ // Populate example index with 4 features per example.
+ return x / 4;
+ } else {
+ // Assign feature indices sequentially - 0,1,2,3 for example 0,
+ // 4,5,6,7 for example 1,.... Wrap back around when we hit
+ // num_sparse-features.
+ return x % sparse_features_per_group;
+ }
+ });
+ return test::graph::Constant(g, data);
+}
+
+Node* RandomZeroOrOne(Graph* const g, const int n) {
+ Tensor data(DT_FLOAT, TensorShape({n}));
+ test::FillFn<float>(&data, [](const int i) {
+ // Fill with 0.0 or 1.0 at random.
+ return (random::New64() % 2) == 0 ? 0.0f : 1.0f;
+ });
+ return test::graph::Constant(g, data);
+}
+
+void GetGraphs(const int32 num_examples, const int32 sparse_feature_groups,
+ const int32 sparse_features_per_group,
+ const int32 dense_feature_groups, Graph** const init_g,
+ Graph** train_g) {
+ {
+ // Build initialization graph
+ Graph* g = new Graph(OpRegistry::Global());
+
+ // These nodes have to be created first, and in the same way as the
+ // nodes in the graph below.
+ std::vector<Node*> sparse_weight_nodes =
+ VarVector(g, sparse_feature_groups, sparse_features_per_group);
+ std::vector<Node*> dense_weight_nodes =
+ VarVector(g, dense_feature_groups, 1);
+ Node* const multi_zero = Zeros(g, sparse_features_per_group);
+ for (Node* n : sparse_weight_nodes) {
+ test::graph::Assign(g, n, multi_zero);
+ }
+ Node* const zero = Zeros(g, 1);
+ for (Node* n : dense_weight_nodes) {
+ test::graph::Assign(g, n, zero);
+ }
+
+ *init_g = g;
+ }
+
+ {
+ // Build execution graph
+ Graph* g = new Graph(OpRegistry::Global());
+
+ // These nodes have to be created first, and in the same way as the
+ // nodes in the graph above.
+ std::vector<Node*> sparse_weight_nodes =
+ VarVector(g, sparse_feature_groups, sparse_features_per_group);
+ std::vector<Node*> dense_weight_nodes =
+ VarVector(g, dense_feature_groups, 1);
+
+ std::vector<NodeBuilder::NodeOut> sparse_weights;
+ for (Node* n : sparse_weight_nodes) {
+ sparse_weights.push_back(NodeBuilder::NodeOut(n));
+ }
+ std::vector<NodeBuilder::NodeOut> dense_weights;
+ for (Node* n : dense_weight_nodes) {
+ dense_weights.push_back(NodeBuilder::NodeOut(n));
+ }
+
+ std::vector<NodeBuilder::NodeOut> sparse_indices;
+ std::vector<NodeBuilder::NodeOut> sparse_values;
+ for (int i = 0; i < sparse_feature_groups; ++i) {
+ sparse_indices.push_back(NodeBuilder::NodeOut(
+ SparseIndices(g, sparse_features_per_group, num_examples)));
+ }
+ for (int i = 0; i < sparse_feature_groups; ++i) {
+ sparse_values.push_back(
+ NodeBuilder::NodeOut(RandomZeroOrOne(g, num_examples * 4)));
+ }
+
+ std::vector<NodeBuilder::NodeOut> dense_features;
+ for (int i = 0; i < dense_feature_groups; ++i) {
+ dense_features.push_back(
+ NodeBuilder::NodeOut(RandomZeroOrOne(g, num_examples)));
+ }
+
+ Node* const weights = Ones(g, num_examples);
+ Node* const labels = RandomZeroOrOne(g, num_examples);
+ Node* const ids = StringIota(g, num_examples);
+
+ Node* sdca = nullptr;
+ TF_CHECK_OK(
+ NodeBuilder(g->NewName("sdca"), "SdcaSolver")
+ .Attr("loss_type", "logistic_loss")
+ .Attr("num_sparse_features", sparse_feature_groups)
+ .Attr("num_dense_features", dense_feature_groups)
+ .Attr("l1", 0.0)
+ .Attr("l2", 1.0)
+ .Attr("num_inner_iterations", 2)
+ .Attr("container", strings::StrCat(strings::Hex(random::New64())))
+ .Attr("solver_uuid", strings::StrCat(strings::Hex(random::New64())))
+ .Input(sparse_indices)
+ .Input(sparse_values)
+ .Input(dense_features)
+ .Input(weights)
+ .Input(labels)
+ .Input(ids)
+ .Input(sparse_weights)
+ .Input(dense_weights)
+ .Finalize(g, &sdca));
+
+ *train_g = g;
+ }
+}
+
+void BM_SDCA(const int iters, const int num_examples) {
+ testing::StopTiming();
+ Graph* init = nullptr;
+ Graph* train = nullptr;
+ GetGraphs(num_examples, 20 /* sparse feature groups */,
+ 5 /* sparse features per group */, 20 /* dense features */, &init,
+ &train);
+ testing::StartTiming();
+ test::Benchmark("cpu", train, GetOptions(), init).Run(iters);
+ // TODO(sibyl-toe9oF2e): Each all to Run() currently creates a container which
+ // gets deleted as the context gets deleted. It would be nicer to
+ // explicitly clean up the container ourselves at this point (after calling
+ // testing::StopTiming).
+}
+
+} // namespace
+
+BENCHMARK(BM_SDCA)->Arg(128)->Arg(256)->Arg(512)->Arg(1024);
+
+} // namespace tensorflow