aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Zakaria Haque <zakaria@google.com>2016-10-14 11:00:12 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-14 12:08:44 -0700
commit4e5e8119077db67e291461c7c0cb9b33b5deebb9 (patch)
tree3650431e4f78858c91c5ae6a29f8aa7f58528ab3
parent8b203064f8e1cd1035d31f1bd2f60fa3461d8b94 (diff)
In head/target column's softmax loss, changes check for target's dtype to is_integer.
Change: 136182028
-rw-r--r--tensorflow/contrib/layers/python/layers/target_column.py8
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head.py8
2 files changed, 6 insertions, 10 deletions
diff --git a/tensorflow/contrib/layers/python/layers/target_column.py b/tensorflow/contrib/layers/python/layers/target_column.py
index 0667fee32a..dbd55e5a01 100644
--- a/tensorflow/contrib/layers/python/layers/target_column.py
+++ b/tensorflow/contrib/layers/python/layers/target_column.py
@@ -23,7 +23,6 @@ import six
from tensorflow.contrib import losses
from tensorflow.contrib import metrics as metrics_lib
from tensorflow.contrib.framework import deprecated
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -412,10 +411,9 @@ def _log_loss_with_two_classes(logits, target):
def _softmax_cross_entropy_loss(logits, target):
- # Check that we got int32/int64 for classification.
- if (not target.dtype.is_compatible_with(dtypes.int64) and
- not target.dtype.is_compatible_with(dtypes.int32)):
- raise ValueError("Target's dtype should be int32, int64 or compatible. "
+ # Check that we got integer for classification.
+ if not target.dtype.is_integer:
+ raise ValueError("Target's dtype should be integer "
"Instead got %s." % target.dtype)
# sparse_softmax_cross_entropy_with_logits requires [batch_size] target.
if len(target.get_shape()) == 2:
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py
index 4b19f84a7a..7d93b023ef 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head.py
@@ -25,7 +25,6 @@ from tensorflow.contrib import metrics as metrics_lib
from tensorflow.contrib.learn.python.learn import metric_spec
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.session_bundle import exporter
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -758,10 +757,9 @@ def _log_loss_with_two_classes(logits, target):
def _softmax_cross_entropy_loss(logits, target):
- # Check that we got int32/int64 for classification.
- if (not target.dtype.is_compatible_with(dtypes.int64) and
- not target.dtype.is_compatible_with(dtypes.int32)):
- raise ValueError("Target's dtype should be int32, int64 or compatible. "
+ # Check that we got integer for classification.
+ if not target.dtype.is_integer:
+ raise ValueError("Target's dtype should be integer "
"Instead got %s." % target.dtype)
# sparse_softmax_cross_entropy_with_logits requires [batch_size] target.
if len(target.get_shape()) == 2: