diff options
Diffstat (limited to 'tensorflow/python/training/warm_starting_util_test.py')
-rw-r--r-- | tensorflow/python/training/warm_starting_util_test.py | 140 |
1 files changed, 129 insertions, 11 deletions
diff --git a/tensorflow/python/training/warm_starting_util_test.py b/tensorflow/python/training/warm_starting_util_test.py index 70a84bc3f6..3ee0f6aaa2 100644 --- a/tensorflow/python/training/warm_starting_util_test.py +++ b/tensorflow/python/training/warm_starting_util_test.py @@ -107,7 +107,7 @@ class WarmStartingUtilTest(test.TestCase): "fruit_weights", initializer=[[0.], [0.], [0.], [0.]]) ws_util._warm_start_var(fruit_weights, self.get_temp_dir()) sess.run(variables.global_variables_initializer()) - self.assertAllEqual(prev_val, fruit_weights.eval(sess)) + self.assertAllClose(prev_val, fruit_weights.eval(sess)) def testWarmStartVarPrevVarPartitioned(self): _, weights = self._create_prev_run_var( @@ -123,7 +123,7 @@ class WarmStartingUtilTest(test.TestCase): "fruit_weights", initializer=[[0.], [0.], [0.], [0.]]) ws_util._warm_start_var(fruit_weights, self.get_temp_dir()) sess.run(variables.global_variables_initializer()) - self.assertAllEqual(prev_val, fruit_weights.eval(sess)) + self.assertAllClose(prev_val, fruit_weights.eval(sess)) def testWarmStartVarCurrentVarPartitioned(self): _, prev_val = self._create_prev_run_var( @@ -143,7 +143,7 @@ class WarmStartingUtilTest(test.TestCase): fruit_weights = fruit_weights._get_variable_list() new_val = np.concatenate( [fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0) - self.assertAllEqual(prev_val, new_val) + self.assertAllClose(prev_val, new_val) def testWarmStartVarBothVarsPartitioned(self): _, weights = self._create_prev_run_var( @@ -170,7 +170,7 @@ class WarmStartingUtilTest(test.TestCase): fruit_weights = fruit_weights._get_variable_list() new_val = np.concatenate( [fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0) - self.assertAllEqual(prev_val, new_val) + self.assertAllClose(prev_val, new_val) def testWarmStartVarWithVocab(self): prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"], @@ -189,9 +189,34 @@ class WarmStartingUtilTest(test.TestCase): ws_util._warm_start_var_with_vocab(fruit_weights, new_vocab_path, 5, self.get_temp_dir(), prev_vocab_path) sess.run(variables.global_variables_initializer()) - self.assertAllEqual([[2.], [1.5], [1.], [0.5], [0.]], + self.assertAllClose([[2.], [1.5], [1.], [0.5], [0.]], fruit_weights.eval(sess)) + def testWarmStartVarWithColumnVocab(self): + prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab") + self._create_prev_run_var( + "fruit_output_layer", + initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]]) + + # New vocab with elements in reverse order and one new element. + new_vocab_path = self._write_vocab(["orange", "apple", "banana"], + "new_vocab") + # New session and new graph. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + fruit_output_layer = variable_scope.get_variable( + "fruit_output_layer", + initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], + [0., 0., 0.]]) + ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path, + current_vocab_size=3, + prev_ckpt=self.get_temp_dir(), + prev_vocab_path=prev_vocab_path, + axis=1) + sess.run(variables.global_variables_initializer()) + self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.], [1.2, 1.5, 0.], + [2.3, 2., 0.]], fruit_output_layer.eval(sess)) + def testWarmStartVarWithVocabConstrainedOldVocabSize(self): prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"], "old_vocab") @@ -215,7 +240,7 @@ class WarmStartingUtilTest(test.TestCase): previous_vocab_size=2) sess.run(variables.global_variables_initializer()) # Old vocabulary limited to ['apple', 'banana']. - self.assertAllEqual([[0.], [0.], [1.], [0.5], [0.]], + self.assertAllClose([[0.], [0.], [1.], [0.5], [0.]], fruit_weights.eval(sess)) def testWarmStartVarWithVocabPrevVarPartitioned(self): @@ -238,9 +263,36 @@ class WarmStartingUtilTest(test.TestCase): ws_util._warm_start_var_with_vocab(fruit_weights, new_vocab_path, 5, self.get_temp_dir(), prev_vocab_path) sess.run(variables.global_variables_initializer()) - self.assertAllEqual([[2.], [1.5], [1.], [0.5], [0.]], + self.assertAllClose([[2.], [1.5], [1.], [0.5], [0.]], fruit_weights.eval(sess)) + def testWarmStartVarWithColumnVocabPrevVarPartitioned(self): + prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab") + self._create_prev_run_var( + "fruit_output_layer", + shape=[4, 2], + initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]], + partitioner=lambda shape, dtype: [2, 1]) + + # New vocab with elements in reverse order and one new element. + new_vocab_path = self._write_vocab(["orange", "apple", "banana"], + "new_vocab") + # New session and new graph. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + fruit_output_layer = variable_scope.get_variable( + "fruit_output_layer", + initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], + [0., 0., 0.]]) + ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path, + current_vocab_size=3, + prev_ckpt=self.get_temp_dir(), + prev_vocab_path=prev_vocab_path, + axis=1) + sess.run(variables.global_variables_initializer()) + self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.], [1.2, 1.5, 0.], + [2.3, 2., 0.]], fruit_output_layer.eval(sess)) + def testWarmStartVarWithVocabCurrentVarPartitioned(self): prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"], "old_vocab") @@ -269,11 +321,43 @@ class WarmStartingUtilTest(test.TestCase): self.assertTrue( isinstance(fruit_weights, variables.PartitionedVariable)) fruit_weights_vars = fruit_weights._get_variable_list() - self.assertAllEqual([[2.], [1.5], [1.]], + self.assertAllClose([[2.], [1.5], [1.]], fruit_weights_vars[0].eval(sess)) - self.assertAllEqual([[0.5], [0.], [0.]], + self.assertAllClose([[0.5], [0.], [0.]], fruit_weights_vars[1].eval(sess)) + def testWarmStartVarWithColumnVocabCurrentVarPartitioned(self): + prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab") + self._create_prev_run_var( + "fruit_output_layer", + initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]]) + + # New vocab with elements in reverse order and one new element. + new_vocab_path = self._write_vocab(["orange", "apple", "banana"], + "new_vocab") + # New session and new graph. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + fruit_output_layer = variable_scope.get_variable( + "fruit_output_layer", + shape=[4, 3], + initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], + [0., 0., 0.]], + partitioner=lambda shape, dtype: [2, 1]) + ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path, + current_vocab_size=3, + prev_ckpt=self.get_temp_dir(), + prev_vocab_path=prev_vocab_path, + axis=1) + sess.run(variables.global_variables_initializer()) + self.assertTrue( + isinstance(fruit_output_layer, variables.PartitionedVariable)) + fruit_output_layer_vars = fruit_output_layer._get_variable_list() + self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.]], + fruit_output_layer_vars[0].eval(sess)) + self.assertAllClose([[1.2, 1.5, 0.], [2.3, 2., 0.]], + fruit_output_layer_vars[1].eval(sess)) + def testWarmStartVarWithVocabBothVarsPartitioned(self): prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"], "old_vocab") @@ -301,11 +385,45 @@ class WarmStartingUtilTest(test.TestCase): self.assertTrue( isinstance(fruit_weights, variables.PartitionedVariable)) fruit_weights_vars = fruit_weights._get_variable_list() - self.assertAllEqual([[2.], [1.5], [1.]], + self.assertAllClose([[2.], [1.5], [1.]], fruit_weights_vars[0].eval(sess)) - self.assertAllEqual([[0.5], [0.], [0.]], + self.assertAllClose([[0.5], [0.], [0.]], fruit_weights_vars[1].eval(sess)) + def testWarmStartVarWithColumnVocabBothVarsPartitioned(self): + prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab") + self._create_prev_run_var( + "fruit_output_layer", + shape=[4, 2], + initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]], + partitioner=lambda shape, dtype: [2, 1]) + + # New vocab with elements in reverse order and one new element. + new_vocab_path = self._write_vocab(["orange", "apple", "banana"], + "new_vocab") + # New session and new graph. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + fruit_output_layer = variable_scope.get_variable( + "fruit_output_layer", + shape=[4, 3], + initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], + [0., 0., 0.]], + partitioner=lambda shape, dtype: [2, 1]) + ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path, + current_vocab_size=3, + prev_ckpt=self.get_temp_dir(), + prev_vocab_path=prev_vocab_path, + axis=1) + sess.run(variables.global_variables_initializer()) + self.assertTrue( + isinstance(fruit_output_layer, variables.PartitionedVariable)) + fruit_output_layer_vars = fruit_output_layer._get_variable_list() + self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.]], + fruit_output_layer_vars[0].eval(sess)) + self.assertAllClose([[1.2, 1.5, 0.], [2.3, 2., 0.]], + fruit_output_layer_vars[1].eval(sess)) + def testWarmStart_ListOfVariables(self): # Save checkpoint from which to warm-start. _, prev_int_val = self._create_prev_run_var("v1", shape=[10, 1], |