diff options
Diffstat (limited to 'tensorflow/python/framework/tensor_util.py')
-rw-r--r-- | tensorflow/python/framework/tensor_util.py | 67 |
1 files changed, 43 insertions, 24 deletions
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index 74390bd6a3..de03c6ac7f 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -219,6 +219,18 @@ def _NotNone(v): return v +def _FilterTuple(v): + if not isinstance(v, (list, tuple)): + return v + if isinstance(v, tuple): + if not any(isinstance(x, (list, tuple)) for x in v): + return None + if isinstance(v, list): + if not any(isinstance(x, (list, tuple)) for x in v): + return _FirstNotNone([None if isinstance(x, (list, tuple)) else x for x in v]) + return _FirstNotNone([_FilterTuple(x) for x in v]) + + def _FilterInt(v): if isinstance(v, (list, tuple)): return _FirstNotNone([_FilterInt(x) for x in v]) @@ -259,29 +271,29 @@ def _FilterNotTensor(v): _TF_TO_IS_OK = { - dtypes.bool: _FilterBool, - dtypes.complex128: _FilterComplex, - dtypes.complex64: _FilterComplex, - dtypes.float32: _FilterFloat, - dtypes.float64: _FilterFloat, - dtypes.int16: _FilterInt, - dtypes.int32: _FilterInt, - dtypes.int64: _FilterInt, - dtypes.int8: _FilterInt, - dtypes.qint16: _FilterInt, - dtypes.qint32: _FilterInt, - dtypes.qint8: _FilterInt, - dtypes.quint16: _FilterInt, - dtypes.quint8: _FilterInt, - dtypes.string: _FilterStr, - dtypes.uint16: _FilterInt, - dtypes.uint8: _FilterInt, + dtypes.bool: [_FilterBool], + dtypes.complex128: [_FilterComplex], + dtypes.complex64: [_FilterComplex], + dtypes.float32: [_FilterFloat], + dtypes.float64: [_FilterFloat], + dtypes.int16: [_FilterInt], + dtypes.int32: [_FilterInt], + dtypes.int64: [_FilterInt], + dtypes.int8: [_FilterInt], + dtypes.qint16: [_FilterInt, _FilterTuple], + dtypes.qint32: [_FilterInt, _FilterTuple], + dtypes.qint8: [_FilterInt, _FilterTuple], + dtypes.quint16: [_FilterInt, _FilterTuple], + dtypes.quint8: [_FilterInt, _FilterTuple], + dtypes.string: [_FilterStr], + dtypes.uint16: [_FilterInt], + dtypes.uint8: [_FilterInt], } def _AssertCompatible(values, dtype): - fn = _TF_TO_IS_OK.get(dtype, _FilterNotTensor) - mismatch = fn(values) + fn_list = _TF_TO_IS_OK.get(dtype, [_FilterNotTensor]) + mismatch = _FirstNotNone([fn(values) for fn in fn_list]) if mismatch is not None: if dtype is None: raise TypeError("List of Tensors when single Tensor expected") @@ -290,13 +302,14 @@ def _AssertCompatible(values, dtype): (dtype.name, repr(mismatch), type(mismatch).__name__)) -def make_tensor_proto(values, dtype=None, shape=None): +def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False): """Create a TensorProto. Args: - values: Values to put in the TensorProto. - dtype: Optional tensor_pb2 DataType value. - shape: List of integers representing the dimensions of tensor. + values: Values to put in the TensorProto. + dtype: Optional tensor_pb2 DataType value. + shape: List of integers representing the dimensions of tensor. + verify_shape: Boolean that enables verification of a shape of values. Returns: A TensorProto. Depending on the type, it may contain data in the @@ -306,7 +319,8 @@ def make_tensor_proto(values, dtype=None, shape=None): Raises: TypeError: if unsupported types are provided. - ValueError: if arguments have inappropriate values. + ValueError: if arguments have inappropriate values or if verify_shape is + True and shape of values is not equals to a shape from the argument. make_tensor_proto accepts "values" of a python scalar, a python list, a numpy ndarray, or a numpy scalar. @@ -396,6 +410,11 @@ def make_tensor_proto(values, dtype=None, shape=None): shape_size = np.prod(shape) is_same_size = shape_size == nparray.size + if verify_shape: + if not nparray.shape == tuple(shape): + raise TypeError("Expected Tensor's shape: %s, got %s." % + (tuple(shape), nparray.shape)) + if nparray.size > shape_size: raise ValueError( "Too many elements provided. Needed at most %d, but received %d" % |