aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-06 16:37:32 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-06 16:41:36 -0800
commit34ac5a65f2818fbc77fff39164a13d7ea9c32a34 (patch)
tree7f9e2073947ac1d8d98322fc8f1732a8f9e03615
parente641c402f60aabaca029393355626a54a99809a4 (diff)
Allow for an old_row_vocab_size, in case a subset of the old_row_vocab_file was used during the checkpoint creation (as is allowed in FeatureColumn._VocabularyListCategoricalColumn).
PiperOrigin-RevId: 174781749
-rw-r--r--tensorflow/core/kernels/generate_vocab_remapping_op.cc21
-rw-r--r--tensorflow/core/ops/checkpoint_ops.cc9
-rw-r--r--tensorflow/python/estimator/warm_starting_util.py43
-rw-r--r--tensorflow/python/estimator/warm_starting_util_test.py71
-rw-r--r--tensorflow/python/kernel_tests/checkpoint_ops_test.py15
-rw-r--r--tensorflow/python/training/checkpoint_ops.py25
-rw-r--r--tensorflow/python/training/checkpoint_ops_test.py81
7 files changed, 221 insertions, 44 deletions
diff --git a/tensorflow/core/kernels/generate_vocab_remapping_op.cc b/tensorflow/core/kernels/generate_vocab_remapping_op.cc
index 247c1f2457..2b97677e38 100644
--- a/tensorflow/core/kernels/generate_vocab_remapping_op.cc
+++ b/tensorflow/core/kernels/generate_vocab_remapping_op.cc
@@ -41,6 +41,8 @@ class GenerateVocabRemappingOp : public OpKernel {
OP_REQUIRES_OK(context,
context->GetAttr("new_vocab_offset", &new_vocab_offset_));
OP_REQUIRES_OK(context, context->GetAttr("num_new_vocab", &num_new_vocab_));
+ OP_REQUIRES_OK(context,
+ context->GetAttr("old_vocab_size", &old_vocab_size_));
}
void Compute(OpKernelContext* context) override {
@@ -92,16 +94,14 @@ class GenerateVocabRemappingOp : public OpKernel {
lookup::HashTable<string, int64>* old_vocab_table =
new lookup::HashTable<string, int64>(context, this);
core::ScopedUnref unref_old(old_vocab_table);
- // Note: we pass -1 (unknown) for vocab_size, which is supposed to be the
- // total elements in file. This is different from num_new_vocab_, which
- // accounts for partitioning.
- OP_REQUIRES_OK(context, lookup::InitializeTableFromTextFile(
- old_vocab_filename,
- -1, // vocab_size
- kUnusedLookupDelim,
- -2, // key_index, use the whole line/token.
- -1, // value_index, use the line number.
- context->env(), old_vocab_table));
+ // Note: If old_vocab_size_ is -1 (unknown), we retrieve all elements in
+ // file (see TextFileLineIterator).
+ OP_REQUIRES_OK(context,
+ lookup::InitializeTableFromTextFile(
+ old_vocab_filename, old_vocab_size_, kUnusedLookupDelim,
+ -2, // key_index, use the whole line/token.
+ -1, // value_index, use the line number.
+ context->env(), old_vocab_table));
// Fill out new_ids = [new_vocab_offset, new_vocab_offset + 1, ...,
// new_vocab_offset + num_new_vocab_]
@@ -165,6 +165,7 @@ class GenerateVocabRemappingOp : public OpKernel {
private:
int new_vocab_offset_;
int num_new_vocab_;
+ int old_vocab_size_;
};
REGISTER_KERNEL_BUILDER(Name("GenerateVocabRemapping").Device(DEVICE_CPU),
diff --git a/tensorflow/core/ops/checkpoint_ops.cc b/tensorflow/core/ops/checkpoint_ops.cc
index b49d7b4d40..08b00c8255 100644
--- a/tensorflow/core/ops/checkpoint_ops.cc
+++ b/tensorflow/core/ops/checkpoint_ops.cc
@@ -22,6 +22,7 @@ REGISTER_OP("GenerateVocabRemapping")
.Input("old_vocab_file: string")
.Attr("new_vocab_offset: int >= 0")
.Attr("num_new_vocab: int >= 0")
+ .Attr("old_vocab_size: int >= -1 = -1")
.Output("remapping: int64")
.Output("num_present: int32")
.SetShapeFn([](shape_inference::InferenceContext* c) {
@@ -43,7 +44,11 @@ Given a path to new and old vocabulary files, returns a remapping Tensor of
length `num_new_vocab`, where `remapping[i]` contains the row number in the old
vocabulary that corresponds to row `i` in the new vocabulary (starting at line
`new_vocab_offset` and up to `num_new_vocab` entities), or `-1` if entry `i`
-in the new vocabulary is not in the old vocabulary. `num_vocab_offset` enables
+in the new vocabulary is not in the old vocabulary. The old vocabulary is
+constrained to the first `old_vocab_size` entries if `old_vocab_size` is not the
+default value of -1.
+
+`num_vocab_offset` enables
use in the partitioned variable case, and should generally be set through
examining partitioning info. The format of the files should be a text file,
with each line containing a single entity within the vocabulary.
@@ -69,6 +74,8 @@ new_vocab_file: Path to the new vocab file.
old_vocab_file: Path to the old vocab file.
new_vocab_offset: How many entries into the new vocab file to start reading.
num_new_vocab: Number of entries in the new vocab file to remap.
+old_vocab_size: Number of entries in the old vocab file to consider. If -1,
+ use the entire old vocabulary.
remapping: A Tensor of length num_new_vocab where the element at index i
is equal to the old ID that maps to the new ID i. This element is -1 for any
new ID that is not found in the old vocabulary.
diff --git a/tensorflow/python/estimator/warm_starting_util.py b/tensorflow/python/estimator/warm_starting_util.py
index 3f0218af83..e5655db082 100644
--- a/tensorflow/python/estimator/warm_starting_util.py
+++ b/tensorflow/python/estimator/warm_starting_util.py
@@ -46,10 +46,13 @@ class _WarmStartSettings(
ckpt_to_initialize_from: [Required] A string specifying the directory with
checkpoint file(s) or path to checkpoint from which to warm-start the
model parameters.
- col_to_prev_vocab: [Optional] Dict of `FeatureColumn` to path of the
- vocabulary used for the `FeatureColumn` in `ckpt_to_initialize_from`. If
- not explicitly provided, the vocabularies are assumed to be same between
- previous and present checkpoints.
+ col_to_prev_vocab: [Optional] Dict of `FeatureColumn` to vocabularies used
+ for the `FeatureColumn` in `ckpt_to_initialize_from`. Vocabularies can
+ be represented either by a string (path to vocabulary), or tuple of
+ (string, int), representing (path of the vocabulary, vocab_size) if only
+ `vocab_size` entries of the old vocabulary were used in the checkpoint. If
+ the dict is not explicitly provided, the vocabularies are assumed to be
+ same between previous and present checkpoints.
col_to_prev_tensor: [Optional] Dict of `FeatureColumn` to name of the
variable (corresponding to the `FeatureColumn`) in
`ckpt_to_initialize_from`. If not explicitly provided, the name of the
@@ -76,6 +79,13 @@ class _WarmStartSettings(
col_to_prev_vocab={sc_vocab_file: "old_vocab.txt"})
# Warm-start all weights but the parameters corresponding to "sc_vocab_file"
+ # have a different vocab from the one used in current checkpoint, and only
+ # 100 of those entries were used.
+ ws = _WarmStartSettings(ckpt_to_initialize_from="/tmp",
+ col_to_prev_vocab={sc_vocab_file:
+ ("old_vocab.txt", 100)})
+
+ # Warm-start all weights but the parameters corresponding to "sc_vocab_file"
# have a different vocab from the one used in current checkpoint and the
# parameters corresponding to "sc_vocab_list" have a different name from the
# current checkpoint.
@@ -214,6 +224,7 @@ def _warmstart_var_with_vocab(var,
current_vocab_size,
prev_ckpt,
prev_vocab_path,
+ previous_vocab_size=-1,
current_oov_buckets=0,
prev_tensor_name=None,
initializer=None):
@@ -239,6 +250,8 @@ def _warmstart_var_with_vocab(var,
to checkpoint. The given checkpoint must have tensor with name
`prev_tensor_name` (if not None) or tensor with name same as given `var`.
prev_vocab_path: Path to the vocab file used for the tensor in `prev_ckpt`.
+ previous_vocab_size: If provided, will constrain previous vocab to the first
+ `previous_vocab_size` entries. -1 means use the entire previous vocab.
current_oov_buckets: An `int` specifying the number of out-of-vocabulary
buckets used for given `var`.
prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
@@ -284,6 +297,7 @@ def _warmstart_var_with_vocab(var,
old_tensor_name=prev_tensor_name,
new_row_vocab_size=current_vocab_size,
new_col_vocab_size=v_shape[1],
+ old_row_vocab_size=previous_vocab_size,
old_row_vocab_file=prev_vocab_path,
new_row_vocab_file=current_vocab_path,
old_col_vocab_file=None,
@@ -373,17 +387,30 @@ def _warmstart_input_layer(cols_to_vars, warmstart_settings):
vocabulary_file = col.vocabulary_file
vocabulary_size = col.vocabulary_size
num_oov_buckets = col.num_oov_buckets
- prev_vocab_path = warmstart_settings.col_to_prev_vocab.get(
+ prev_vocab = warmstart_settings.col_to_prev_vocab.get(
col, vocabulary_file)
- logging.info("Warm-starting column: {}; prev_vocab: {}; prev_tensor: {}".
- format(col.name, prev_vocab_path, (
- prev_tensor_name or "Unchanged")))
+ if isinstance(prev_vocab, str):
+ prev_vocab_path = prev_vocab
+ previous_vocab_size = -1
+ logging.info(
+ "Warm-starting column: {}; prev_vocab: {}; "
+ "prev_tensor: {}".format(col.name, prev_vocab_path,
+ (prev_tensor_name or "Unchanged")))
+ elif isinstance(prev_vocab, tuple):
+ prev_vocab_path = prev_vocab[0]
+ previous_vocab_size = prev_vocab[1]
+ logging.info("Warm-starting column: {}; prev_vocab: {} (first {} "
+ "entries); prev_tensor: {}".format(
+ col.name, prev_vocab_path, previous_vocab_size,
+ (prev_tensor_name or "Unchanged")))
+
_warmstart_var_with_vocab(
var,
current_vocab_path=vocabulary_file,
current_vocab_size=vocabulary_size,
prev_ckpt=warmstart_settings.ckpt_to_initialize_from,
prev_vocab_path=prev_vocab_path,
+ previous_vocab_size=previous_vocab_size,
current_oov_buckets=num_oov_buckets,
prev_tensor_name=prev_tensor_name,
initializer=initializer)
diff --git a/tensorflow/python/estimator/warm_starting_util_test.py b/tensorflow/python/estimator/warm_starting_util_test.py
index f488957fb4..a05dbfd744 100644
--- a/tensorflow/python/estimator/warm_starting_util_test.py
+++ b/tensorflow/python/estimator/warm_starting_util_test.py
@@ -318,6 +318,32 @@ class WarmStartingUtilTest(test.TestCase):
self.assertAllEqual([[2.], [1.5], [1.], [0.5], [0.]],
fruit_weights.eval(sess))
+ def testWarmStartVarWithVocabConstrainedOldVocabSize(self):
+ prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
+ "old_vocab")
+ _, _ = self._create_prev_run_var(
+ "fruit_weights", initializer=[[0.5], [1.], [1.5], [2.]])
+
+ # New vocab with elements in reverse order and one new element.
+ new_vocab_path = self._write_vocab(
+ ["orange", "guava", "banana", "apple", "raspberry"], "new_vocab")
+ # New session and new graph.
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as sess:
+ fruit_weights = variable_scope.get_variable(
+ "fruit_weights", initializer=[[0.], [0.], [0.], [0.], [0.]])
+ ws_util._warmstart_var_with_vocab(
+ fruit_weights,
+ new_vocab_path,
+ 5,
+ self.get_temp_dir(),
+ prev_vocab_path,
+ previous_vocab_size=2)
+ sess.run(variables.global_variables_initializer())
+ # Old vocabulary limited to ['apple', 'banana'].
+ self.assertAllEqual([[0.], [0.], [1.], [0.5], [0.]],
+ fruit_weights.eval(sess))
+
def testWarmStartVarWithVocabPrevVarPartitioned(self):
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"old_vocab")
@@ -507,6 +533,51 @@ class WarmStartingUtilTest(test.TestCase):
self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]},
sess)
+ def testWarmStartInputLayer_SparseColumnVocabularyConstrainedVocabSizes(self):
+ # Create old vocabulary, and use a size smaller than the total number of
+ # entries.
+ old_vocab_path = self._write_vocab(["apple", "guava", "banana"],
+ "old_vocab")
+ old_vocab_size = 2 # ['apple', 'guava']
+
+ # Create new vocab for sparse column "sc_vocab".
+ current_vocab_path = self._write_vocab(
+ ["apple", "banana", "guava", "orange"], "current_vocab")
+ # Create feature column. Only use 2 of the actual entries, resulting in
+ # ['apple', 'banana'] for the new vocabulary.
+ sc_vocab = fc.categorical_column_with_vocabulary_file(
+ "sc_vocab", vocabulary_file=current_vocab_path, vocabulary_size=2)
+
+ # Save checkpoint from which to warm-start.
+ self._create_prev_run_var(
+ "linear_model/sc_vocab/weights", shape=[2, 1], initializer=ones())
+
+ partitioner = lambda shape, dtype: [1] * len(shape)
+ # New graph, new session WITHOUT warmstarting.
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as sess:
+ cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
+ sess.run(variables.global_variables_initializer())
+ # Without warmstarting, the weights should be initialized using default
+ # initializer (which is init_ops.zeros_initializer).
+ self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [np.zeros([2, 1])]},
+ sess)
+
+ # New graph, new session with warmstarting.
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as sess:
+ cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
+ warmstart_settings = ws_util._WarmStartSettings(
+ ckpt_to_initialize_from=self.get_temp_dir(),
+ col_to_prev_vocab={
+ sc_vocab: (old_vocab_path, old_vocab_size)
+ })
+ ws_util._warmstart_input_layer(cols_to_vars, warmstart_settings)
+ sess.run(variables.global_variables_initializer())
+ # Verify weights were correctly warmstarted. 'banana' isn't in the
+ # first two entries of the old vocabulary, so it's newly initialized.
+ self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [[[1], [0]]]}, sess)
+
def testWarmStartInputLayer_BucketizedColumn(self):
# Create feature column.
real = fc.numeric_column("real")
diff --git a/tensorflow/python/kernel_tests/checkpoint_ops_test.py b/tensorflow/python/kernel_tests/checkpoint_ops_test.py
index d2eb3eb801..a786d0a47e 100644
--- a/tensorflow/python/kernel_tests/checkpoint_ops_test.py
+++ b/tensorflow/python/kernel_tests/checkpoint_ops_test.py
@@ -87,6 +87,21 @@ class GenerateVocabRemappingTest(test.TestCase):
self.assertAllEqual(expected_remapping, remapping.eval())
self.assertAllEqual(expected_num_present, num_present.eval())
+ def test_generate_remapping_with_old_vocab_size(self):
+ """Tests where old_vocab_size is specified."""
+ remapping, num_present = gen_checkpoint_ops._generate_vocab_remapping(
+ new_vocab_file=self.new_vocab_file,
+ old_vocab_file=self.old_vocab_file,
+ num_new_vocab=3,
+ new_vocab_offset=0,
+ # Old vocabulary becomes ['knitting', 'eminem'].
+ old_vocab_size=2)
+ expected_remapping = [-1, 0, 1]
+ expected_num_present = 2
+ with self.test_session():
+ self.assertAllEqual(expected_remapping, remapping.eval())
+ self.assertAllEqual(expected_num_present, num_present.eval())
+
class LoadAndRemapMatrixTest(test.TestCase):
"""Tests for the load_and_remap_matrix() op."""
diff --git a/tensorflow/python/training/checkpoint_ops.py b/tensorflow/python/training/checkpoint_ops.py
index 0769ccd3d1..7f92d94d2b 100644
--- a/tensorflow/python/training/checkpoint_ops.py
+++ b/tensorflow/python/training/checkpoint_ops.py
@@ -36,6 +36,7 @@ def _load_and_remap_matrix(ckpt_path,
num_rows_to_load,
new_col_vocab_size,
initializer,
+ old_row_vocab_size=-1,
old_row_vocab_file=None,
new_row_vocab_file=None,
old_col_vocab_file=None,
@@ -75,6 +76,12 @@ def _load_and_remap_matrix(ckpt_path,
initializer: Callable initializer function that accepts a 1-D tensor as the
arg to specify the shape of the returned tensor. Used to initialize
missing values.
+ old_row_vocab_size: The number of entries to consider in the old vocabulary.
+ With the default value of -1, the entire old row vocabulary file will be
+ used. Otherwise, only the first `old_row_vocab_size` entries will be
+ considered for remapping.Must be smaller than the length of
+ `old_row_vocab_file`. NOTE: we do not provide an equivalent
+ `old_col_vocab_size` for classes.
old_row_vocab_file: A scalar `Tensor` of type `string` containing the
path to the old row vocabulary file. Can be None, which represents no
remapping on the row axis.
@@ -146,7 +153,8 @@ def _load_and_remap_matrix(ckpt_path,
new_vocab_file=new_row_vocab_file,
old_vocab_file=old_row_vocab_file,
new_vocab_offset=new_row_vocab_offset,
- num_new_vocab=num_rows_to_load))
+ num_new_vocab=num_rows_to_load,
+ old_vocab_size=old_row_vocab_size))
else:
# Even when the rows are not being reordered, we still need to generate a
# remapping to account for initializing partitioned Variables (when
@@ -199,6 +207,7 @@ def _load_and_remap_matrix_initializer(ckpt_path,
old_tensor_name,
new_row_vocab_size,
new_col_vocab_size,
+ old_row_vocab_size=-1,
old_row_vocab_file=None,
new_row_vocab_file=None,
old_col_vocab_file=None,
@@ -280,6 +289,12 @@ def _load_and_remap_matrix_initializer(ckpt_path,
`new_col_vocab_file`. If no column remapping is needed (no column vocab
provided), this should be equal to the number of columns in the old
matrix.
+ old_row_vocab_size: The number of entries to consider in the old vocabulary.
+ With the default value of -1, the entire old row vocabulary file will be
+ used. Otherwise, only the first `old_row_vocab_size` entries will be
+ considered for remapping.Must be smaller than the length of
+ `old_row_vocab_file`. NOTE: we do not provide an equivalent
+ `old_col_vocab_size` for classes.
old_row_vocab_file: A scalar `Tensor` of type `string` containing the
path to the old row vocabulary file. Can be None, which represents no
remapping on the row axis.
@@ -388,6 +403,7 @@ def _load_and_remap_matrix_initializer(ckpt_path,
num_rows_to_load=num_rows_to_load,
new_col_vocab_size=new_col_vocab_size,
initializer=initializer,
+ old_row_vocab_size=old_row_vocab_size,
old_row_vocab_file=old_row_vocab_file,
new_row_vocab_file=new_row_vocab_file,
old_col_vocab_file=old_col_vocab_file,
@@ -405,6 +421,7 @@ def _load_embedding_initializer(ckpt_path,
embedding_dim,
old_vocab_file,
new_vocab_file,
+ old_vocab_size=-1,
num_oov_buckets=0,
initializer=None,
max_rows_in_memory=-1):
@@ -428,6 +445,11 @@ def _load_embedding_initializer(ckpt_path,
path to the old vocabulary file.
new_vocab_file: A scalar `Tensor` of type `string` containing the
path to the new vocabulary file.
+ old_vocab_size: The number of entries to consider in the old vocabulary.
+ With the default value of -1, the entire old row vocabulary file will be
+ used. Otherwise, only the first `old_vocab_size` entries will be
+ considered for remapping.Must be smaller than the length of
+ `old_row_vocab_file`.
num_oov_buckets: `int` specifying the number of out-of-vocabulary
buckets to use. Must be >= 0.
initializer: Initializer function that accepts a 1-D tensor as the arg to
@@ -452,6 +474,7 @@ def _load_embedding_initializer(ckpt_path,
old_tensor_name=embedding_tensor_name,
new_row_vocab_size=new_vocab_size,
new_col_vocab_size=embedding_dim,
+ old_row_vocab_size=old_vocab_size,
old_row_vocab_file=old_vocab_file,
new_row_vocab_file=new_vocab_file,
old_col_vocab_file=None,
diff --git a/tensorflow/python/training/checkpoint_ops_test.py b/tensorflow/python/training/checkpoint_ops_test.py
index b578dde251..00611de862 100644
--- a/tensorflow/python/training/checkpoint_ops_test.py
+++ b/tensorflow/python/training/checkpoint_ops_test.py
@@ -103,7 +103,7 @@ class LoadAndRemapWrappersTest(test.TestCase):
num_col_oov_buckets=1)
# [4 in vocab + 1 oov features, 4 in vocab + 1 oov classes]. The offset
- # means we read
+ # means we read from the first line.
expected_remapped_matrix = np.concatenate(
[
np.reshape([18, 34, 50, self.init_val, self.init_val], [5, 1]),
@@ -132,6 +132,9 @@ class LoadAndRemapWrappersTest(test.TestCase):
num_col_oov_buckets=1,
initializer=self.initializer))
+ # The new weight matrix is of size
+ # [5 feature vocab + 1 feature OOV, 4 class vocab + 1 class OOV]. Use a
+ # partitioned variable to confirm that the offset logic works.
expected_remapped_matrix = np.concatenate(
[
np.reshape([2, 18, 34, 50, self.init_val, self.init_val], [6, 1]),
@@ -141,10 +144,6 @@ class LoadAndRemapWrappersTest(test.TestCase):
np.reshape([self.init_val] * 6, [6, 1])
],
axis=1)
-
- # The new weight matrix is of size
- # [5 feature vocab + 1 feature OOV, 4 class vocab + 1 class OOV]. Use a
- # partitioned variable to confirm that the offset logic works.
remapped_matrix = variable_scope.get_variable(
name='linear/obtained_weight_matrix',
shape=[6, 5],
@@ -168,6 +167,8 @@ class LoadAndRemapWrappersTest(test.TestCase):
num_col_oov_buckets=1,
initializer=self.initializer))
+ # The new weight matrix is of size
+ # [5-sized input layer, 4 class vocab + 1 class OOV].
expected_remapped_matrix = np.concatenate(
[
np.reshape([2, 18, 34, 50, 66], [5, 1]),
@@ -177,9 +178,6 @@ class LoadAndRemapWrappersTest(test.TestCase):
np.reshape([self.init_val] * 5, [5, 1])
],
axis=1)
-
- # The new weight matrix is of size
- # [5-sized input layer, 4 class vocab + 1 class OOV].
remapped_matrix = variable_scope.get_variable(
name='dnn_output/obtained_weight_matrix',
shape=[5, 5],
@@ -206,6 +204,9 @@ class LoadAndRemapWrappersTest(test.TestCase):
num_col_oov_buckets=1,
initializer=self.initializer))
+ # The new weight matrix is of size
+ # [5 feature vocab + 5 feature OOV, 4 class vocab + 1 class OOV]. The
+ # second partition has only OOV.
expected_remapped_matrix = np.concatenate(
[
np.reshape([2, 18, 34, 50] + [self.init_val] * 6, [10, 1]),
@@ -215,10 +216,6 @@ class LoadAndRemapWrappersTest(test.TestCase):
np.reshape([self.init_val] * 10, [10, 1]),
],
axis=1)
-
- # The new weight matrix is of size
- # [5 feature vocab + 5 feature OOV, 4 class vocab + 1 class OOV]. The
- # second partition has only OOV.
remapped_matrix = variable_scope.get_variable(
name='linear_all_oov/obtained_weight_matrix',
shape=[10, 5],
@@ -244,6 +241,8 @@ class LoadAndRemapWrappersTest(test.TestCase):
num_row_oov_buckets=1,
num_col_oov_buckets=1))
+ # Same as test_initializer_with_oov_only_partition, but with zero
+ # initialization.
expected_remapped_matrix = np.concatenate(
[
np.reshape([2, 18, 34, 50, 0, 0], [6, 1]),
@@ -253,7 +252,6 @@ class LoadAndRemapWrappersTest(test.TestCase):
np.reshape([0] * 6, [6, 1])
],
axis=1)
-
remapped_matrix = variable_scope.get_variable(
name='linear_init_fallback/obtained_weight_matrix',
shape=[6, 5],
@@ -277,18 +275,17 @@ class LoadAndRemapWrappersTest(test.TestCase):
num_oov_buckets=1,
initializer=self.initializer))
+ # The new weight matrix is of size
+ # [5 feature vocab + 1 feature OOV, 16 (embedding dimension)], where the
+ # last vocab row (2nd last row) is newly initialized (wasn't found in
+ # previous vocab) and the actual last row is OOV and also newly initialized.
+ # Use a partitioned variable to confirm that the offset logic works.
expected_remapped_embeddings = np.concatenate(
[
np.reshape(range(64), [4, 16]),
np.reshape([self.init_val] * 32, [2, 16]),
],
axis=0)
-
- # The new weight matrix is of size
- # [5 feature vocab + 1 feature OOV, 16 (embedding dimension)], where the
- # last vocab row (2nd last row) is newly initialized (wasn't found in
- # previous vocab) and the actual last row is OOV and also newly initialized.
- # Use a partitioned variable to confirm that the offset logic works.
remapped_embeddings = variable_scope.get_variable(
name='embedding/obtained_embedding_matrix',
shape=[6, 16],
@@ -323,6 +320,11 @@ class LoadAndRemapWrappersTest(test.TestCase):
num_oov_buckets=5,
initializer=self.initializer))
+ # The new weight matrix is of size
+ # [4 feature vocab + 5 feature OOV, 16 (embedding dimension)], where the
+ # 3rd and 4th rows are not found in the old vocabulary and therefore newly
+ # initialized. The last five rows are OOV and also newly initialized.
+ # Use a partitioned variable to confirm that the offset logic works.
expected_remapped_embeddings = np.concatenate(
[
np.reshape(range(16, 32), [1, 16]),
@@ -330,15 +332,47 @@ class LoadAndRemapWrappersTest(test.TestCase):
np.reshape([self.init_val] * 112, [7, 16]),
],
axis=0)
+ remapped_embeddings = variable_scope.get_variable(
+ name='embedding/obtained_embedding_matrix',
+ shape=[9, 16],
+ initializer=embedding_loading_initializer,
+ partitioner=partitioned_variables.fixed_size_partitioner(2))
+
+ with self.test_session():
+ variables.global_variables_initializer().run()
+ self.assertAllClose(expected_remapped_embeddings,
+ remapped_embeddings.as_tensor().eval())
+
+ def test_load_embedding_initializer_old_row_vocab(self):
+ """Tests for load_embedding_initializer where we constrain old vocab."""
+ embedding_loading_initializer = (
+ checkpoint_ops._load_embedding_initializer(
+ new_vocab_file=self.new_feature_vocab_file,
+ old_vocab_file=self.old_feature_vocab_file,
+ # Considered old vocabulary becomes ['zero', 'one', 'two']. This
+ # means 'three' in the new vocabulary is newly initialized.
+ old_vocab_size=3,
+ new_vocab_size=5,
+ embedding_dim=16,
+ embedding_tensor_name='some_scope/embeddings',
+ ckpt_path=[self.checkpoint_file],
+ num_oov_buckets=1,
+ initializer=self.initializer))
# The new weight matrix is of size
- # [4 feature vocab + 5 feature OOV, 16 (embedding dimension)], where the
- # 3rd and 4th rows are not found in the old vocabulary and therefore newly
- # initialized. The last five rows are OOV and also newly initialized.
+ # [5 feature vocab + 1 feature OOV, 16 (embedding dimension)], where the
+ # last vocab row (2nd last row) is newly initialized (wasn't found in
+ # previous vocab) and the actual last row is OOV and also newly initialized.
# Use a partitioned variable to confirm that the offset logic works.
+ expected_remapped_embeddings = np.concatenate(
+ [
+ np.reshape(range(48), [3, 16]),
+ np.reshape([self.init_val] * 48, [3, 16]),
+ ],
+ axis=0)
remapped_embeddings = variable_scope.get_variable(
name='embedding/obtained_embedding_matrix',
- shape=[9, 16],
+ shape=[6, 16],
initializer=embedding_loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(2))
@@ -347,6 +381,5 @@ class LoadAndRemapWrappersTest(test.TestCase):
self.assertAllClose(expected_remapped_embeddings,
remapped_embeddings.as_tensor().eval())
-
if __name__ == '__main__':
test.main()