aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-21 16:10:05 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-21 16:45:48 -0800
commit0672d414381b66dd63cc58ea860f8dd93af0083c (patch)
tree1237380a874e1720df5738bfc43ac558a89bac24
parent0547f9654b6972aa3b25deaa949e882dd248983f (diff)
No longer expose the as_ref argument from convert_to_tensor and friends.
Change: 139844559
-rw-r--r--tensorflow/contrib/framework/python/framework/tensor_util.py8
-rw-r--r--tensorflow/contrib/labeled_tensor/python/ops/core.py2
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py36
-rw-r--r--tensorflow/python/framework/op_def_library.py10
-rw-r--r--tensorflow/python/framework/ops.py164
-rw-r--r--tensorflow/python/framework/sparse_tensor.py3
-rw-r--r--tensorflow/python/ops/control_flow_ops.py14
-rw-r--r--tensorflow/python/training/saver.py4
8 files changed, 186 insertions, 55 deletions
diff --git a/tensorflow/contrib/framework/python/framework/tensor_util.py b/tensorflow/contrib/framework/python/framework/tensor_util.py
index 7176196627..34420fc87f 100644
--- a/tensorflow/contrib/framework/python/framework/tensor_util.py
+++ b/tensorflow/contrib/framework/python/framework/tensor_util.py
@@ -356,8 +356,7 @@ def with_shape(expected_shape, tensor):
return tensor
-def convert_to_tensor_or_sparse_tensor(
- value, dtype=None, name=None, as_ref=False):
+def convert_to_tensor_or_sparse_tensor(value, dtype=None, name=None):
"""Converts value to a `SparseTensor` or `Tensor`.
Args:
@@ -366,8 +365,6 @@ def convert_to_tensor_or_sparse_tensor(
dtype: Optional element type for the returned tensor. If missing, the
type is inferred from the type of `value`.
name: Optional name to use if a new `Tensor` is created.
- as_ref: True if we want the result as a ref tensor. Only used if a new
- `Tensor` is created.
Returns:
A `SparseTensor` or `Tensor` based on `value`.
@@ -385,4 +382,5 @@ def convert_to_tensor_or_sparse_tensor(
'Sparse dtype: requested = %s, actual = %s' % (
dtype.name, value.dtype.name))
return value
- return ops.convert_to_tensor(value, dtype=dtype, name=name, as_ref=as_ref)
+ return ops.internal_convert_to_tensor(
+ value, dtype=dtype, name=name)
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/core.py b/tensorflow/contrib/labeled_tensor/python/ops/core.py
index b35745861a..cd590f9ffd 100644
--- a/tensorflow/contrib/labeled_tensor/python/ops/core.py
+++ b/tensorflow/contrib/labeled_tensor/python/ops/core.py
@@ -501,7 +501,7 @@ tc.register_type_abbreviation(LabeledTensor, 'labeled_tensor.LabeledTensor')
@tc.accepts(LabeledTensor)
def _convert_labeled_tensor_to_tensor(value, *args, **kwargs):
# call ops.convert_to_tensor to handle optional arguments appropriately
- return ops.convert_to_tensor(value.tensor, *args, **kwargs)
+ return ops.internal_convert_to_tensor(value.tensor, *args, **kwargs)
ops.register_tensor_conversion_function(
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
index 5200ef1d88..7143520e3f 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
@@ -27,7 +27,7 @@ from tensorflow.python import summary
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
-from tensorflow.python.framework.ops import convert_to_tensor
+from tensorflow.python.framework.ops import internal_convert_to_tensor
from tensorflow.python.framework.ops import name_scope
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -228,18 +228,18 @@ class SparseFeatureColumn(object):
"""
with name_scope(None, 'SparseFeatureColumn',
[example_indices, feature_indices]):
- self._example_indices = convert_to_tensor(example_indices,
- name='example_indices',
- dtype=dtypes.int64)
- self._feature_indices = convert_to_tensor(feature_indices,
- name='feature_indices',
- dtype=dtypes.int64)
+ self._example_indices = internal_convert_to_tensor(example_indices,
+ name='example_indices',
+ dtype=dtypes.int64)
+ self._feature_indices = internal_convert_to_tensor(feature_indices,
+ name='feature_indices',
+ dtype=dtypes.int64)
self._feature_values = None
if feature_values is not None:
with name_scope(None, 'SparseFeatureColumn', [feature_values]):
- self._feature_values = convert_to_tensor(feature_values,
- name='feature_values',
- dtype=dtypes.float32)
+ self._feature_values = internal_convert_to_tensor(feature_values,
+ name='feature_values',
+ dtype=dtypes.float32)
@property
def example_indices(self):
@@ -459,7 +459,7 @@ class SdcaModel(object):
def _convert_n_to_tensor(self, input_list, as_ref=False):
"""Converts input list to a set of tensors."""
- return [convert_to_tensor(x, as_ref=as_ref) for x in input_list]
+ return [internal_convert_to_tensor(x, as_ref=as_ref) for x in input_list]
def _linear_predictions(self, examples):
"""Returns predictions of the form w*x."""
@@ -536,7 +536,7 @@ class SdcaModel(object):
# pylint: disable=protected-access
example_ids_hashed = gen_sdca_ops._sdca_fprint(
- convert_to_tensor(self._examples['example_ids']))
+ internal_convert_to_tensor(self._examples['example_ids']))
# pylint: enable=protected-access
example_state_data = self._hashtable.lookup(example_ids_hashed)
# Solver returns example_state_update, new delta sparse_feature_weights
@@ -561,8 +561,8 @@ class SdcaModel(object):
sparse_feature_indices,
sparse_features_values,
self._convert_n_to_tensor(self._examples['dense_features']),
- convert_to_tensor(self._examples['example_weights']),
- convert_to_tensor(self._examples['example_labels']),
+ internal_convert_to_tensor(self._examples['example_weights']),
+ internal_convert_to_tensor(self._examples['example_labels']),
sparse_indices,
sparse_weights,
self._convert_n_to_tensor(self._slots[
@@ -676,9 +676,11 @@ class SdcaModel(object):
predictions = math_ops.cast(
self._linear_predictions(examples), dtypes.float64)
labels = math_ops.cast(
- convert_to_tensor(examples['example_labels']), dtypes.float64)
+ internal_convert_to_tensor(
+ examples['example_labels']), dtypes.float64)
weights = math_ops.cast(
- convert_to_tensor(examples['example_weights']), dtypes.float64)
+ internal_convert_to_tensor(
+ examples['example_weights']), dtypes.float64)
if self._options['loss_type'] == 'logistic_loss':
return math_ops.reduce_sum(math_ops.mul(
@@ -722,7 +724,7 @@ class SdcaModel(object):
'sparse_features', 'dense_features'], examples)
self._assertList(['sparse_features', 'dense_features'], examples)
with name_scope('sdca/regularized_loss'):
- weights = convert_to_tensor(examples['example_weights'])
+ weights = internal_convert_to_tensor(examples['example_weights'])
return ((
self._l1_loss() +
# Note that here we are using the raw regularization
diff --git a/tensorflow/python/framework/op_def_library.py b/tensorflow/python/framework/op_def_library.py
index 2c27cef3a9..27a7d07296 100644
--- a/tensorflow/python/framework/op_def_library.py
+++ b/tensorflow/python/framework/op_def_library.py
@@ -427,7 +427,7 @@ class OpDefLibrary(object):
try:
if not input_arg.is_ref and dtype:
dtype = dtypes.as_dtype(dtype).base_dtype
- values = ops.convert_n_to_tensor(
+ values = ops.internal_convert_n_to_tensor(
values,
name=input_arg.name,
dtype=dtype if dtype else None,
@@ -441,7 +441,7 @@ class OpDefLibrary(object):
observed_types = []
for value in values:
try:
- converted_value = ops.convert_to_tensor(
+ converted_value = ops.internal_convert_to_tensor(
value, as_ref=input_arg.is_ref)
observed_types.append(converted_value.dtype.base_dtype.name)
except (TypeError, ValueError):
@@ -482,7 +482,7 @@ class OpDefLibrary(object):
default_dtype = default_type_attr_map[input_arg.type_attr]
try:
- values = ops.convert_to_tensor(
+ values = ops.internal_convert_to_tensor(
values,
name=input_arg.name,
dtype=dtype,
@@ -499,8 +499,8 @@ class OpDefLibrary(object):
repr(values), type(values).__name__))
except ValueError:
# What type does convert_to_tensor think it has?
- observed = ops.convert_to_tensor(values,
- as_ref=input_arg.is_ref).dtype.name
+ observed = ops.internal_convert_to_tensor(
+ values, as_ref=input_arg.is_ref).dtype.name
prefix = ("Input '%s' of '%s' Op has type %s that does not match" %
(input_name, op_type_name, observed))
if input_arg.type != types_pb2.DT_INVALID:
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 68ddb7af84..c4c64f1577 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -592,7 +592,6 @@ register_dense_tensor_like_type(Tensor)
def convert_to_tensor(value,
dtype=None,
name=None,
- as_ref=False,
preferred_dtype=None):
"""Converts the given `value` to a `Tensor`.
@@ -624,8 +623,50 @@ def convert_to_tensor(value,
dtype: Optional element type for the returned tensor. If missing, the
type is inferred from the type of `value`.
name: Optional name to use if a new `Tensor` is created.
- as_ref: True if we want the result as a ref tensor. Only used if a new
- `Tensor` is created.
+ preferred_dtype: Optional element type for the returned tensor,
+ used when dtype is None. In some cases, a caller may not have a
+ dtype in mind when converting to a tensor, so preferred_dtype
+ can be used as a soft preference. If the conversion to
+ `preferred_dtype` is not possible, this argument has no effect.
+
+ Returns:
+ An `Output` based on `value`.
+
+ Raises:
+ TypeError: If no conversion function is registered for `value`.
+ RuntimeError: If a registered conversion function returns an invalid value.
+
+ """
+ return internal_convert_to_tensor(
+ value=value,
+ dtype=dtype,
+ name=name,
+ preferred_dtype=preferred_dtype,
+ as_ref=False)
+
+
+def internal_convert_to_tensor(value,
+ dtype=None,
+ name=None,
+ as_ref=False,
+ preferred_dtype=None):
+ """Converts the given `value` to an `Tensor`.
+
+ This function converts Python objects of various types to `Tensor`
+ objects. It accepts `Tensor` objects, numpy arrays, Python lists,
+ and Python scalars. For example:
+
+ This function can be useful when composing a new operation in Python
+ All standard Python op constructors apply this function to each of their
+ Tensor-valued inputs, which allows those ops to accept numpy arrays, Python
+ lists, and scalars in addition to `Tensor` objects.
+
+ Args:
+ value: An object whose type has a registered `Tensor` conversion function.
+ dtype: Optional element type for the returned tensor. If missing, the
+ type is inferred from the type of `value`.
+ name: Optional name to use if a new `Tensor` is created.
+ as_ref: True if we want the mutable view of Variables, if applicable.
preferred_dtype: Optional element type for the returned tensor,
used when dtype is None. In some cases, a caller may not have a
dtype in mind when converting to a tensor, so preferred_dtype
@@ -687,11 +728,11 @@ def convert_to_tensor(value,
% (error_prefix, value, type(value)))
-def convert_n_to_tensor(values,
- dtype=None,
- name=None,
- as_ref=False,
- preferred_dtype=None):
+def internal_convert_n_to_tensor(values,
+ dtype=None,
+ name=None,
+ as_ref=False,
+ preferred_dtype=None):
"""Converts `values` to a list of `Tensor` objects.
Args:
@@ -722,7 +763,7 @@ def convert_n_to_tensor(values,
for i, value in enumerate(values):
n = None if name is None else "%s_%d" % (name, i)
ret.append(
- convert_to_tensor(
+ internal_convert_to_tensor(
value,
dtype=dtype,
name=n,
@@ -731,8 +772,41 @@ def convert_n_to_tensor(values,
return ret
-def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None,
- as_ref=False):
+def convert_n_to_tensor(values,
+ dtype=None,
+ name=None,
+ preferred_dtype=None):
+ """Converts `values` to a list of `Tensor` objects.
+
+ Args:
+ values: A list of objects that can be consumed by `tf.convert_to_tensor()`.
+ dtype: (Optional.) The required `DType` of the returned `Tensor` objects.
+ name: (Optional.) A name prefix to used when a new `Tensor` is
+ created, in which case element `i` will be given the name `name
+ + '_' + i`.
+ preferred_dtype: Optional element type for the returned tensors,
+ used when dtype is None. In some cases, a caller may not have a
+ dtype in mind when converting to a tensor, so preferred_dtype
+ can be used as a soft preference. If the conversion to
+ `preferred_dtype` is not possible, this argument has no effect.
+
+ Returns:
+ A list of `Tensor` and/or `IndexedSlices` objects.
+
+ Raises:
+ TypeError: If no conversion function is registered for an element in
+ `values`.
+ RuntimeError: If a registered conversion function returns an invalid
+ value.
+ """
+ return internal_convert_n_to_tensor(values=values,
+ dtype=dtype,
+ name=name,
+ preferred_dtype=preferred_dtype,
+ as_ref=False)
+
+
+def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None):
"""Converts the given object to a `Tensor` or an `IndexedSlices`.
If `value` is an `IndexedSlices` or `SparseTensor` it is returned
@@ -745,6 +819,31 @@ def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None,
dtype: (Optional.) The required `DType` of the returned `Tensor` or
`IndexedSlices`.
name: (Optional.) A name to use if a new `Tensor` is created.
+
+ Returns:
+ An `Tensor`, `IndexedSlices`, or `SparseTensor` based on `value`.
+
+ Raises:
+ ValueError: If `dtype` does not match the element type of `value`.
+ """
+ return internal_convert_to_tensor_or_indexed_slices(
+ value=value, dtype=dtype, name=name, as_ref=False)
+
+
+def internal_convert_to_tensor_or_indexed_slices(value, dtype=None, name=None,
+ as_ref=False):
+ """Converts the given object to an `Tensor` or an `IndexedSlices`.
+
+ If `value` is an `IndexedSlices` or `SparseTensor` it is returned
+ unmodified. Otherwise, it is converted to a `Tensor` using
+ `convert_to_tensor()`.
+
+ Args:
+ value: An `IndexedSlices`, `SparseTensor`, or an object that can be consumed
+ by `convert_to_tensor()`.
+ dtype: (Optional.) The required `DType` of the returned `Tensor` or
+ `IndexedSlices`.
+ name: (Optional.) A name to use if a new `Tensor` is created.
as_ref: True if the caller wants the results as ref tensors.
Returns:
@@ -760,11 +859,14 @@ def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None,
% (dtypes.as_dtype(dtype).name, value.dtype.name, str(value)))
return value
else:
- return convert_to_tensor(value, dtype=dtype, name=name, as_ref=as_ref)
+ return internal_convert_to_tensor(value,
+ dtype=dtype,
+ name=name,
+ as_ref=as_ref)
-def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None,
- as_ref=False):
+def internal_convert_n_to_tensor_or_indexed_slices(values, dtype=None,
+ name=None, as_ref=False):
"""Converts `values` to a list of `Tensor` or `IndexedSlices` objects.
Any `IndexedSlices` or `SparseTensor` objects in `values` are returned
@@ -798,11 +900,39 @@ def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None,
else:
n = None if name is None else "%s_%d" % (name, i)
ret.append(
- convert_to_tensor_or_indexed_slices(value, dtype=dtype, name=n,
- as_ref=as_ref))
+ internal_convert_to_tensor_or_indexed_slices(
+ value, dtype=dtype, name=n, as_ref=as_ref))
return ret
+def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None):
+ """Converts `values` to a list of `Output` or `IndexedSlices` objects.
+
+ Any `IndexedSlices` or `SparseTensor` objects in `values` are returned
+ unmodified.
+
+ Args:
+ values: A list of `None`, `IndexedSlices`, `SparseTensor`, or objects that
+ can be consumed by `convert_to_tensor()`.
+ dtype: (Optional.) The required `DType` of the returned `Tensor`
+ `IndexedSlices`.
+ name: (Optional.) A name prefix to used when a new `Tensor` is
+ created, in which case element `i` will be given the name `name
+ + '_' + i`.
+
+ Returns:
+ A list of `Tensor`, `IndexedSlices`, and/or `SparseTensor` objects.
+
+ Raises:
+ TypeError: If no conversion function is registered for an element in
+ `values`.
+ RuntimeError: If a registered conversion function returns an invalid
+ value.
+ """
+ return internal_convert_n_to_tensor_or_indexed_slices(
+ values=values, dtype=dtype, name=name, as_ref=False)
+
+
def register_tensor_conversion_function(base_type, conversion_func,
priority=100):
"""Registers a function for converting objects of `base_type` to `Tensor`.
@@ -2848,7 +2978,7 @@ class Graph(object):
if not isinstance(op, Operation):
# We always want to colocate with the reference op.
- op = convert_to_tensor_or_indexed_slices(op, as_ref=True).op
+ op = internal_convert_to_tensor_or_indexed_slices(op, as_ref=True).op
# By default, colocate_with resets the device function stack,
# since colocate_with is typically used in specific internal
diff --git a/tensorflow/python/framework/sparse_tensor.py b/tensorflow/python/framework/sparse_tensor.py
index 4ce92c4225..8b7bb130de 100644
--- a/tensorflow/python/framework/sparse_tensor.py
+++ b/tensorflow/python/framework/sparse_tensor.py
@@ -131,7 +131,8 @@ class SparseTensor(_TensorLike):
# values later if it is a VariableOp.
# TODO(touts): Consider adding mutable_values() when 'values'
# is a VariableOp and updating users of SparseTensor.
- values = ops.convert_to_tensor(values, name="values", as_ref=True)
+ values = ops.internal_convert_to_tensor(
+ values, name="values", as_ref=True)
shape = ops.convert_to_tensor(shape, name="shape", dtype=dtypes.int64)
self._indices = indices
self._values = values
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 25a6c44869..42686f26c2 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -158,7 +158,7 @@ def _Identity(data, name=None):
Returns:
A Tensor with the same type and value as the input Tensor.
"""
- data = ops.convert_to_tensor_or_indexed_slices(data, as_ref=True)
+ data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True)
if isinstance(data, ops.Tensor):
if data.dtype._is_ref_dtype: # pylint: disable=protected-access
return gen_array_ops._ref_identity(data, name=name)
@@ -180,7 +180,7 @@ def _Identity(data, name=None):
def _NextIteration(data, name=None):
- data = ops.convert_to_tensor_or_indexed_slices(data, as_ref=True)
+ data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True)
if isinstance(data, ops.Tensor):
if data.dtype._is_ref_dtype: # pylint: disable=protected-access
return ref_next_iteration(data, name=name)
@@ -221,7 +221,7 @@ def _Enter(data, frame_name, is_constant=False, parallel_iterations=10,
Returns:
The same tensor as `data`.
"""
- data = ops.convert_to_tensor_or_indexed_slices(data, as_ref=True)
+ data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True)
if isinstance(data, ops.Tensor):
if data.dtype._is_ref_dtype and use_ref: # pylint: disable=protected-access
result = ref_enter(data, frame_name, is_constant, parallel_iterations,
@@ -270,7 +270,7 @@ def exit(data, name=None):
Returns:
The same tensor as `data`.
"""
- data = ops.convert_to_tensor_or_indexed_slices(data, as_ref=True)
+ data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True)
if isinstance(data, ops.Tensor):
if data.dtype._is_ref_dtype: # pylint: disable=protected-access
return gen_control_flow_ops._ref_exit(data, name)
@@ -311,8 +311,8 @@ def switch(data, pred, dtype=None, name=None):
to `output_true`, otherwise it goes to `output_false`.
"""
with ops.name_scope(name, "Switch", [data, pred]) as name:
- data = ops.convert_to_tensor_or_indexed_slices(data, dtype=dtype,
- name="data", as_ref=True)
+ data = ops.internal_convert_to_tensor_or_indexed_slices(
+ data, dtype=dtype, name="data", as_ref=True)
pred = ops.convert_to_tensor(pred, name="pred")
if isinstance(data, ops.Tensor):
return gen_control_flow_ops._switch(data, pred, name=name)
@@ -411,7 +411,7 @@ def merge(inputs, name=None):
if any([inp is None for inp in inputs]):
raise ValueError("At least one of the merge inputs is None: %s" % inputs)
with ops.name_scope(name, "Merge", inputs) as name:
- inputs = [ops.convert_to_tensor_or_indexed_slices(inp, as_ref=True)
+ inputs = [ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref=True)
for inp in inputs]
if all([isinstance(v, ops.Tensor) for v in inputs]):
if all([v.dtype._is_ref_dtype for v in inputs]): # pylint: disable=protected-access
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index ea4c163ac9..f020c8cd0f 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -471,7 +471,7 @@ class BaseSaverBuilder(object):
else:
names_to_saveables[name] = [var]
else:
- var = ops.convert_to_tensor(var, as_ref=True)
+ var = ops.internal_convert_to_tensor(var, as_ref=True)
if not BaseSaverBuilder._IsVariable(var):
raise TypeError("Variable to save is not a Variable: %s" % var)
name = var.op.name
@@ -532,7 +532,7 @@ class BaseSaverBuilder(object):
# pylint: enable=protected-access
else:
# A variable or tensor.
- variable = ops.convert_to_tensor(op, as_ref=True)
+ variable = ops.internal_convert_to_tensor(op, as_ref=True)
if not BaseSaverBuilder._IsVariable(variable):
raise TypeError("names_to_saveables must be a dict mapping string "
"names to Tensors/Variables. Not a variable: %s" %