diff options
Diffstat (limited to 'tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc')
-rw-r--r-- | tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc | 186 |
1 files changed, 66 insertions, 120 deletions
diff --git a/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc b/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc index 1e9ca3d256..68146a3dff 100644 --- a/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc +++ b/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc @@ -73,11 +73,6 @@ struct Regularizations { float symmetric_l2 = 0; }; -struct RegularizationLoss { - double l1_loss = 0; - double l2_loss = 0; -}; - struct PerExampleData { double wx = 0; double norm = 0; @@ -102,7 +97,7 @@ using DenseFeaturesByGroup = std::vector<TTypes<const float>::Vec>; // indicates that the contents of sparse_examples_by_group cannot be trusted or // used. Status FillSparseExamplesByGroup( - const int64 num_sparse_features, const int64 num_examples, + const int64 num_sparse_features, const int num_examples, const OpInputList& sparse_features_indices_inputs, const OpInputList& sparse_features_values_inputs, const WeightsByGroup& sparse_weights_by_group, @@ -127,7 +122,10 @@ Status FillSparseExamplesByGroup( static const int64 kIndicesDims = 2; gtl::InlinedVector<int64, 8> order(kIndicesDims); std::iota(order.begin(), order.end(), 0); - for (int64 i = begin; i < end; ++i) { + + // The static_cast here is safe since begin and end can be at most + // num_examples which is an int. + for (int i = static_cast<int>(begin); i < end; ++i) { if (sparse_features_indices_inputs[i].shape().dims() != kIndicesDims) { mutex_lock l(mu); result = errors::InvalidArgument(strings::Printf( @@ -147,7 +145,7 @@ Status FillSparseExamplesByGroup( if (example_index < 0 || example_index >= num_examples) { mutex_lock l(mu); result = errors::Internal(strings::Printf( - "Example indices should be in [0, %lld). Encountered: %lld", + "Example indices should be in [0, %d). Encountered: %lld", num_examples, example_index)); return; } @@ -203,35 +201,6 @@ inline double Shrink(const double weight, const double shrink_by) { return 0.0; } -// Compute L1 and L2 regularization loss. -inline RegularizationLoss ComputeRegularizationLoss( - const WeightsByGroup& sparse_weights_by_group, - const WeightsByGroup& dense_weights_by_group, - const Regularizations& regularizations) { - RegularizationLoss result; - - const double shrink_by = ShrinkageFactor(regularizations); - auto accumulate_regularization_loss = [&](const double w) { - const double sw = std::abs(Shrink(w, shrink_by)); - result.l1_loss += sw; - result.l2_loss += sw * sw; - }; - - for (const TTypes<float>::Vec weights : sparse_weights_by_group) { - for (int64 i = 0; i < weights.size(); ++i) { - accumulate_regularization_loss(weights(i)); - } - } - - for (const TTypes<float>::Vec weights : dense_weights_by_group) { - accumulate_regularization_loss(weights(0)); - } - - result.l1_loss *= regularizations.symmetric_l1; - result.l2_loss *= regularizations.symmetric_l2; - return result; -} - // Compute PerExampleData which contains the logits, and weighted example norm // for a given example_id. Norm is weighted by 1/(lambda*N). inline PerExampleData ComputeWxAndWeightedExampleNorm( @@ -380,7 +349,7 @@ WeightsByGroup MakeDeltaWeightsFrom(std::vector<Tensor>* const tensors) { } Status RunTrainStepsForMiniBatch( - const int64 num_examples, const TTypes<const string>::Vec example_ids, + const int num_examples, const TTypes<const string>::Vec example_ids, const TTypes<const float>::Vec example_labels, const TTypes<const float>::Vec example_weights, const DeviceBase::CpuWorkerThreads& worker_threads, @@ -459,6 +428,13 @@ Status RunTrainStepsForMiniBatch( return train_step_status; } +Status FillRegularizations(OpKernelConstruction* const context, + Regularizations* const regularizations) { + TF_RETURN_IF_ERROR(context->GetAttr("l1", ®ularizations->symmetric_l1)); + TF_RETURN_IF_ERROR(context->GetAttr("l2", ®ularizations->symmetric_l2)); + return Status::OK(); +} + } // namespace class SdcaSolver : public OpKernel { @@ -484,25 +460,9 @@ class SdcaSolver : public OpKernel { OP_REQUIRES( context, num_sparse_features_ + num_dense_features_ > 0, errors::InvalidArgument("Requires at least one feature to train.")); - - OP_REQUIRES_OK(context, - context->GetAttr("l1", ®ularizations_.symmetric_l1)); - OP_REQUIRES_OK(context, - context->GetAttr("l2", ®ularizations_.symmetric_l2)); - // We enforce a minimal l2, required by the algorithm. - regularizations_.symmetric_l2 = - std::max(regularizations_.symmetric_l2, 1.0f); - + OP_REQUIRES_OK(context, FillRegularizations(context, ®ularizations_)); OP_REQUIRES_OK(context, context->GetAttr("num_inner_iterations", &num_inner_iterations_)); - - // TODO(rohananil): Provide emperical evidence for this. It is better to run - // more than one iteration on single mini-batch as we want to spend more - // time in compute. SDCA works better with larger mini batches and there - // is also recent work that shows its better to reuse old samples than train - // on new samples. See: http://arxiv.org/abs/1602.02136. - num_inner_iterations_ = - std::max(num_inner_iterations_, static_cast<int64>(2)); OP_REQUIRES_OK(context, context->GetAttr("container", &container_)); OP_REQUIRES_OK(context, context->GetAttr("solver_uuid", &solver_uuid_)); } @@ -533,21 +493,16 @@ class SdcaSolver : public OpKernel { OP_REQUIRES(context, TensorShapeUtils::IsVector(example_weights_t->shape()), errors::InvalidArgument("example_weights should be a vector.")); const auto example_weights = example_weights_t->vec<float>(); - - Eigen::Tensor<float, 0, Eigen::RowMajor> example_weights_sum; - example_weights_sum.device(context->eigen_cpu_device()) = - example_weights.sum(); - const float weighted_examples = example_weights_sum(); - const int64 num_examples = example_weights.size(); - - OP_REQUIRES(context, weighted_examples > 0, - errors::InvalidArgument("No weighted examples in ", - num_examples, " training examples")); + OP_REQUIRES(context, + example_weights.size() <= std::numeric_limits<int>::max(), + errors::InvalidArgument(strings::Printf( + "Too many examples in a mini-batch: %ld > %d", + example_weights.size(), std::numeric_limits<int>::max()))); + const int num_examples = static_cast<int>(example_weights.size()); OpInputList dense_features_inputs; OP_REQUIRES_OK( context, context->input_list("dense_features", &dense_features_inputs)); - DenseFeaturesByGroup dense_features_by_group; for (const auto& dense_feature : dense_features_inputs) { dense_features_by_group.emplace_back(dense_feature.vec<float>()); @@ -562,7 +517,7 @@ class SdcaSolver : public OpKernel { OP_REQUIRES(context, example_labels.size() == num_examples, errors::InvalidArgument(strings::Printf( "The number of example labels (%ld) should match the " - "number of example weights (%lld).", + "number of example weights (%d).", example_labels.size(), num_examples))); const Tensor* example_ids_t; @@ -573,7 +528,7 @@ class SdcaSolver : public OpKernel { OP_REQUIRES(context, example_labels.size() == num_examples, errors::InvalidArgument(strings::Printf( "The number of example ids (%ld) should match the number " - "of example weights (%lld).", + "of example weights (%d).", example_ids.size(), num_examples))); const int64 num_duplicate_example_ids = [&] { // TODO(katsiapis): Benchmark and/or optimize. @@ -632,12 +587,7 @@ class SdcaSolver : public OpKernel { SetZeroDeltaWeights(&sparse_delta_weights_by_group, &dense_delta_weights_by_group); - // TODO(rohananil): Provide emperical evidence for this. It is better to run - // more than one iteration on single mini-batch as we want to spend more - // time in compute. SDCA works better with larger mini batches and there - // is also recent work that shows its better to reuse old samples than train - // on new samples. See: http://arxiv.org/abs/1602.02136. - for (int64 i = 0; i < num_inner_iterations_; ++i) { + for (int i = 0; i < num_inner_iterations_; ++i) { OP_REQUIRES_OK( context, RunTrainStepsForMiniBatch( @@ -669,7 +619,7 @@ class SdcaSolver : public OpKernel { int64 num_sparse_features_; int64 num_dense_features_; Regularizations regularizations_; - int64 num_inner_iterations_; + int num_inner_iterations_; string container_; string solver_uuid_; }; @@ -678,13 +628,7 @@ REGISTER_KERNEL_BUILDER(Name("SdcaSolver").Device(DEVICE_CPU), SdcaSolver); class SdcaShrinkL1 : public OpKernel { public: explicit SdcaShrinkL1(OpKernelConstruction* context) : OpKernel(context) { - OP_REQUIRES_OK(context, - context->GetAttr("l1", ®ularizations_.symmetric_l1)); - OP_REQUIRES_OK(context, - context->GetAttr("l2", ®ularizations_.symmetric_l2)); - // We enforce a minimal l2, required by the algorithm. - regularizations_.symmetric_l2 = - std::max(regularizations_.symmetric_l2, 1.0f); + OP_REQUIRES_OK(context, FillRegularizations(context, ®ularizations_)); } void Compute(OpKernelContext* context) override { @@ -709,19 +653,10 @@ class SdcaShrinkL1 : public OpKernel { }; REGISTER_KERNEL_BUILDER(Name("SdcaShrinkL1").Device(DEVICE_CPU), SdcaShrinkL1); -class ComputeDualityGap : public OpKernel { +class SdcaTrainingStats : public OpKernel { public: - explicit ComputeDualityGap(OpKernelConstruction* context) + explicit SdcaTrainingStats(OpKernelConstruction* context) : OpKernel(context) { - // TODO(rohananil): Refactor grabbing common attributes across ops related - // to sdca. - OP_REQUIRES_OK(context, - context->GetAttr("l1", ®ularizations_.symmetric_l1)); - OP_REQUIRES_OK(context, - context->GetAttr("l2", ®ularizations_.symmetric_l2)); - // We enforce a minimal l2, required by the algorithm. - regularizations_.symmetric_l2 = - std::max(regularizations_.symmetric_l2, 1.0f); OP_REQUIRES_OK(context, context->GetAttr("container", &container_)); OP_REQUIRES_OK(context, context->GetAttr("solver_uuid", &solver_uuid_)); } @@ -734,45 +669,56 @@ class ComputeDualityGap : public OpKernel { context, !data_by_example->RefCountIsOne(), errors::Internal("Expected shared-ownership of data_by_example.")); - OpMutableInputList sparse_weights_inputs; - OP_REQUIRES_OK(context, context->mutable_input_list( - "sparse_weights", &sparse_weights_inputs)); - WeightsByGroup sparse_weights_by_group = - MakeWeightsFrom(&sparse_weights_inputs); - - OpMutableInputList dense_weights_inputs; - OP_REQUIRES_OK(context, context->mutable_input_list("dense_weights", - &dense_weights_inputs)); - WeightsByGroup dense_weights_by_group = - MakeWeightsFrom(&dense_weights_inputs); - - double example_weight_sum = 0; - double total_duality_gap = 0; + double total_primal_loss = 0; + double total_dual_loss = 0; + double total_example_weight = 0; OP_REQUIRES_OK(context, data_by_example->Visit([&](const DataByExample::Data& data) { - example_weight_sum += data.example_weight; - total_duality_gap += data.primal_loss + data.dual_loss; + total_primal_loss += data.primal_loss; + total_dual_loss += data.dual_loss; + total_example_weight += data.example_weight; })); - const RegularizationLoss regularization_loss = ComputeRegularizationLoss( - sparse_weights_by_group, dense_weights_by_group, regularizations_); - total_duality_gap += - regularization_loss.l2_loss + regularization_loss.l1_loss; + // TODO(katsiapis): Think about the most arithmetically stable way of + // computing (dual + primal) loss (if it matters). - Tensor* duality_gap_t = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output("duality_gap", {}, &duality_gap_t)); - duality_gap_t->scalar<float>()() = total_duality_gap / example_weight_sum; + { + Tensor* tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output("primal_loss", {}, &tensor)); + tensor->scalar<double>()() = total_primal_loss; + } + + { + Tensor* tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output("dual_loss", {}, &tensor)); + tensor->scalar<double>()() = total_dual_loss; + } + + { + OP_REQUIRES( + context, total_example_weight > 0, + errors::FailedPrecondition( + "No examples found or all examples have zero weight. Either the " + "optimizer was trained with no instances or perhaps there is a " + "bug in the training data.")); + + Tensor* tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output("example_weights", {}, &tensor)); + tensor->scalar<double>()() = total_example_weight; + } // TODO(katsiapis): Use core::ScopedUnref once it's moved out of internal. data_by_example->Unref(); } private: - Regularizations regularizations_; string container_; string solver_uuid_; }; -REGISTER_KERNEL_BUILDER(Name("ComputeDualityGap").Device(DEVICE_CPU), - ComputeDualityGap); +REGISTER_KERNEL_BUILDER(Name("SdcaTrainingStats").Device(DEVICE_CPU), + SdcaTrainingStats); + } // namespace tensorflow |