diff options
Diffstat (limited to 'tensorflow/contrib/tensor_forest/client/random_forest.py')
-rw-r--r-- | tensorflow/contrib/tensor_forest/client/random_forest.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py index db970deff5..0042d37acd 100644 --- a/tensorflow/contrib/tensor_forest/client/random_forest.py +++ b/tensorflow/contrib/tensor_forest/client/random_forest.py @@ -134,19 +134,19 @@ def _get_default_head(params, weights_name, output_type, name=None): weight_column=weights_name, label_dimension=params.num_outputs, name=name, - loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) else: if params.num_classes == 2: return core_head_lib.binary_classification_head( weight_column=weights_name, name=name, - loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) else: return core_head_lib.multi_class_head( n_classes=params.num_classes, weight_column=weights_name, name=name, - loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) def get_model_fn(params, graph_builder_class, |