diff options
author | Pavithra Vijay <psv@google.com> | 2018-07-30 11:52:57 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-30 11:56:57 -0700 |
commit | 4027262f588466ee4a419c28521a5b53aad12e5a (patch) | |
tree | 8445d665a13104ec9bbe1eed1dda1ba8ff2f36e4 /tensorflow/python/keras/engine/training_arrays.py | |
parent | 3328243cdca5d08f56fc64c582ce2f3b80630259 (diff) |
De-dup few eager mode tests, remove some unused functions and params.
PiperOrigin-RevId: 206621105
Diffstat (limited to 'tensorflow/python/keras/engine/training_arrays.py')
-rw-r--r-- | tensorflow/python/keras/engine/training_arrays.py | 18 |
1 files changed, 6 insertions, 12 deletions
diff --git a/tensorflow/python/keras/engine/training_arrays.py b/tensorflow/python/keras/engine/training_arrays.py index adefffab11..6572e2c344 100644 --- a/tensorflow/python/keras/engine/training_arrays.py +++ b/tensorflow/python/keras/engine/training_arrays.py @@ -50,7 +50,6 @@ def fit_loop(model, val_targets=None, val_sample_weights=None, shuffle=True, - callback_metrics=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None): @@ -69,8 +68,6 @@ def fit_loop(model, val_targets: List of target arrays. val_sample_weights: Optional list of sample weight arrays. shuffle: Whether to shuffle the data at the beginning of each epoch - callback_metrics: List of strings, the display names of the metrics - passed to the callbacks. They should be the concatenation of list the display names of the outputs of `f` and the list of display names of the outputs of `f_val`. initial_epoch: Epoch at which to start training @@ -121,9 +118,7 @@ def fit_loop(model, out_labels = model.metrics_names if do_validation: - callback_metrics = copy.copy(out_labels) + [ - 'val_' + n for n in out_labels - ] + callback_metrics = copy.copy(out_labels) + ['val_' + n for n in out_labels] # need to create the test_function before start of the first epoch # because TensorBoard callback on_epoch_begin adds summary to the # list of fetches of the test_function @@ -197,9 +192,7 @@ def fit_loop(model, if steps_per_epoch is not None: # Step-wise fit loop. for step_index in range(steps_per_epoch): - batch_logs = {} - batch_logs['batch'] = step_index - batch_logs['size'] = 1 + batch_logs = {'batch': step_index, 'size': 1} callbacks.on_batch_begin(step_index, batch_logs) try: outs = f(ins) @@ -388,7 +381,9 @@ def predict_loop(model, inputs, batch_size=32, verbose=0, steps=None): return outs -def test_loop(model, inputs, targets, +def test_loop(model, + inputs, + targets, sample_weights=None, batch_size=None, verbose=0, @@ -485,8 +480,7 @@ def test_loop(model, inputs, targets, if isinstance(batch_outs, list): if batch_index == 0: - for batch_out in enumerate(batch_outs): - outs.append(0.) + outs.extend([0.] * len(batch_outs)) for i, batch_out in enumerate(batch_outs): if i in stateful_metric_indices: outs[i] = batch_out |