aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/factorization
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-26 15:42:00 -0700
committerGravatar Gunhan Gulsoy <gunan@google.com>2018-06-28 21:37:43 -0700
commit4bac5adc9c19d5f658491bfb970db9313f38a995 (patch)
treea328f7e7175d6ba6984486d0d8635bb1a63cea42 /tensorflow/contrib/factorization
parent24c1634196129d60568044437b6db225f6f7d721 (diff)
Adding per-element weight support for WALSComputePartialLhsAndRhsOp operator.
PiperOrigin-RevId: 202208129
Diffstat (limited to 'tensorflow/contrib/factorization')
-rw-r--r--tensorflow/contrib/factorization/kernels/wals_solver_ops.cc44
-rw-r--r--tensorflow/contrib/factorization/ops/factorization_ops.cc19
-rw-r--r--tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py36
-rw-r--r--tensorflow/contrib/factorization/python/ops/factorization_ops.py1
4 files changed, 89 insertions, 11 deletions
diff --git a/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc b/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc
index bb9b835889..7fcae5ad8e 100644
--- a/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc
+++ b/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc
@@ -62,10 +62,11 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel {
public:
explicit WALSComputePartialLhsAndRhsOp(OpKernelConstruction* context)
: OpKernel(context) {
- OP_REQUIRES_OK(context, context->MatchSignature(
- {DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT,
- DT_INT64, DT_FLOAT, DT_INT64, DT_BOOL},
- {DT_FLOAT, DT_FLOAT}));
+ OP_REQUIRES_OK(context,
+ context->MatchSignature(
+ {DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT64,
+ DT_FLOAT, DT_FLOAT, DT_INT64, DT_BOOL},
+ {DT_FLOAT, DT_FLOAT}));
}
void Compute(OpKernelContext* context) override {
@@ -75,8 +76,9 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel {
const Tensor& input_weights = context->input(3);
const Tensor& input_indices = context->input(4);
const Tensor& input_values = context->input(5);
- const Tensor& input_block_size = context->input(6);
- const Tensor& input_is_transpose = context->input(7);
+ const Tensor& entry_weights = context->input(6);
+ const Tensor& input_block_size = context->input(7);
+ const Tensor& input_is_transpose = context->input(8);
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(factors.shape()),
InvalidArgument("Input factors should be a matrix."));
@@ -89,13 +91,33 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel {
InvalidArgument("Input input_weights should be a vector."));
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices.shape()),
InvalidArgument("Input input_indices should be a matrix."));
+ OP_REQUIRES(
+ context, input_indices.dim_size(1) == 2,
+ InvalidArgument("Input input_indices should have shape (?, 2)."));
OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values.shape()),
InvalidArgument("Input input_values should be a vector"));
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(entry_weights.shape()),
+ InvalidArgument("Input entry_weights should be a vector"));
+ OP_REQUIRES(context, input_indices.dim_size(0) == input_values.dim_size(0),
+ InvalidArgument("Input input_values' length should match the "
+ "first dimension of Input input_indices "));
OP_REQUIRES(context, TensorShapeUtils::IsScalar(input_block_size.shape()),
InvalidArgument("Input input_block_size should be a scalar."));
OP_REQUIRES(
context, TensorShapeUtils::IsScalar(input_is_transpose.shape()),
InvalidArgument("Input input_is_transpose should be a scalar."));
+ OP_REQUIRES(
+ context,
+ ((input_weights.dim_size(0) > 0 &&
+ factor_weights.dim_size(0) == factors.dim_size(0) &&
+ entry_weights.dim_size(0) == 0) ||
+ (input_weights.dim_size(0) == 0 && factor_weights.dim_size(0) == 0 &&
+ entry_weights.dim_size(0) == input_indices.dim_size(0))),
+ InvalidArgument("To specify the weights for observed entries, either "
+ "(1) entry_weights must be set or (2) input_weights "
+ "and factor_weights must be set, but not both."));
+ // TODO(yifanchen): Deprecate the support of input_weights and
+ // factor_weights.
const int64 factor_dim = factors.dim_size(1);
const int64 factors_size = factors.dim_size(0);
@@ -105,6 +127,7 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel {
const auto& input_weights_vec = input_weights.vec<float>();
const float w_0 = unobserved_weights.scalar<float>()();
const auto& input_values_vec = input_values.vec<float>();
+ const auto& entry_weights_vec = entry_weights.vec<float>();
ConstEigenMatrixFloatMap factors_mat(factors.matrix<float>().data(),
factor_dim, factors_size);
@@ -134,6 +157,8 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel {
return is_transpose ? indices_mat(0, i) : indices_mat(1, i);
};
+ const bool use_entry_weights = entry_weights_vec.size() > 0;
+
// TODO(rmlarsen): In principle, we should be using the SparseTensor class
// and machinery for iterating over groups, but the fact that class
// SparseTensor makes a complete copy of the matrix makes me reluctant to
@@ -195,6 +220,8 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel {
// map using the hash of the thread id as the key.
//
// TODO(jpoulson): Switch to try_emplace once C++17 is supported
+ // TODO(b/72952120): Check whether the 3 lock-unlock pairs can be
+ // consolidated into just one.
map_mutex.lock();
const auto key_count = factor_batch_map.count(id_hash);
map_mutex.unlock();
@@ -213,6 +240,8 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel {
CHECK_LE(shard.second, perm.size());
CHECK_LE(shard.first, shard.second);
const int64 input_index = get_input_index(perm[shard.first]);
+ const float input_weight =
+ use_entry_weights ? 1.0 : input_weights_vec(input_index);
// Accumulate the rhs and lhs terms in the normal equations
// for the non-zero elements in the row or column of the sparse matrix
// corresponding to input_index.
@@ -228,7 +257,8 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel {
const int64 factor_index = get_factor_index(i);
const float input_value = input_values_vec(i);
const float weight =
- input_weights_vec(input_index) * factor_weights_vec(factor_index);
+ use_entry_weights ? entry_weights_vec(i)
+ : input_weight * factor_weights_vec(factor_index);
CHECK_GE(weight, 0);
factor_batch.col(num_batched) =
factors_mat.col(factor_index) * std::sqrt(weight);
diff --git a/tensorflow/contrib/factorization/ops/factorization_ops.cc b/tensorflow/contrib/factorization/ops/factorization_ops.cc
index 11ea36946e..1d31bd38c8 100644
--- a/tensorflow/contrib/factorization/ops/factorization_ops.cc
+++ b/tensorflow/contrib/factorization/ops/factorization_ops.cc
@@ -25,20 +25,33 @@ REGISTER_OP("WALSComputePartialLhsAndRhs")
.Input("input_weights: float32")
.Input("input_indices: int64")
.Input("input_values: float32")
+ .Input("entry_weights: float32")
.Input("input_block_size: int64")
.Input("input_is_transpose: bool")
.Output("partial_lhs: float32")
.Output("partial_rhs: float32")
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"(
-Computes the partial left-hand side and right-hand side of WALS update.
+Computes the partial left-hand side and right-hand side of WALS update. For
+observed entry input_indices[i]=[m, n] with value input_values[i]=v, the weight
+should be specified either through (1) entry_weights[i] or (2) through
+input_weights[m] * factor_weights[n] (if input_is_transpose is false) or
+input_weights[n] * factor_weights[m] (if input_is_transpose is true). Note it is
+not allowed to have both (1) and (2) specified at the same time: when one
+approach is used, the input tensors related to the other approach must be kept
+completely empty.
factors: Matrix of size m * k.
-factor_weights: Vector of size m. Corresponds to column weights
+factor_weights: Vector of size m. Corresponds to column weights. Should be empty
+ if entry_weights is used.
unobserved_weights: Scalar. Weight for unobserved input entries.
-input_weights: Vector of size n. Corresponds to row weights.
+input_weights: Vector of size n. Corresponds to row weights. Should be empty if
+ entry_weights is used.
input_indices: Indices for the input SparseTensor.
input_values: Values for the input SparseTensor.
+entry_weights: If not empty, this must be same length as input_vaues and is used
+ as the per-entry non-zero weight. If this is used, input_weights and
+ factor_weights must be empty.
input_block_size: Scalar. Number of rows spanned by input.
input_is_transpose: If true, logically transposes the input for processing.
partial_lhs: 3-D tensor with size input_block_size x k x k.
diff --git a/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py b/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py
index ba30fd9977..6c2f1d4608 100644
--- a/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py
+++ b/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py
@@ -55,7 +55,41 @@ class WalsSolverOpsTest(test.TestCase):
rhs_matrix] = gen_factorization_ops.wals_compute_partial_lhs_and_rhs(
self._column_factors, self._column_weights, self._unobserved_weights,
self._row_weights, sparse_block.indices, sparse_block.values,
- sparse_block.dense_shape[0], False)
+ [],
+ input_block_size=sparse_block.dense_shape[0],
+ input_is_transpose=False)
+ self.assertAllClose(lhs_tensor.eval(), [[
+ [0.014800, 0.017000, 0.019200],
+ [0.017000, 0.019600, 0.022200],
+ [0.019200, 0.022200, 0.025200],
+ ], [
+ [0.0064000, 0.0080000, 0.0096000],
+ [0.0080000, 0.0100000, 0.0120000],
+ [0.0096000, 0.0120000, 0.0144000],
+ ], [
+ [0.0099000, 0.0126000, 0.0153000],
+ [0.0126000, 0.0162000, 0.0198000],
+ [0.0153000, 0.0198000, 0.0243000],
+ ], [
+ [0.058800, 0.067200, 0.075600],
+ [0.067200, 0.076800, 0.086400],
+ [0.075600, 0.086400, 0.097200],
+ ]])
+ self.assertAllClose(rhs_matrix.eval(), [[0.019300, 0.023000, 0.026700],
+ [0.061600, 0.077000, 0.092400],
+ [0.160400, 0.220000, 0.279600],
+ [0.492800, 0.563200, 0.633600]])
+
+ def testWalsSolverLhsEntryWeights(self):
+ sparse_block = SparseBlock3x3()
+ with self.test_session():
+ [lhs_tensor,
+ rhs_matrix] = gen_factorization_ops.wals_compute_partial_lhs_and_rhs(
+ self._column_factors, [], self._unobserved_weights,
+ [], sparse_block.indices, sparse_block.values,
+ [0.01, 0.03, 0.04, 0.03, 0.06, 0.12],
+ input_block_size=sparse_block.dense_shape[0],
+ input_is_transpose=False)
self.assertAllClose(lhs_tensor.eval(), [[
[0.014800, 0.017000, 0.019200],
[0.017000, 0.019600, 0.022200],
diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops.py b/tensorflow/contrib/factorization/python/ops/factorization_ops.py
index 8f73274c2a..7ab70fbcfd 100644
--- a/tensorflow/contrib/factorization/python/ops/factorization_ops.py
+++ b/tensorflow/contrib/factorization/python/ops/factorization_ops.py
@@ -943,6 +943,7 @@ class WALSModel(object):
row_weights_slice,
new_sp_input.indices,
new_sp_input.values,
+ [],
num_rows,
transpose_input,
name="wals_compute_partial_lhs_rhs"))