aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/warm_starting_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/warm_starting_util.py')
-rw-r--r--tensorflow/python/training/warm_starting_util.py100
1 files changed, 87 insertions, 13 deletions
diff --git a/tensorflow/python/training/warm_starting_util.py b/tensorflow/python/training/warm_starting_util.py
index c0dd46bfa5..bea9bb6dff 100644
--- a/tensorflow/python/training/warm_starting_util.py
+++ b/tensorflow/python/training/warm_starting_util.py
@@ -41,6 +41,7 @@ class VocabInfo(
"old_vocab",
"old_vocab_size",
"backup_initializer",
+ "axis",
])):
"""Vocabulary information for warm-starting.
@@ -62,6 +63,42 @@ class VocabInfo(
backup_initializer: [Optional] A variable initializer used for variables
corresponding to new vocabulary entries and OOV. If not provided, these
entries will be zero-initialized.
+ axis: [Optional] Denotes what axis the vocabulary corresponds to. The
+ default, 0, corresponds to the most common use case (embeddings or
+ linear weights for binary classification / regression). An axis of 1
+ could be used for warm-starting output layers with class vocabularies.
+
+ For example:
+
+ embeddings_vocab_info = tf.VocabInfo(
+ new_vocab='embeddings_vocab',
+ new_vocab_size=100,
+ num_oov_buckets=1,
+ old_vocab='pretrained_embeddings_vocab',
+ old_vocab_size=10000,
+ backup_initializer=tf.truncated_normal_initializer(
+ mean=0.0, stddev=(1 / math.sqrt(embedding_dim))),
+ axis=0)
+
+ softmax_output_layer_kernel_vocab_info = tf.VocabInfo(
+ new_vocab='class_vocab',
+ new_vocab_size=5,
+ num_oov_buckets=0, # No OOV for classes.
+ old_vocab='old_class_vocab',
+ old_vocab_size=8,
+ backup_initializer=tf.glorot_uniform_initializer(),
+ axis=1)
+
+ softmax_output_layer_bias_vocab_info = tf.VocabInfo(
+ new_vocab='class_vocab',
+ new_vocab_size=5,
+ num_oov_buckets=0, # No OOV for classes.
+ old_vocab='old_class_vocab',
+ old_vocab_size=8,
+ backup_initializer=tf.zeros_initializer(),
+ axis=0)
+
+ Currently, only axis=0 and axis=1 are supported.
"""
def __new__(cls,
@@ -70,7 +107,12 @@ class VocabInfo(
num_oov_buckets,
old_vocab,
old_vocab_size=-1,
- backup_initializer=None):
+ backup_initializer=None,
+ axis=0):
+ if axis != 0 and axis != 1:
+ raise ValueError("The only supported values for the axis argument are 0 "
+ "and 1. Provided axis: {}".format(axis))
+
return super(VocabInfo, cls).__new__(
cls,
new_vocab,
@@ -79,6 +121,7 @@ class VocabInfo(
old_vocab,
old_vocab_size,
backup_initializer,
+ axis,
)
@@ -149,7 +192,8 @@ def _warm_start_var_with_vocab(var,
previous_vocab_size=-1,
current_oov_buckets=0,
prev_tensor_name=None,
- initializer=None):
+ initializer=None,
+ axis=0):
"""Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`.
Use this method when the `var` is backed by vocabulary. This method stitches
@@ -180,6 +224,7 @@ def _warm_start_var_with_vocab(var,
None, we lookup tensor with same name as given `var`.
initializer: Variable initializer to be used for missing entries. If None,
missing entries will be zero-initialized.
+ axis: Axis of the variable that the provided vocabulary corresponds to.
Raises:
ValueError: If required args are not provided.
@@ -204,6 +249,8 @@ def _warm_start_var_with_vocab(var,
# Assume tensor name remains the same.
prev_tensor_name = _infer_var_name(var)
+ # TODO(eddz): Fix functionality for rank-1 Variables (like FC biases).
+ total_v_first_axis = sum([v.get_shape().as_list()[0] for v in var])
for v in var:
v_shape = v.get_shape().as_list()
slice_info = v._get_save_slice_info()
@@ -213,19 +260,45 @@ def _warm_start_var_with_vocab(var,
full_shape=slice_info.full_shape,
var_offset=slice_info.var_offset)
- # TODO(eddz): Support cases where class vocabularies need remapping too.
+ if axis == 0:
+ new_row_vocab_size = current_vocab_size
+ new_col_vocab_size = v_shape[1]
+ old_row_vocab_size = previous_vocab_size
+ old_row_vocab_file = prev_vocab_path
+ new_row_vocab_file = current_vocab_path
+ old_col_vocab_file = None
+ new_col_vocab_file = None
+ num_row_oov_buckets = current_oov_buckets
+ num_col_oov_buckets = 0
+ elif axis == 1:
+ # Note that we must compute this value across all partitions, whereas
+ # in the axis = 0 case, we can simply use v_shape[1] because we don't
+ # allow partitioning across axis = 1.
+ new_row_vocab_size = total_v_first_axis
+ new_col_vocab_size = current_vocab_size
+ old_row_vocab_size = -1
+ old_row_vocab_file = None
+ new_row_vocab_file = None
+ old_col_vocab_file = prev_vocab_path
+ new_col_vocab_file = current_vocab_path
+ num_row_oov_buckets = 0
+ num_col_oov_buckets = current_oov_buckets
+ else:
+ raise ValueError("The only supported values for the axis argument are 0 "
+ "and 1. Provided axis: {}".format(axis))
+
init = checkpoint_ops._load_and_remap_matrix_initializer(
ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt),
old_tensor_name=prev_tensor_name,
- new_row_vocab_size=current_vocab_size,
- new_col_vocab_size=v_shape[1],
- old_row_vocab_size=previous_vocab_size,
- old_row_vocab_file=prev_vocab_path,
- new_row_vocab_file=current_vocab_path,
- old_col_vocab_file=None,
- new_col_vocab_file=None,
- num_row_oov_buckets=current_oov_buckets,
- num_col_oov_buckets=0,
+ new_row_vocab_size=new_row_vocab_size,
+ new_col_vocab_size=new_col_vocab_size,
+ old_row_vocab_size=old_row_vocab_size,
+ old_row_vocab_file=old_row_vocab_file,
+ new_row_vocab_file=new_row_vocab_file,
+ old_col_vocab_file=old_col_vocab_file,
+ new_col_vocab_file=new_col_vocab_file,
+ num_row_oov_buckets=num_row_oov_buckets,
+ num_col_oov_buckets=num_col_oov_buckets,
initializer=initializer)
new_init_val = ops.convert_to_tensor(
init(shape=v_shape, partition_info=partition_info))
@@ -374,7 +447,8 @@ def warm_start(ckpt_to_initialize_from,
previous_vocab_size=vocab_info.old_vocab_size,
current_oov_buckets=vocab_info.num_oov_buckets,
prev_tensor_name=prev_var_name,
- initializer=vocab_info.backup_initializer)
+ initializer=vocab_info.backup_initializer,
+ axis=vocab_info.axis)
else:
# For the special value of vars_to_warm_start = None,
# we only warm-start variables with explicitly specified vocabularies.