aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/warm_starting_util_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/warm_starting_util_test.py')
-rw-r--r--tensorflow/python/training/warm_starting_util_test.py140
1 files changed, 129 insertions, 11 deletions
diff --git a/tensorflow/python/training/warm_starting_util_test.py b/tensorflow/python/training/warm_starting_util_test.py
index 70a84bc3f6..3ee0f6aaa2 100644
--- a/tensorflow/python/training/warm_starting_util_test.py
+++ b/tensorflow/python/training/warm_starting_util_test.py
@@ -107,7 +107,7 @@ class WarmStartingUtilTest(test.TestCase):
"fruit_weights", initializer=[[0.], [0.], [0.], [0.]])
ws_util._warm_start_var(fruit_weights, self.get_temp_dir())
sess.run(variables.global_variables_initializer())
- self.assertAllEqual(prev_val, fruit_weights.eval(sess))
+ self.assertAllClose(prev_val, fruit_weights.eval(sess))
def testWarmStartVarPrevVarPartitioned(self):
_, weights = self._create_prev_run_var(
@@ -123,7 +123,7 @@ class WarmStartingUtilTest(test.TestCase):
"fruit_weights", initializer=[[0.], [0.], [0.], [0.]])
ws_util._warm_start_var(fruit_weights, self.get_temp_dir())
sess.run(variables.global_variables_initializer())
- self.assertAllEqual(prev_val, fruit_weights.eval(sess))
+ self.assertAllClose(prev_val, fruit_weights.eval(sess))
def testWarmStartVarCurrentVarPartitioned(self):
_, prev_val = self._create_prev_run_var(
@@ -143,7 +143,7 @@ class WarmStartingUtilTest(test.TestCase):
fruit_weights = fruit_weights._get_variable_list()
new_val = np.concatenate(
[fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0)
- self.assertAllEqual(prev_val, new_val)
+ self.assertAllClose(prev_val, new_val)
def testWarmStartVarBothVarsPartitioned(self):
_, weights = self._create_prev_run_var(
@@ -170,7 +170,7 @@ class WarmStartingUtilTest(test.TestCase):
fruit_weights = fruit_weights._get_variable_list()
new_val = np.concatenate(
[fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0)
- self.assertAllEqual(prev_val, new_val)
+ self.assertAllClose(prev_val, new_val)
def testWarmStartVarWithVocab(self):
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
@@ -189,9 +189,34 @@ class WarmStartingUtilTest(test.TestCase):
ws_util._warm_start_var_with_vocab(fruit_weights, new_vocab_path, 5,
self.get_temp_dir(), prev_vocab_path)
sess.run(variables.global_variables_initializer())
- self.assertAllEqual([[2.], [1.5], [1.], [0.5], [0.]],
+ self.assertAllClose([[2.], [1.5], [1.], [0.5], [0.]],
fruit_weights.eval(sess))
+ def testWarmStartVarWithColumnVocab(self):
+ prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
+ self._create_prev_run_var(
+ "fruit_output_layer",
+ initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]])
+
+ # New vocab with elements in reverse order and one new element.
+ new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
+ "new_vocab")
+ # New session and new graph.
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as sess:
+ fruit_output_layer = variable_scope.get_variable(
+ "fruit_output_layer",
+ initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
+ [0., 0., 0.]])
+ ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
+ current_vocab_size=3,
+ prev_ckpt=self.get_temp_dir(),
+ prev_vocab_path=prev_vocab_path,
+ axis=1)
+ sess.run(variables.global_variables_initializer())
+ self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.], [1.2, 1.5, 0.],
+ [2.3, 2., 0.]], fruit_output_layer.eval(sess))
+
def testWarmStartVarWithVocabConstrainedOldVocabSize(self):
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"old_vocab")
@@ -215,7 +240,7 @@ class WarmStartingUtilTest(test.TestCase):
previous_vocab_size=2)
sess.run(variables.global_variables_initializer())
# Old vocabulary limited to ['apple', 'banana'].
- self.assertAllEqual([[0.], [0.], [1.], [0.5], [0.]],
+ self.assertAllClose([[0.], [0.], [1.], [0.5], [0.]],
fruit_weights.eval(sess))
def testWarmStartVarWithVocabPrevVarPartitioned(self):
@@ -238,9 +263,36 @@ class WarmStartingUtilTest(test.TestCase):
ws_util._warm_start_var_with_vocab(fruit_weights, new_vocab_path, 5,
self.get_temp_dir(), prev_vocab_path)
sess.run(variables.global_variables_initializer())
- self.assertAllEqual([[2.], [1.5], [1.], [0.5], [0.]],
+ self.assertAllClose([[2.], [1.5], [1.], [0.5], [0.]],
fruit_weights.eval(sess))
+ def testWarmStartVarWithColumnVocabPrevVarPartitioned(self):
+ prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
+ self._create_prev_run_var(
+ "fruit_output_layer",
+ shape=[4, 2],
+ initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]],
+ partitioner=lambda shape, dtype: [2, 1])
+
+ # New vocab with elements in reverse order and one new element.
+ new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
+ "new_vocab")
+ # New session and new graph.
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as sess:
+ fruit_output_layer = variable_scope.get_variable(
+ "fruit_output_layer",
+ initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
+ [0., 0., 0.]])
+ ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
+ current_vocab_size=3,
+ prev_ckpt=self.get_temp_dir(),
+ prev_vocab_path=prev_vocab_path,
+ axis=1)
+ sess.run(variables.global_variables_initializer())
+ self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.], [1.2, 1.5, 0.],
+ [2.3, 2., 0.]], fruit_output_layer.eval(sess))
+
def testWarmStartVarWithVocabCurrentVarPartitioned(self):
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"old_vocab")
@@ -269,11 +321,43 @@ class WarmStartingUtilTest(test.TestCase):
self.assertTrue(
isinstance(fruit_weights, variables.PartitionedVariable))
fruit_weights_vars = fruit_weights._get_variable_list()
- self.assertAllEqual([[2.], [1.5], [1.]],
+ self.assertAllClose([[2.], [1.5], [1.]],
fruit_weights_vars[0].eval(sess))
- self.assertAllEqual([[0.5], [0.], [0.]],
+ self.assertAllClose([[0.5], [0.], [0.]],
fruit_weights_vars[1].eval(sess))
+ def testWarmStartVarWithColumnVocabCurrentVarPartitioned(self):
+ prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
+ self._create_prev_run_var(
+ "fruit_output_layer",
+ initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]])
+
+ # New vocab with elements in reverse order and one new element.
+ new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
+ "new_vocab")
+ # New session and new graph.
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as sess:
+ fruit_output_layer = variable_scope.get_variable(
+ "fruit_output_layer",
+ shape=[4, 3],
+ initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
+ [0., 0., 0.]],
+ partitioner=lambda shape, dtype: [2, 1])
+ ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
+ current_vocab_size=3,
+ prev_ckpt=self.get_temp_dir(),
+ prev_vocab_path=prev_vocab_path,
+ axis=1)
+ sess.run(variables.global_variables_initializer())
+ self.assertTrue(
+ isinstance(fruit_output_layer, variables.PartitionedVariable))
+ fruit_output_layer_vars = fruit_output_layer._get_variable_list()
+ self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.]],
+ fruit_output_layer_vars[0].eval(sess))
+ self.assertAllClose([[1.2, 1.5, 0.], [2.3, 2., 0.]],
+ fruit_output_layer_vars[1].eval(sess))
+
def testWarmStartVarWithVocabBothVarsPartitioned(self):
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"old_vocab")
@@ -301,11 +385,45 @@ class WarmStartingUtilTest(test.TestCase):
self.assertTrue(
isinstance(fruit_weights, variables.PartitionedVariable))
fruit_weights_vars = fruit_weights._get_variable_list()
- self.assertAllEqual([[2.], [1.5], [1.]],
+ self.assertAllClose([[2.], [1.5], [1.]],
fruit_weights_vars[0].eval(sess))
- self.assertAllEqual([[0.5], [0.], [0.]],
+ self.assertAllClose([[0.5], [0.], [0.]],
fruit_weights_vars[1].eval(sess))
+ def testWarmStartVarWithColumnVocabBothVarsPartitioned(self):
+ prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
+ self._create_prev_run_var(
+ "fruit_output_layer",
+ shape=[4, 2],
+ initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]],
+ partitioner=lambda shape, dtype: [2, 1])
+
+ # New vocab with elements in reverse order and one new element.
+ new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
+ "new_vocab")
+ # New session and new graph.
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as sess:
+ fruit_output_layer = variable_scope.get_variable(
+ "fruit_output_layer",
+ shape=[4, 3],
+ initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
+ [0., 0., 0.]],
+ partitioner=lambda shape, dtype: [2, 1])
+ ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
+ current_vocab_size=3,
+ prev_ckpt=self.get_temp_dir(),
+ prev_vocab_path=prev_vocab_path,
+ axis=1)
+ sess.run(variables.global_variables_initializer())
+ self.assertTrue(
+ isinstance(fruit_output_layer, variables.PartitionedVariable))
+ fruit_output_layer_vars = fruit_output_layer._get_variable_list()
+ self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.]],
+ fruit_output_layer_vars[0].eval(sess))
+ self.assertAllClose([[1.2, 1.5, 0.], [2.3, 2., 0.]],
+ fruit_output_layer_vars[1].eval(sess))
+
def testWarmStart_ListOfVariables(self):
# Save checkpoint from which to warm-start.
_, prev_int_val = self._create_prev_run_var("v1", shape=[10, 1],