aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops
diff options
context:
space:
mode:
authorGravatar Brian Patton <bjp@google.com>2018-09-24 19:00:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 19:04:34 -0700
commitdee007d9bab96fcbf7673cb7ed3d5235b122f12a (patch)
treef405d45eebc42420d972598b8069b6c7bab78b08 /tensorflow/python/ops
parentec2cc9122cca5fdec52d6c1ec42b771b8082d298 (diff)
Allow callers to specify a preferred dtype when calling convert_to_tensor.
PiperOrigin-RevId: 214370113
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r--tensorflow/python/ops/distributions/util.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/python/ops/distributions/util.py b/tensorflow/python/ops/distributions/util.py
index c61efebca0..ad848dfee6 100644
--- a/tensorflow/python/ops/distributions/util.py
+++ b/tensorflow/python/ops/distributions/util.py
@@ -155,7 +155,8 @@ def get_logits_and_probs(logits=None,
probs=None,
multidimensional=False,
validate_args=False,
- name="get_logits_and_probs"):
+ name="get_logits_and_probs",
+ dtype=None):
"""Converts logit to probabilities (or vice-versa), and returns both.
Args:
@@ -169,6 +170,7 @@ def get_logits_and_probs(logits=None,
`0 <= probs <= 1` (if not `multidimensional`) or that the last dimension
of `probs` sums to one.
name: A name for this operation (optional).
+ dtype: `tf.DType` to prefer when converting args to `Tensor`s.
Returns:
logits, probs: Tuple of `Tensor`s. If `probs` has an entry that is `0` or
@@ -183,7 +185,7 @@ def get_logits_and_probs(logits=None,
raise ValueError("Must pass probs or logits, but not both.")
if probs is None:
- logits = ops.convert_to_tensor(logits, name="logits")
+ logits = ops.convert_to_tensor(logits, name="logits", dtype=dtype)
if not logits.dtype.is_floating:
raise TypeError("logits must having floating type.")
# We can early return since we constructed probs and therefore know
@@ -194,7 +196,7 @@ def get_logits_and_probs(logits=None,
return logits, nn.softmax(logits, name="probs")
return logits, math_ops.sigmoid(logits, name="probs")
- probs = ops.convert_to_tensor(probs, name="probs")
+ probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype)
if not probs.dtype.is_floating:
raise TypeError("probs must having floating type.")