aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/tensor_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/tensor_util.py')
-rw-r--r--tensorflow/python/framework/tensor_util.py67
1 files changed, 43 insertions, 24 deletions
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
index 74390bd6a3..de03c6ac7f 100644
--- a/tensorflow/python/framework/tensor_util.py
+++ b/tensorflow/python/framework/tensor_util.py
@@ -219,6 +219,18 @@ def _NotNone(v):
return v
+def _FilterTuple(v):
+ if not isinstance(v, (list, tuple)):
+ return v
+ if isinstance(v, tuple):
+ if not any(isinstance(x, (list, tuple)) for x in v):
+ return None
+ if isinstance(v, list):
+ if not any(isinstance(x, (list, tuple)) for x in v):
+ return _FirstNotNone([None if isinstance(x, (list, tuple)) else x for x in v])
+ return _FirstNotNone([_FilterTuple(x) for x in v])
+
+
def _FilterInt(v):
if isinstance(v, (list, tuple)):
return _FirstNotNone([_FilterInt(x) for x in v])
@@ -259,29 +271,29 @@ def _FilterNotTensor(v):
_TF_TO_IS_OK = {
- dtypes.bool: _FilterBool,
- dtypes.complex128: _FilterComplex,
- dtypes.complex64: _FilterComplex,
- dtypes.float32: _FilterFloat,
- dtypes.float64: _FilterFloat,
- dtypes.int16: _FilterInt,
- dtypes.int32: _FilterInt,
- dtypes.int64: _FilterInt,
- dtypes.int8: _FilterInt,
- dtypes.qint16: _FilterInt,
- dtypes.qint32: _FilterInt,
- dtypes.qint8: _FilterInt,
- dtypes.quint16: _FilterInt,
- dtypes.quint8: _FilterInt,
- dtypes.string: _FilterStr,
- dtypes.uint16: _FilterInt,
- dtypes.uint8: _FilterInt,
+ dtypes.bool: [_FilterBool],
+ dtypes.complex128: [_FilterComplex],
+ dtypes.complex64: [_FilterComplex],
+ dtypes.float32: [_FilterFloat],
+ dtypes.float64: [_FilterFloat],
+ dtypes.int16: [_FilterInt],
+ dtypes.int32: [_FilterInt],
+ dtypes.int64: [_FilterInt],
+ dtypes.int8: [_FilterInt],
+ dtypes.qint16: [_FilterInt, _FilterTuple],
+ dtypes.qint32: [_FilterInt, _FilterTuple],
+ dtypes.qint8: [_FilterInt, _FilterTuple],
+ dtypes.quint16: [_FilterInt, _FilterTuple],
+ dtypes.quint8: [_FilterInt, _FilterTuple],
+ dtypes.string: [_FilterStr],
+ dtypes.uint16: [_FilterInt],
+ dtypes.uint8: [_FilterInt],
}
def _AssertCompatible(values, dtype):
- fn = _TF_TO_IS_OK.get(dtype, _FilterNotTensor)
- mismatch = fn(values)
+ fn_list = _TF_TO_IS_OK.get(dtype, [_FilterNotTensor])
+ mismatch = _FirstNotNone([fn(values) for fn in fn_list])
if mismatch is not None:
if dtype is None:
raise TypeError("List of Tensors when single Tensor expected")
@@ -290,13 +302,14 @@ def _AssertCompatible(values, dtype):
(dtype.name, repr(mismatch), type(mismatch).__name__))
-def make_tensor_proto(values, dtype=None, shape=None):
+def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False):
"""Create a TensorProto.
Args:
- values: Values to put in the TensorProto.
- dtype: Optional tensor_pb2 DataType value.
- shape: List of integers representing the dimensions of tensor.
+ values: Values to put in the TensorProto.
+ dtype: Optional tensor_pb2 DataType value.
+ shape: List of integers representing the dimensions of tensor.
+ verify_shape: Boolean that enables verification of a shape of values.
Returns:
A TensorProto. Depending on the type, it may contain data in the
@@ -306,7 +319,8 @@ def make_tensor_proto(values, dtype=None, shape=None):
Raises:
TypeError: if unsupported types are provided.
- ValueError: if arguments have inappropriate values.
+ ValueError: if arguments have inappropriate values or if verify_shape is
+ True and shape of values is not equals to a shape from the argument.
make_tensor_proto accepts "values" of a python scalar, a python list, a
numpy ndarray, or a numpy scalar.
@@ -396,6 +410,11 @@ def make_tensor_proto(values, dtype=None, shape=None):
shape_size = np.prod(shape)
is_same_size = shape_size == nparray.size
+ if verify_shape:
+ if not nparray.shape == tuple(shape):
+ raise TypeError("Expected Tensor's shape: %s, got %s." %
+ (tuple(shape), nparray.shape))
+
if nparray.size > shape_size:
raise ValueError(
"Too many elements provided. Needed at most %d, but received %d" %