diff options
author | 2016-10-14 11:00:12 -0800 | |
---|---|---|
committer | 2016-10-14 12:08:44 -0700 | |
commit | 4e5e8119077db67e291461c7c0cb9b33b5deebb9 (patch) | |
tree | 3650431e4f78858c91c5ae6a29f8aa7f58528ab3 | |
parent | 8b203064f8e1cd1035d31f1bd2f60fa3461d8b94 (diff) |
In head/target column's softmax loss, changes check for target's dtype to is_integer.
Change: 136182028
-rw-r--r-- | tensorflow/contrib/layers/python/layers/target_column.py | 8 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/head.py | 8 |
2 files changed, 6 insertions, 10 deletions
diff --git a/tensorflow/contrib/layers/python/layers/target_column.py b/tensorflow/contrib/layers/python/layers/target_column.py index 0667fee32a..dbd55e5a01 100644 --- a/tensorflow/contrib/layers/python/layers/target_column.py +++ b/tensorflow/contrib/layers/python/layers/target_column.py @@ -23,7 +23,6 @@ import six from tensorflow.contrib import losses from tensorflow.contrib import metrics as metrics_lib from tensorflow.contrib.framework import deprecated -from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -412,10 +411,9 @@ def _log_loss_with_two_classes(logits, target): def _softmax_cross_entropy_loss(logits, target): - # Check that we got int32/int64 for classification. - if (not target.dtype.is_compatible_with(dtypes.int64) and - not target.dtype.is_compatible_with(dtypes.int32)): - raise ValueError("Target's dtype should be int32, int64 or compatible. " + # Check that we got integer for classification. + if not target.dtype.is_integer: + raise ValueError("Target's dtype should be integer " "Instead got %s." % target.dtype) # sparse_softmax_cross_entropy_with_logits requires [batch_size] target. if len(target.get_shape()) == 2: diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index 4b19f84a7a..7d93b023ef 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -25,7 +25,6 @@ from tensorflow.contrib import metrics as metrics_lib from tensorflow.contrib.learn.python.learn import metric_spec from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.session_bundle import exporter -from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -758,10 +757,9 @@ def _log_loss_with_two_classes(logits, target): def _softmax_cross_entropy_loss(logits, target): - # Check that we got int32/int64 for classification. - if (not target.dtype.is_compatible_with(dtypes.int64) and - not target.dtype.is_compatible_with(dtypes.int32)): - raise ValueError("Target's dtype should be int32, int64 or compatible. " + # Check that we got integer for classification. + if not target.dtype.is_integer: + raise ValueError("Target's dtype should be integer " "Instead got %s." % target.dtype) # sparse_softmax_cross_entropy_with_logits requires [batch_size] target. if len(target.get_shape()) == 2: |