aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-18 09:55:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-18 10:00:43 -0700
commit748ef1c2eb58b0e031d796ea0211a8c6d74531ff (patch)
treeae23b50092fb743d221a194f43c270f4d2489914 /tensorflow/examples
parentb78e0da14d29d7e9d3b51a7b633c960baacd710c (diff)
Add back in stddev parameter inadvertently dropped from conv model.
Use create_eval_graph() for graph to be frozen. PiperOrigin-RevId: 205093459
Diffstat (limited to 'tensorflow/examples')
-rw-r--r--tensorflow/examples/speech_commands/freeze.py2
-rw-r--r--tensorflow/examples/speech_commands/models.py2
2 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/examples/speech_commands/freeze.py b/tensorflow/examples/speech_commands/freeze.py
index 7657b23c60..89e790d4e4 100644
--- a/tensorflow/examples/speech_commands/freeze.py
+++ b/tensorflow/examples/speech_commands/freeze.py
@@ -130,7 +130,7 @@ def main(_):
FLAGS.clip_stride_ms, FLAGS.window_size_ms, FLAGS.window_stride_ms,
FLAGS.feature_bin_count, FLAGS.model_architecture, FLAGS.preprocess)
if FLAGS.quantize:
- tf.contrib.quantize.create_training_graph(quant_delay=0)
+ tf.contrib.quantize.create_eval_graph()
models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)
# Turn all the variables into inline constants inside the graph and save it.
diff --git a/tensorflow/examples/speech_commands/models.py b/tensorflow/examples/speech_commands/models.py
index 65ae3b1511..4d1454be0d 100644
--- a/tensorflow/examples/speech_commands/models.py
+++ b/tensorflow/examples/speech_commands/models.py
@@ -302,7 +302,7 @@ def create_conv_model(fingerprint_input, model_settings, is_training):
label_count = model_settings['label_count']
final_fc_weights = tf.get_variable(
name='final_fc_weights',
- initializer=tf.truncated_normal_initializer,
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
shape=[second_conv_element_count, label_count])
final_fc_bias = tf.get_variable(
name='final_fc_bias',