aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/estimator/estimator.py2
-rw-r--r--tensorflow/python/training/checkpoint_ops.py3
-rw-r--r--tensorflow/python/training/warm_starting_util.py100
-rw-r--r--tensorflow/python/training/warm_starting_util_test.py140
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.estimator.-vocab-info.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.train.-vocab-info.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.estimator.-vocab-info.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.train.-vocab-info.pbtxt4
8 files changed, 235 insertions, 26 deletions
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index e44a69b374..0f20acefdf 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -2056,7 +2056,7 @@ class WarmStartSettings(
var_name_to_vocab_info: [Optional] Dict of variable names (strings) to
`tf.estimator.VocabInfo`. The variable names should be "full" variables,
not the names of the partitions. If not explicitly provided, the variable
- is assumed to have no vocabulary.
+ is assumed to have no (changes to) vocabulary.
var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to
name of the previously-trained variable in `ckpt_to_initialize_from`. If
not explicitly provided, the name of the variable is assumed to be same
diff --git a/tensorflow/python/training/checkpoint_ops.py b/tensorflow/python/training/checkpoint_ops.py
index a6e9662b73..cfd9b39ddc 100644
--- a/tensorflow/python/training/checkpoint_ops.py
+++ b/tensorflow/python/training/checkpoint_ops.py
@@ -268,7 +268,8 @@ def _load_and_remap_matrix_initializer(ckpt_path,
vocab files are the same, and no column remapping is done.
The returned initializer only supports div-partitioning along the row axis. It
- does not support partitioning along the column axis or mod-partitioning.
+ does not support partitioning along the column axis (as this is not common in
+ practice) or mod-partitioning.
NOTE: When this is used to warm-start variables, client code should use
`tf.lookup.index_table_from_tensor()` like
diff --git a/tensorflow/python/training/warm_starting_util.py b/tensorflow/python/training/warm_starting_util.py
index c0dd46bfa5..bea9bb6dff 100644
--- a/tensorflow/python/training/warm_starting_util.py
+++ b/tensorflow/python/training/warm_starting_util.py
@@ -41,6 +41,7 @@ class VocabInfo(
"old_vocab",
"old_vocab_size",
"backup_initializer",
+ "axis",
])):
"""Vocabulary information for warm-starting.
@@ -62,6 +63,42 @@ class VocabInfo(
backup_initializer: [Optional] A variable initializer used for variables
corresponding to new vocabulary entries and OOV. If not provided, these
entries will be zero-initialized.
+ axis: [Optional] Denotes what axis the vocabulary corresponds to. The
+ default, 0, corresponds to the most common use case (embeddings or
+ linear weights for binary classification / regression). An axis of 1
+ could be used for warm-starting output layers with class vocabularies.
+
+ For example:
+
+ embeddings_vocab_info = tf.VocabInfo(
+ new_vocab='embeddings_vocab',
+ new_vocab_size=100,
+ num_oov_buckets=1,
+ old_vocab='pretrained_embeddings_vocab',
+ old_vocab_size=10000,
+ backup_initializer=tf.truncated_normal_initializer(
+ mean=0.0, stddev=(1 / math.sqrt(embedding_dim))),
+ axis=0)
+
+ softmax_output_layer_kernel_vocab_info = tf.VocabInfo(
+ new_vocab='class_vocab',
+ new_vocab_size=5,
+ num_oov_buckets=0, # No OOV for classes.
+ old_vocab='old_class_vocab',
+ old_vocab_size=8,
+ backup_initializer=tf.glorot_uniform_initializer(),
+ axis=1)
+
+ softmax_output_layer_bias_vocab_info = tf.VocabInfo(
+ new_vocab='class_vocab',
+ new_vocab_size=5,
+ num_oov_buckets=0, # No OOV for classes.
+ old_vocab='old_class_vocab',
+ old_vocab_size=8,
+ backup_initializer=tf.zeros_initializer(),
+ axis=0)
+
+ Currently, only axis=0 and axis=1 are supported.
"""
def __new__(cls,
@@ -70,7 +107,12 @@ class VocabInfo(
num_oov_buckets,
old_vocab,
old_vocab_size=-1,
- backup_initializer=None):
+ backup_initializer=None,
+ axis=0):
+ if axis != 0 and axis != 1:
+ raise ValueError("The only supported values for the axis argument are 0 "
+ "and 1. Provided axis: {}".format(axis))
+
return super(VocabInfo, cls).__new__(
cls,
new_vocab,
@@ -79,6 +121,7 @@ class VocabInfo(
old_vocab,
old_vocab_size,
backup_initializer,
+ axis,
)
@@ -149,7 +192,8 @@ def _warm_start_var_with_vocab(var,
previous_vocab_size=-1,
current_oov_buckets=0,
prev_tensor_name=None,
- initializer=None):
+ initializer=None,
+ axis=0):
"""Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`.
Use this method when the `var` is backed by vocabulary. This method stitches
@@ -180,6 +224,7 @@ def _warm_start_var_with_vocab(var,
None, we lookup tensor with same name as given `var`.
initializer: Variable initializer to be used for missing entries. If None,
missing entries will be zero-initialized.
+ axis: Axis of the variable that the provided vocabulary corresponds to.
Raises:
ValueError: If required args are not provided.
@@ -204,6 +249,8 @@ def _warm_start_var_with_vocab(var,
# Assume tensor name remains the same.
prev_tensor_name = _infer_var_name(var)
+ # TODO(eddz): Fix functionality for rank-1 Variables (like FC biases).
+ total_v_first_axis = sum([v.get_shape().as_list()[0] for v in var])
for v in var:
v_shape = v.get_shape().as_list()
slice_info = v._get_save_slice_info()
@@ -213,19 +260,45 @@ def _warm_start_var_with_vocab(var,
full_shape=slice_info.full_shape,
var_offset=slice_info.var_offset)
- # TODO(eddz): Support cases where class vocabularies need remapping too.
+ if axis == 0:
+ new_row_vocab_size = current_vocab_size
+ new_col_vocab_size = v_shape[1]
+ old_row_vocab_size = previous_vocab_size
+ old_row_vocab_file = prev_vocab_path
+ new_row_vocab_file = current_vocab_path
+ old_col_vocab_file = None
+ new_col_vocab_file = None
+ num_row_oov_buckets = current_oov_buckets
+ num_col_oov_buckets = 0
+ elif axis == 1:
+ # Note that we must compute this value across all partitions, whereas
+ # in the axis = 0 case, we can simply use v_shape[1] because we don't
+ # allow partitioning across axis = 1.
+ new_row_vocab_size = total_v_first_axis
+ new_col_vocab_size = current_vocab_size
+ old_row_vocab_size = -1
+ old_row_vocab_file = None
+ new_row_vocab_file = None
+ old_col_vocab_file = prev_vocab_path
+ new_col_vocab_file = current_vocab_path
+ num_row_oov_buckets = 0
+ num_col_oov_buckets = current_oov_buckets
+ else:
+ raise ValueError("The only supported values for the axis argument are 0 "
+ "and 1. Provided axis: {}".format(axis))
+
init = checkpoint_ops._load_and_remap_matrix_initializer(
ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt),
old_tensor_name=prev_tensor_name,
- new_row_vocab_size=current_vocab_size,
- new_col_vocab_size=v_shape[1],
- old_row_vocab_size=previous_vocab_size,
- old_row_vocab_file=prev_vocab_path,
- new_row_vocab_file=current_vocab_path,
- old_col_vocab_file=None,
- new_col_vocab_file=None,
- num_row_oov_buckets=current_oov_buckets,
- num_col_oov_buckets=0,
+ new_row_vocab_size=new_row_vocab_size,
+ new_col_vocab_size=new_col_vocab_size,
+ old_row_vocab_size=old_row_vocab_size,
+ old_row_vocab_file=old_row_vocab_file,
+ new_row_vocab_file=new_row_vocab_file,
+ old_col_vocab_file=old_col_vocab_file,
+ new_col_vocab_file=new_col_vocab_file,
+ num_row_oov_buckets=num_row_oov_buckets,
+ num_col_oov_buckets=num_col_oov_buckets,
initializer=initializer)
new_init_val = ops.convert_to_tensor(
init(shape=v_shape, partition_info=partition_info))
@@ -374,7 +447,8 @@ def warm_start(ckpt_to_initialize_from,
previous_vocab_size=vocab_info.old_vocab_size,
current_oov_buckets=vocab_info.num_oov_buckets,
prev_tensor_name=prev_var_name,
- initializer=vocab_info.backup_initializer)
+ initializer=vocab_info.backup_initializer,
+ axis=vocab_info.axis)
else:
# For the special value of vars_to_warm_start = None,
# we only warm-start variables with explicitly specified vocabularies.
diff --git a/tensorflow/python/training/warm_starting_util_test.py b/tensorflow/python/training/warm_starting_util_test.py
index 70a84bc3f6..3ee0f6aaa2 100644
--- a/tensorflow/python/training/warm_starting_util_test.py
+++ b/tensorflow/python/training/warm_starting_util_test.py
@@ -107,7 +107,7 @@ class WarmStartingUtilTest(test.TestCase):
"fruit_weights", initializer=[[0.], [0.], [0.], [0.]])
ws_util._warm_start_var(fruit_weights, self.get_temp_dir())
sess.run(variables.global_variables_initializer())
- self.assertAllEqual(prev_val, fruit_weights.eval(sess))
+ self.assertAllClose(prev_val, fruit_weights.eval(sess))
def testWarmStartVarPrevVarPartitioned(self):
_, weights = self._create_prev_run_var(
@@ -123,7 +123,7 @@ class WarmStartingUtilTest(test.TestCase):
"fruit_weights", initializer=[[0.], [0.], [0.], [0.]])
ws_util._warm_start_var(fruit_weights, self.get_temp_dir())
sess.run(variables.global_variables_initializer())
- self.assertAllEqual(prev_val, fruit_weights.eval(sess))
+ self.assertAllClose(prev_val, fruit_weights.eval(sess))
def testWarmStartVarCurrentVarPartitioned(self):
_, prev_val = self._create_prev_run_var(
@@ -143,7 +143,7 @@ class WarmStartingUtilTest(test.TestCase):
fruit_weights = fruit_weights._get_variable_list()
new_val = np.concatenate(
[fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0)
- self.assertAllEqual(prev_val, new_val)
+ self.assertAllClose(prev_val, new_val)
def testWarmStartVarBothVarsPartitioned(self):
_, weights = self._create_prev_run_var(
@@ -170,7 +170,7 @@ class WarmStartingUtilTest(test.TestCase):
fruit_weights = fruit_weights._get_variable_list()
new_val = np.concatenate(
[fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0)
- self.assertAllEqual(prev_val, new_val)
+ self.assertAllClose(prev_val, new_val)
def testWarmStartVarWithVocab(self):
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
@@ -189,9 +189,34 @@ class WarmStartingUtilTest(test.TestCase):
ws_util._warm_start_var_with_vocab(fruit_weights, new_vocab_path, 5,
self.get_temp_dir(), prev_vocab_path)
sess.run(variables.global_variables_initializer())
- self.assertAllEqual([[2.], [1.5], [1.], [0.5], [0.]],
+ self.assertAllClose([[2.], [1.5], [1.], [0.5], [0.]],
fruit_weights.eval(sess))
+ def testWarmStartVarWithColumnVocab(self):
+ prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
+ self._create_prev_run_var(
+ "fruit_output_layer",
+ initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]])
+
+ # New vocab with elements in reverse order and one new element.
+ new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
+ "new_vocab")
+ # New session and new graph.
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as sess:
+ fruit_output_layer = variable_scope.get_variable(
+ "fruit_output_layer",
+ initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
+ [0., 0., 0.]])
+ ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
+ current_vocab_size=3,
+ prev_ckpt=self.get_temp_dir(),
+ prev_vocab_path=prev_vocab_path,
+ axis=1)
+ sess.run(variables.global_variables_initializer())
+ self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.], [1.2, 1.5, 0.],
+ [2.3, 2., 0.]], fruit_output_layer.eval(sess))
+
def testWarmStartVarWithVocabConstrainedOldVocabSize(self):
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"old_vocab")
@@ -215,7 +240,7 @@ class WarmStartingUtilTest(test.TestCase):
previous_vocab_size=2)
sess.run(variables.global_variables_initializer())
# Old vocabulary limited to ['apple', 'banana'].
- self.assertAllEqual([[0.], [0.], [1.], [0.5], [0.]],
+ self.assertAllClose([[0.], [0.], [1.], [0.5], [0.]],
fruit_weights.eval(sess))
def testWarmStartVarWithVocabPrevVarPartitioned(self):
@@ -238,9 +263,36 @@ class WarmStartingUtilTest(test.TestCase):
ws_util._warm_start_var_with_vocab(fruit_weights, new_vocab_path, 5,
self.get_temp_dir(), prev_vocab_path)
sess.run(variables.global_variables_initializer())
- self.assertAllEqual([[2.], [1.5], [1.], [0.5], [0.]],
+ self.assertAllClose([[2.], [1.5], [1.], [0.5], [0.]],
fruit_weights.eval(sess))
+ def testWarmStartVarWithColumnVocabPrevVarPartitioned(self):
+ prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
+ self._create_prev_run_var(
+ "fruit_output_layer",
+ shape=[4, 2],
+ initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]],
+ partitioner=lambda shape, dtype: [2, 1])
+
+ # New vocab with elements in reverse order and one new element.
+ new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
+ "new_vocab")
+ # New session and new graph.
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as sess:
+ fruit_output_layer = variable_scope.get_variable(
+ "fruit_output_layer",
+ initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
+ [0., 0., 0.]])
+ ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
+ current_vocab_size=3,
+ prev_ckpt=self.get_temp_dir(),
+ prev_vocab_path=prev_vocab_path,
+ axis=1)
+ sess.run(variables.global_variables_initializer())
+ self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.], [1.2, 1.5, 0.],
+ [2.3, 2., 0.]], fruit_output_layer.eval(sess))
+
def testWarmStartVarWithVocabCurrentVarPartitioned(self):
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"old_vocab")
@@ -269,11 +321,43 @@ class WarmStartingUtilTest(test.TestCase):
self.assertTrue(
isinstance(fruit_weights, variables.PartitionedVariable))
fruit_weights_vars = fruit_weights._get_variable_list()
- self.assertAllEqual([[2.], [1.5], [1.]],
+ self.assertAllClose([[2.], [1.5], [1.]],
fruit_weights_vars[0].eval(sess))
- self.assertAllEqual([[0.5], [0.], [0.]],
+ self.assertAllClose([[0.5], [0.], [0.]],
fruit_weights_vars[1].eval(sess))
+ def testWarmStartVarWithColumnVocabCurrentVarPartitioned(self):
+ prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
+ self._create_prev_run_var(
+ "fruit_output_layer",
+ initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]])
+
+ # New vocab with elements in reverse order and one new element.
+ new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
+ "new_vocab")
+ # New session and new graph.
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as sess:
+ fruit_output_layer = variable_scope.get_variable(
+ "fruit_output_layer",
+ shape=[4, 3],
+ initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
+ [0., 0., 0.]],
+ partitioner=lambda shape, dtype: [2, 1])
+ ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
+ current_vocab_size=3,
+ prev_ckpt=self.get_temp_dir(),
+ prev_vocab_path=prev_vocab_path,
+ axis=1)
+ sess.run(variables.global_variables_initializer())
+ self.assertTrue(
+ isinstance(fruit_output_layer, variables.PartitionedVariable))
+ fruit_output_layer_vars = fruit_output_layer._get_variable_list()
+ self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.]],
+ fruit_output_layer_vars[0].eval(sess))
+ self.assertAllClose([[1.2, 1.5, 0.], [2.3, 2., 0.]],
+ fruit_output_layer_vars[1].eval(sess))
+
def testWarmStartVarWithVocabBothVarsPartitioned(self):
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"old_vocab")
@@ -301,11 +385,45 @@ class WarmStartingUtilTest(test.TestCase):
self.assertTrue(
isinstance(fruit_weights, variables.PartitionedVariable))
fruit_weights_vars = fruit_weights._get_variable_list()
- self.assertAllEqual([[2.], [1.5], [1.]],
+ self.assertAllClose([[2.], [1.5], [1.]],
fruit_weights_vars[0].eval(sess))
- self.assertAllEqual([[0.5], [0.], [0.]],
+ self.assertAllClose([[0.5], [0.], [0.]],
fruit_weights_vars[1].eval(sess))
+ def testWarmStartVarWithColumnVocabBothVarsPartitioned(self):
+ prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
+ self._create_prev_run_var(
+ "fruit_output_layer",
+ shape=[4, 2],
+ initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]],
+ partitioner=lambda shape, dtype: [2, 1])
+
+ # New vocab with elements in reverse order and one new element.
+ new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
+ "new_vocab")
+ # New session and new graph.
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as sess:
+ fruit_output_layer = variable_scope.get_variable(
+ "fruit_output_layer",
+ shape=[4, 3],
+ initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
+ [0., 0., 0.]],
+ partitioner=lambda shape, dtype: [2, 1])
+ ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
+ current_vocab_size=3,
+ prev_ckpt=self.get_temp_dir(),
+ prev_vocab_path=prev_vocab_path,
+ axis=1)
+ sess.run(variables.global_variables_initializer())
+ self.assertTrue(
+ isinstance(fruit_output_layer, variables.PartitionedVariable))
+ fruit_output_layer_vars = fruit_output_layer._get_variable_list()
+ self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.]],
+ fruit_output_layer_vars[0].eval(sess))
+ self.assertAllClose([[1.2, 1.5, 0.], [2.3, 2., 0.]],
+ fruit_output_layer_vars[1].eval(sess))
+
def testWarmStart_ListOfVariables(self):
# Save checkpoint from which to warm-start.
_, prev_int_val = self._create_prev_run_var("v1", shape=[10, 1],
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-vocab-info.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-vocab-info.pbtxt
index 5301b94eb3..b6942cb7ed 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-vocab-info.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-vocab-info.pbtxt
@@ -4,6 +4,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
is_instance: "<type \'tuple\'>"
member {
+ name: "axis"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "backup_initializer"
mtype: "<type \'property\'>"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.-vocab-info.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-vocab-info.pbtxt
index 4ce7cb1111..39b946b82f 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.train.-vocab-info.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.train.-vocab-info.pbtxt
@@ -4,6 +4,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
is_instance: "<type \'tuple\'>"
member {
+ name: "axis"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "backup_initializer"
mtype: "<type \'property\'>"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-vocab-info.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-vocab-info.pbtxt
index 5301b94eb3..b6942cb7ed 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-vocab-info.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-vocab-info.pbtxt
@@ -4,6 +4,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
is_instance: "<type \'tuple\'>"
member {
+ name: "axis"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "backup_initializer"
mtype: "<type \'property\'>"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-vocab-info.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-vocab-info.pbtxt
index 4ce7cb1111..39b946b82f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.-vocab-info.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-vocab-info.pbtxt
@@ -4,6 +4,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
is_instance: "<type \'tuple\'>"
member {
+ name: "axis"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "backup_initializer"
mtype: "<type \'property\'>"
}