aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc')
-rw-r--r--tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc186
1 files changed, 66 insertions, 120 deletions
diff --git a/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc b/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc
index 1e9ca3d256..68146a3dff 100644
--- a/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc
+++ b/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc
@@ -73,11 +73,6 @@ struct Regularizations {
float symmetric_l2 = 0;
};
-struct RegularizationLoss {
- double l1_loss = 0;
- double l2_loss = 0;
-};
-
struct PerExampleData {
double wx = 0;
double norm = 0;
@@ -102,7 +97,7 @@ using DenseFeaturesByGroup = std::vector<TTypes<const float>::Vec>;
// indicates that the contents of sparse_examples_by_group cannot be trusted or
// used.
Status FillSparseExamplesByGroup(
- const int64 num_sparse_features, const int64 num_examples,
+ const int64 num_sparse_features, const int num_examples,
const OpInputList& sparse_features_indices_inputs,
const OpInputList& sparse_features_values_inputs,
const WeightsByGroup& sparse_weights_by_group,
@@ -127,7 +122,10 @@ Status FillSparseExamplesByGroup(
static const int64 kIndicesDims = 2;
gtl::InlinedVector<int64, 8> order(kIndicesDims);
std::iota(order.begin(), order.end(), 0);
- for (int64 i = begin; i < end; ++i) {
+
+ // The static_cast here is safe since begin and end can be at most
+ // num_examples which is an int.
+ for (int i = static_cast<int>(begin); i < end; ++i) {
if (sparse_features_indices_inputs[i].shape().dims() != kIndicesDims) {
mutex_lock l(mu);
result = errors::InvalidArgument(strings::Printf(
@@ -147,7 +145,7 @@ Status FillSparseExamplesByGroup(
if (example_index < 0 || example_index >= num_examples) {
mutex_lock l(mu);
result = errors::Internal(strings::Printf(
- "Example indices should be in [0, %lld). Encountered: %lld",
+ "Example indices should be in [0, %d). Encountered: %lld",
num_examples, example_index));
return;
}
@@ -203,35 +201,6 @@ inline double Shrink(const double weight, const double shrink_by) {
return 0.0;
}
-// Compute L1 and L2 regularization loss.
-inline RegularizationLoss ComputeRegularizationLoss(
- const WeightsByGroup& sparse_weights_by_group,
- const WeightsByGroup& dense_weights_by_group,
- const Regularizations& regularizations) {
- RegularizationLoss result;
-
- const double shrink_by = ShrinkageFactor(regularizations);
- auto accumulate_regularization_loss = [&](const double w) {
- const double sw = std::abs(Shrink(w, shrink_by));
- result.l1_loss += sw;
- result.l2_loss += sw * sw;
- };
-
- for (const TTypes<float>::Vec weights : sparse_weights_by_group) {
- for (int64 i = 0; i < weights.size(); ++i) {
- accumulate_regularization_loss(weights(i));
- }
- }
-
- for (const TTypes<float>::Vec weights : dense_weights_by_group) {
- accumulate_regularization_loss(weights(0));
- }
-
- result.l1_loss *= regularizations.symmetric_l1;
- result.l2_loss *= regularizations.symmetric_l2;
- return result;
-}
-
// Compute PerExampleData which contains the logits, and weighted example norm
// for a given example_id. Norm is weighted by 1/(lambda*N).
inline PerExampleData ComputeWxAndWeightedExampleNorm(
@@ -380,7 +349,7 @@ WeightsByGroup MakeDeltaWeightsFrom(std::vector<Tensor>* const tensors) {
}
Status RunTrainStepsForMiniBatch(
- const int64 num_examples, const TTypes<const string>::Vec example_ids,
+ const int num_examples, const TTypes<const string>::Vec example_ids,
const TTypes<const float>::Vec example_labels,
const TTypes<const float>::Vec example_weights,
const DeviceBase::CpuWorkerThreads& worker_threads,
@@ -459,6 +428,13 @@ Status RunTrainStepsForMiniBatch(
return train_step_status;
}
+Status FillRegularizations(OpKernelConstruction* const context,
+ Regularizations* const regularizations) {
+ TF_RETURN_IF_ERROR(context->GetAttr("l1", &regularizations->symmetric_l1));
+ TF_RETURN_IF_ERROR(context->GetAttr("l2", &regularizations->symmetric_l2));
+ return Status::OK();
+}
+
} // namespace
class SdcaSolver : public OpKernel {
@@ -484,25 +460,9 @@ class SdcaSolver : public OpKernel {
OP_REQUIRES(
context, num_sparse_features_ + num_dense_features_ > 0,
errors::InvalidArgument("Requires at least one feature to train."));
-
- OP_REQUIRES_OK(context,
- context->GetAttr("l1", &regularizations_.symmetric_l1));
- OP_REQUIRES_OK(context,
- context->GetAttr("l2", &regularizations_.symmetric_l2));
- // We enforce a minimal l2, required by the algorithm.
- regularizations_.symmetric_l2 =
- std::max(regularizations_.symmetric_l2, 1.0f);
-
+ OP_REQUIRES_OK(context, FillRegularizations(context, &regularizations_));
OP_REQUIRES_OK(context, context->GetAttr("num_inner_iterations",
&num_inner_iterations_));
-
- // TODO(rohananil): Provide emperical evidence for this. It is better to run
- // more than one iteration on single mini-batch as we want to spend more
- // time in compute. SDCA works better with larger mini batches and there
- // is also recent work that shows its better to reuse old samples than train
- // on new samples. See: http://arxiv.org/abs/1602.02136.
- num_inner_iterations_ =
- std::max(num_inner_iterations_, static_cast<int64>(2));
OP_REQUIRES_OK(context, context->GetAttr("container", &container_));
OP_REQUIRES_OK(context, context->GetAttr("solver_uuid", &solver_uuid_));
}
@@ -533,21 +493,16 @@ class SdcaSolver : public OpKernel {
OP_REQUIRES(context, TensorShapeUtils::IsVector(example_weights_t->shape()),
errors::InvalidArgument("example_weights should be a vector."));
const auto example_weights = example_weights_t->vec<float>();
-
- Eigen::Tensor<float, 0, Eigen::RowMajor> example_weights_sum;
- example_weights_sum.device(context->eigen_cpu_device()) =
- example_weights.sum();
- const float weighted_examples = example_weights_sum();
- const int64 num_examples = example_weights.size();
-
- OP_REQUIRES(context, weighted_examples > 0,
- errors::InvalidArgument("No weighted examples in ",
- num_examples, " training examples"));
+ OP_REQUIRES(context,
+ example_weights.size() <= std::numeric_limits<int>::max(),
+ errors::InvalidArgument(strings::Printf(
+ "Too many examples in a mini-batch: %ld > %d",
+ example_weights.size(), std::numeric_limits<int>::max())));
+ const int num_examples = static_cast<int>(example_weights.size());
OpInputList dense_features_inputs;
OP_REQUIRES_OK(
context, context->input_list("dense_features", &dense_features_inputs));
-
DenseFeaturesByGroup dense_features_by_group;
for (const auto& dense_feature : dense_features_inputs) {
dense_features_by_group.emplace_back(dense_feature.vec<float>());
@@ -562,7 +517,7 @@ class SdcaSolver : public OpKernel {
OP_REQUIRES(context, example_labels.size() == num_examples,
errors::InvalidArgument(strings::Printf(
"The number of example labels (%ld) should match the "
- "number of example weights (%lld).",
+ "number of example weights (%d).",
example_labels.size(), num_examples)));
const Tensor* example_ids_t;
@@ -573,7 +528,7 @@ class SdcaSolver : public OpKernel {
OP_REQUIRES(context, example_labels.size() == num_examples,
errors::InvalidArgument(strings::Printf(
"The number of example ids (%ld) should match the number "
- "of example weights (%lld).",
+ "of example weights (%d).",
example_ids.size(), num_examples)));
const int64 num_duplicate_example_ids = [&] {
// TODO(katsiapis): Benchmark and/or optimize.
@@ -632,12 +587,7 @@ class SdcaSolver : public OpKernel {
SetZeroDeltaWeights(&sparse_delta_weights_by_group,
&dense_delta_weights_by_group);
- // TODO(rohananil): Provide emperical evidence for this. It is better to run
- // more than one iteration on single mini-batch as we want to spend more
- // time in compute. SDCA works better with larger mini batches and there
- // is also recent work that shows its better to reuse old samples than train
- // on new samples. See: http://arxiv.org/abs/1602.02136.
- for (int64 i = 0; i < num_inner_iterations_; ++i) {
+ for (int i = 0; i < num_inner_iterations_; ++i) {
OP_REQUIRES_OK(
context,
RunTrainStepsForMiniBatch(
@@ -669,7 +619,7 @@ class SdcaSolver : public OpKernel {
int64 num_sparse_features_;
int64 num_dense_features_;
Regularizations regularizations_;
- int64 num_inner_iterations_;
+ int num_inner_iterations_;
string container_;
string solver_uuid_;
};
@@ -678,13 +628,7 @@ REGISTER_KERNEL_BUILDER(Name("SdcaSolver").Device(DEVICE_CPU), SdcaSolver);
class SdcaShrinkL1 : public OpKernel {
public:
explicit SdcaShrinkL1(OpKernelConstruction* context) : OpKernel(context) {
- OP_REQUIRES_OK(context,
- context->GetAttr("l1", &regularizations_.symmetric_l1));
- OP_REQUIRES_OK(context,
- context->GetAttr("l2", &regularizations_.symmetric_l2));
- // We enforce a minimal l2, required by the algorithm.
- regularizations_.symmetric_l2 =
- std::max(regularizations_.symmetric_l2, 1.0f);
+ OP_REQUIRES_OK(context, FillRegularizations(context, &regularizations_));
}
void Compute(OpKernelContext* context) override {
@@ -709,19 +653,10 @@ class SdcaShrinkL1 : public OpKernel {
};
REGISTER_KERNEL_BUILDER(Name("SdcaShrinkL1").Device(DEVICE_CPU), SdcaShrinkL1);
-class ComputeDualityGap : public OpKernel {
+class SdcaTrainingStats : public OpKernel {
public:
- explicit ComputeDualityGap(OpKernelConstruction* context)
+ explicit SdcaTrainingStats(OpKernelConstruction* context)
: OpKernel(context) {
- // TODO(rohananil): Refactor grabbing common attributes across ops related
- // to sdca.
- OP_REQUIRES_OK(context,
- context->GetAttr("l1", &regularizations_.symmetric_l1));
- OP_REQUIRES_OK(context,
- context->GetAttr("l2", &regularizations_.symmetric_l2));
- // We enforce a minimal l2, required by the algorithm.
- regularizations_.symmetric_l2 =
- std::max(regularizations_.symmetric_l2, 1.0f);
OP_REQUIRES_OK(context, context->GetAttr("container", &container_));
OP_REQUIRES_OK(context, context->GetAttr("solver_uuid", &solver_uuid_));
}
@@ -734,45 +669,56 @@ class ComputeDualityGap : public OpKernel {
context, !data_by_example->RefCountIsOne(),
errors::Internal("Expected shared-ownership of data_by_example."));
- OpMutableInputList sparse_weights_inputs;
- OP_REQUIRES_OK(context, context->mutable_input_list(
- "sparse_weights", &sparse_weights_inputs));
- WeightsByGroup sparse_weights_by_group =
- MakeWeightsFrom(&sparse_weights_inputs);
-
- OpMutableInputList dense_weights_inputs;
- OP_REQUIRES_OK(context, context->mutable_input_list("dense_weights",
- &dense_weights_inputs));
- WeightsByGroup dense_weights_by_group =
- MakeWeightsFrom(&dense_weights_inputs);
-
- double example_weight_sum = 0;
- double total_duality_gap = 0;
+ double total_primal_loss = 0;
+ double total_dual_loss = 0;
+ double total_example_weight = 0;
OP_REQUIRES_OK(context,
data_by_example->Visit([&](const DataByExample::Data& data) {
- example_weight_sum += data.example_weight;
- total_duality_gap += data.primal_loss + data.dual_loss;
+ total_primal_loss += data.primal_loss;
+ total_dual_loss += data.dual_loss;
+ total_example_weight += data.example_weight;
}));
- const RegularizationLoss regularization_loss = ComputeRegularizationLoss(
- sparse_weights_by_group, dense_weights_by_group, regularizations_);
- total_duality_gap +=
- regularization_loss.l2_loss + regularization_loss.l1_loss;
+ // TODO(katsiapis): Think about the most arithmetically stable way of
+ // computing (dual + primal) loss (if it matters).
- Tensor* duality_gap_t = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output("duality_gap", {}, &duality_gap_t));
- duality_gap_t->scalar<float>()() = total_duality_gap / example_weight_sum;
+ {
+ Tensor* tensor = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output("primal_loss", {}, &tensor));
+ tensor->scalar<double>()() = total_primal_loss;
+ }
+
+ {
+ Tensor* tensor = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output("dual_loss", {}, &tensor));
+ tensor->scalar<double>()() = total_dual_loss;
+ }
+
+ {
+ OP_REQUIRES(
+ context, total_example_weight > 0,
+ errors::FailedPrecondition(
+ "No examples found or all examples have zero weight. Either the "
+ "optimizer was trained with no instances or perhaps there is a "
+ "bug in the training data."));
+
+ Tensor* tensor = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output("example_weights", {}, &tensor));
+ tensor->scalar<double>()() = total_example_weight;
+ }
// TODO(katsiapis): Use core::ScopedUnref once it's moved out of internal.
data_by_example->Unref();
}
private:
- Regularizations regularizations_;
string container_;
string solver_uuid_;
};
-REGISTER_KERNEL_BUILDER(Name("ComputeDualityGap").Device(DEVICE_CPU),
- ComputeDualityGap);
+REGISTER_KERNEL_BUILDER(Name("SdcaTrainingStats").Device(DEVICE_CPU),
+ SdcaTrainingStats);
+
} // namespace tensorflow