aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest/client/random_forest.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensor_forest/client/random_forest.py')
-rw-r--r--tensorflow/contrib/tensor_forest/client/random_forest.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py
index db970deff5..0042d37acd 100644
--- a/tensorflow/contrib/tensor_forest/client/random_forest.py
+++ b/tensorflow/contrib/tensor_forest/client/random_forest.py
@@ -134,19 +134,19 @@ def _get_default_head(params, weights_name, output_type, name=None):
weight_column=weights_name,
label_dimension=params.num_outputs,
name=name,
- loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
else:
if params.num_classes == 2:
return core_head_lib.binary_classification_head(
weight_column=weights_name,
name=name,
- loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
else:
return core_head_lib.multi_class_head(
n_classes=params.num_classes,
weight_column=weights_name,
name=name,
- loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
def get_model_fn(params,
graph_builder_class,