From dee007d9bab96fcbf7673cb7ed3d5235b122f12a Mon Sep 17 00:00:00 2001 From: Brian Patton Date: Mon, 24 Sep 2018 19:00:02 -0700 Subject: Allow callers to specify a preferred dtype when calling convert_to_tensor. PiperOrigin-RevId: 214370113 --- tensorflow/python/ops/distributions/util.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/ops/distributions/util.py b/tensorflow/python/ops/distributions/util.py index c61efebca0..ad848dfee6 100644 --- a/tensorflow/python/ops/distributions/util.py +++ b/tensorflow/python/ops/distributions/util.py @@ -155,7 +155,8 @@ def get_logits_and_probs(logits=None, probs=None, multidimensional=False, validate_args=False, - name="get_logits_and_probs"): + name="get_logits_and_probs", + dtype=None): """Converts logit to probabilities (or vice-versa), and returns both. Args: @@ -169,6 +170,7 @@ def get_logits_and_probs(logits=None, `0 <= probs <= 1` (if not `multidimensional`) or that the last dimension of `probs` sums to one. name: A name for this operation (optional). + dtype: `tf.DType` to prefer when converting args to `Tensor`s. Returns: logits, probs: Tuple of `Tensor`s. If `probs` has an entry that is `0` or @@ -183,7 +185,7 @@ def get_logits_and_probs(logits=None, raise ValueError("Must pass probs or logits, but not both.") if probs is None: - logits = ops.convert_to_tensor(logits, name="logits") + logits = ops.convert_to_tensor(logits, name="logits", dtype=dtype) if not logits.dtype.is_floating: raise TypeError("logits must having floating type.") # We can early return since we constructed probs and therefore know @@ -194,7 +196,7 @@ def get_logits_and_probs(logits=None, return logits, nn.softmax(logits, name="probs") return logits, math_ops.sigmoid(logits, name="probs") - probs = ops.convert_to_tensor(probs, name="probs") + probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype) if not probs.dtype.is_floating: raise TypeError("probs must having floating type.") -- cgit v1.2.3