aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/factorization
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-31 12:16:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-31 12:19:57 -0700
commit519189837b77181137505bf83054ddd962600f9b (patch)
treefc1b56cdca999b1b4da369d73742fcb4ace272ba /tensorflow/contrib/factorization
parentfdf4d0813d4c0321be7b33698d00b165d90365b0 (diff)
Making the tf.name_scope blocks related to the factor and weight vars configurable. By default they will not be scoped.
PiperOrigin-RevId: 198759754
Diffstat (limited to 'tensorflow/contrib/factorization')
-rw-r--r--tensorflow/contrib/factorization/python/ops/factorization_ops.py129
1 files changed, 74 insertions, 55 deletions
diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops.py b/tensorflow/contrib/factorization/python/ops/factorization_ops.py
index 09745e2de5..8f73274c2a 100644
--- a/tensorflow/contrib/factorization/python/ops/factorization_ops.py
+++ b/tensorflow/contrib/factorization/python/ops/factorization_ops.py
@@ -197,7 +197,8 @@ class WALSModel(object):
row_weights=1,
col_weights=1,
use_factors_weights_cache=True,
- use_gramian_cache=True):
+ use_gramian_cache=True,
+ use_scoped_vars=False):
"""Creates model for WALS matrix factorization.
Args:
@@ -239,6 +240,8 @@ class WALSModel(object):
weights cache to take effect.
use_gramian_cache: When True, the Gramians will be cached on the workers
before the updates start. Defaults to True.
+ use_scoped_vars: When True, the factor and weight vars will also be nested
+ in a tf.name_scope.
"""
self._input_rows = input_rows
self._input_cols = input_cols
@@ -251,18 +254,36 @@ class WALSModel(object):
regularization * linalg_ops.eye(self._n_components)
if regularization is not None else None)
assert (row_weights is None) == (col_weights is None)
- self._row_weights = WALSModel._create_weights(
- row_weights, self._input_rows, self._num_row_shards, "row_weights")
- self._col_weights = WALSModel._create_weights(
- col_weights, self._input_cols, self._num_col_shards, "col_weights")
self._use_factors_weights_cache = use_factors_weights_cache
self._use_gramian_cache = use_gramian_cache
- self._row_factors = self._create_factors(
- self._input_rows, self._n_components, self._num_row_shards, row_init,
- "row_factors")
- self._col_factors = self._create_factors(
- self._input_cols, self._n_components, self._num_col_shards, col_init,
- "col_factors")
+
+ if use_scoped_vars:
+ with ops.name_scope("row_weights"):
+ self._row_weights = WALSModel._create_weights(
+ row_weights, self._input_rows, self._num_row_shards, "row_weights")
+ with ops.name_scope("col_weights"):
+ self._col_weights = WALSModel._create_weights(
+ col_weights, self._input_cols, self._num_col_shards, "col_weights")
+ with ops.name_scope("row_factors"):
+ self._row_factors = self._create_factors(
+ self._input_rows, self._n_components, self._num_row_shards,
+ row_init, "row_factors")
+ with ops.name_scope("col_factors"):
+ self._col_factors = self._create_factors(
+ self._input_cols, self._n_components, self._num_col_shards,
+ col_init, "col_factors")
+ else:
+ self._row_weights = WALSModel._create_weights(
+ row_weights, self._input_rows, self._num_row_shards, "row_weights")
+ self._col_weights = WALSModel._create_weights(
+ col_weights, self._input_cols, self._num_col_shards, "col_weights")
+ self._row_factors = self._create_factors(
+ self._input_rows, self._n_components, self._num_row_shards, row_init,
+ "row_factors")
+ self._col_factors = self._create_factors(
+ self._input_cols, self._n_components, self._num_col_shards, col_init,
+ "col_factors")
+
self._row_gramian = self._create_gramian(self._n_components, "row_gramian")
self._col_gramian = self._create_gramian(self._n_components, "col_gramian")
with ops.name_scope("row_prepare_gramian"):
@@ -313,37 +334,36 @@ class WALSModel(object):
@classmethod
def _create_factors(cls, rows, cols, num_shards, init, name):
"""Helper function to create row and column factors."""
- with ops.name_scope(name):
- if callable(init):
- init = init()
- if isinstance(init, list):
- assert len(init) == num_shards
- elif isinstance(init, str) and init == "random":
- pass
- elif num_shards == 1:
- init = [init]
- sharded_matrix = []
- sizes = cls._shard_sizes(rows, num_shards)
- assert len(sizes) == num_shards
-
- def make_initializer(i, size):
-
- def initializer():
- if init == "random":
- return random_ops.random_normal([size, cols])
- else:
- return init[i]
+ if callable(init):
+ init = init()
+ if isinstance(init, list):
+ assert len(init) == num_shards
+ elif isinstance(init, str) and init == "random":
+ pass
+ elif num_shards == 1:
+ init = [init]
+ sharded_matrix = []
+ sizes = cls._shard_sizes(rows, num_shards)
+ assert len(sizes) == num_shards
+
+ def make_initializer(i, size):
- return initializer
+ def initializer():
+ if init == "random":
+ return random_ops.random_normal([size, cols])
+ else:
+ return init[i]
- for i, size in enumerate(sizes):
- var_name = "%s_shard_%d" % (name, i)
- var_init = make_initializer(i, size)
- sharded_matrix.append(
- variable_scope.variable(
- var_init, dtype=dtypes.float32, name=var_name))
+ return initializer
- return sharded_matrix
+ for i, size in enumerate(sizes):
+ var_name = "%s_shard_%d" % (name, i)
+ var_init = make_initializer(i, size)
+ sharded_matrix.append(
+ variable_scope.variable(
+ var_init, dtype=dtypes.float32, name=var_name))
+
+ return sharded_matrix
@classmethod
def _create_weights(cls, wt_init, num_wts, num_shards, name):
@@ -384,26 +404,25 @@ class WALSModel(object):
sizes = cls._shard_sizes(num_wts, num_shards)
assert len(sizes) == num_shards
- with ops.name_scope(name):
- def make_wt_initializer(i, size):
+ def make_wt_initializer(i, size):
- def initializer():
- if init_mode == "scalar":
- return wt_init * array_ops.ones([size])
- else:
- return wt_init[i]
+ def initializer():
+ if init_mode == "scalar":
+ return wt_init * array_ops.ones([size])
+ else:
+ return wt_init[i]
- return initializer
+ return initializer
- sharded_weight = []
- for i, size in enumerate(sizes):
- var_name = "%s_shard_%d" % (name, i)
- var_init = make_wt_initializer(i, size)
- sharded_weight.append(
- variable_scope.variable(
- var_init, dtype=dtypes.float32, name=var_name))
+ sharded_weight = []
+ for i, size in enumerate(sizes):
+ var_name = "%s_shard_%d" % (name, i)
+ var_init = make_wt_initializer(i, size)
+ sharded_weight.append(
+ variable_scope.variable(
+ var_init, dtype=dtypes.float32, name=var_name))
- return sharded_weight
+ return sharded_weight
@staticmethod
def _create_gramian(n_components, name):