diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-07 12:05:08 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-07 12:09:09 -0700 |
commit | 0a375d94b6fd4c3cd0bd5d0a301b3acc65b96d78 (patch) | |
tree | 8b7104cf254a5d2646363ae3087f6362c0aadf6c /tensorflow/contrib/tensor_forest | |
parent | e258e52d2c4060fc26fda43e4ce068d5ba2ab1ff (diff) |
Switching default loss reduction for core tensorforest to be the same as in old version.
PiperOrigin-RevId: 212014026
Diffstat (limited to 'tensorflow/contrib/tensor_forest')
-rw-r--r-- | tensorflow/contrib/tensor_forest/client/random_forest.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py index db970deff5..0042d37acd 100644 --- a/tensorflow/contrib/tensor_forest/client/random_forest.py +++ b/tensorflow/contrib/tensor_forest/client/random_forest.py @@ -134,19 +134,19 @@ def _get_default_head(params, weights_name, output_type, name=None): weight_column=weights_name, label_dimension=params.num_outputs, name=name, - loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) else: if params.num_classes == 2: return core_head_lib.binary_classification_head( weight_column=weights_name, name=name, - loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) else: return core_head_lib.multi_class_head( n_classes=params.num_classes, weight_column=weights_name, name=name, - loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) def get_model_fn(params, graph_builder_class, |