diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-05-31 12:16:54 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-31 12:19:57 -0700 |
commit | 519189837b77181137505bf83054ddd962600f9b (patch) | |
tree | fc1b56cdca999b1b4da369d73742fcb4ace272ba /tensorflow/contrib/factorization | |
parent | fdf4d0813d4c0321be7b33698d00b165d90365b0 (diff) |
Making the tf.name_scope blocks related to the factor and weight vars configurable. By default they will not be scoped.
PiperOrigin-RevId: 198759754
Diffstat (limited to 'tensorflow/contrib/factorization')
-rw-r--r-- | tensorflow/contrib/factorization/python/ops/factorization_ops.py | 129 |
1 files changed, 74 insertions, 55 deletions
diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops.py b/tensorflow/contrib/factorization/python/ops/factorization_ops.py index 09745e2de5..8f73274c2a 100644 --- a/tensorflow/contrib/factorization/python/ops/factorization_ops.py +++ b/tensorflow/contrib/factorization/python/ops/factorization_ops.py @@ -197,7 +197,8 @@ class WALSModel(object): row_weights=1, col_weights=1, use_factors_weights_cache=True, - use_gramian_cache=True): + use_gramian_cache=True, + use_scoped_vars=False): """Creates model for WALS matrix factorization. Args: @@ -239,6 +240,8 @@ class WALSModel(object): weights cache to take effect. use_gramian_cache: When True, the Gramians will be cached on the workers before the updates start. Defaults to True. + use_scoped_vars: When True, the factor and weight vars will also be nested + in a tf.name_scope. """ self._input_rows = input_rows self._input_cols = input_cols @@ -251,18 +254,36 @@ class WALSModel(object): regularization * linalg_ops.eye(self._n_components) if regularization is not None else None) assert (row_weights is None) == (col_weights is None) - self._row_weights = WALSModel._create_weights( - row_weights, self._input_rows, self._num_row_shards, "row_weights") - self._col_weights = WALSModel._create_weights( - col_weights, self._input_cols, self._num_col_shards, "col_weights") self._use_factors_weights_cache = use_factors_weights_cache self._use_gramian_cache = use_gramian_cache - self._row_factors = self._create_factors( - self._input_rows, self._n_components, self._num_row_shards, row_init, - "row_factors") - self._col_factors = self._create_factors( - self._input_cols, self._n_components, self._num_col_shards, col_init, - "col_factors") + + if use_scoped_vars: + with ops.name_scope("row_weights"): + self._row_weights = WALSModel._create_weights( + row_weights, self._input_rows, self._num_row_shards, "row_weights") + with ops.name_scope("col_weights"): + self._col_weights = WALSModel._create_weights( + col_weights, self._input_cols, self._num_col_shards, "col_weights") + with ops.name_scope("row_factors"): + self._row_factors = self._create_factors( + self._input_rows, self._n_components, self._num_row_shards, + row_init, "row_factors") + with ops.name_scope("col_factors"): + self._col_factors = self._create_factors( + self._input_cols, self._n_components, self._num_col_shards, + col_init, "col_factors") + else: + self._row_weights = WALSModel._create_weights( + row_weights, self._input_rows, self._num_row_shards, "row_weights") + self._col_weights = WALSModel._create_weights( + col_weights, self._input_cols, self._num_col_shards, "col_weights") + self._row_factors = self._create_factors( + self._input_rows, self._n_components, self._num_row_shards, row_init, + "row_factors") + self._col_factors = self._create_factors( + self._input_cols, self._n_components, self._num_col_shards, col_init, + "col_factors") + self._row_gramian = self._create_gramian(self._n_components, "row_gramian") self._col_gramian = self._create_gramian(self._n_components, "col_gramian") with ops.name_scope("row_prepare_gramian"): @@ -313,37 +334,36 @@ class WALSModel(object): @classmethod def _create_factors(cls, rows, cols, num_shards, init, name): """Helper function to create row and column factors.""" - with ops.name_scope(name): - if callable(init): - init = init() - if isinstance(init, list): - assert len(init) == num_shards - elif isinstance(init, str) and init == "random": - pass - elif num_shards == 1: - init = [init] - sharded_matrix = [] - sizes = cls._shard_sizes(rows, num_shards) - assert len(sizes) == num_shards - - def make_initializer(i, size): - - def initializer(): - if init == "random": - return random_ops.random_normal([size, cols]) - else: - return init[i] + if callable(init): + init = init() + if isinstance(init, list): + assert len(init) == num_shards + elif isinstance(init, str) and init == "random": + pass + elif num_shards == 1: + init = [init] + sharded_matrix = [] + sizes = cls._shard_sizes(rows, num_shards) + assert len(sizes) == num_shards + + def make_initializer(i, size): - return initializer + def initializer(): + if init == "random": + return random_ops.random_normal([size, cols]) + else: + return init[i] - for i, size in enumerate(sizes): - var_name = "%s_shard_%d" % (name, i) - var_init = make_initializer(i, size) - sharded_matrix.append( - variable_scope.variable( - var_init, dtype=dtypes.float32, name=var_name)) + return initializer - return sharded_matrix + for i, size in enumerate(sizes): + var_name = "%s_shard_%d" % (name, i) + var_init = make_initializer(i, size) + sharded_matrix.append( + variable_scope.variable( + var_init, dtype=dtypes.float32, name=var_name)) + + return sharded_matrix @classmethod def _create_weights(cls, wt_init, num_wts, num_shards, name): @@ -384,26 +404,25 @@ class WALSModel(object): sizes = cls._shard_sizes(num_wts, num_shards) assert len(sizes) == num_shards - with ops.name_scope(name): - def make_wt_initializer(i, size): + def make_wt_initializer(i, size): - def initializer(): - if init_mode == "scalar": - return wt_init * array_ops.ones([size]) - else: - return wt_init[i] + def initializer(): + if init_mode == "scalar": + return wt_init * array_ops.ones([size]) + else: + return wt_init[i] - return initializer + return initializer - sharded_weight = [] - for i, size in enumerate(sizes): - var_name = "%s_shard_%d" % (name, i) - var_init = make_wt_initializer(i, size) - sharded_weight.append( - variable_scope.variable( - var_init, dtype=dtypes.float32, name=var_name)) + sharded_weight = [] + for i, size in enumerate(sizes): + var_name = "%s_shard_%d" % (name, i) + var_init = make_wt_initializer(i, size) + sharded_weight.append( + variable_scope.variable( + var_init, dtype=dtypes.float32, name=var_name)) - return sharded_weight + return sharded_weight @staticmethod def _create_gramian(n_components, name): |