diff options
Diffstat (limited to 'tensorflow/contrib/training/python/training/hparam.py')
-rw-r--r-- | tensorflow/contrib/training/python/training/hparam.py | 58 |
1 files changed, 4 insertions, 54 deletions
diff --git a/tensorflow/contrib/training/python/training/hparam.py b/tensorflow/contrib/training/python/training/hparam.py index 7db625cdd5..391899b34f 100644 --- a/tensorflow/contrib/training/python/training/hparam.py +++ b/tensorflow/contrib/training/python/training/hparam.py @@ -18,7 +18,6 @@ from __future__ import division from __future__ import print_function import json -import numbers import re import six @@ -77,7 +76,7 @@ def _process_scalar_value(name, parse_fn, var_type, m_dict, values, function. Raises: - ValueError: If the name has already been used. + ValueError: If the name has already been sued. """ try: parsed_value = parse_fn(m_dict['val']) @@ -139,54 +138,6 @@ def _process_list_value(name, parse_fn, var_type, m_dict, values, _parse_fail(name, var_type, m_dict['vals'], values) -def _cast_to_type_if_compatible(name, param_type, value): - """Cast hparam to the provided type, if compatible. - - Args: - name: Name of the hparam to be cast. - param_type: The type of the hparam. - value: The value to be cast, if compatible. - - Returns: - The result of casting `value` to `param_type`. - - Raises: - ValueError: If the type of `value` is not compatible with param_type. - * If `param_type` is a string type, but `value` is not. - * If `param_type` is a boolean, but `value` is not, or vice versa. - * If `param_type` is an integer type, but `value` is not. - * If `param_type` is a float type, but `value` is not a numeric type. - """ - fail_msg = ( - "Could not cast hparam '%s' of type '%s' from value %r" % - (name, param_type, value)) - - # Some callers use None, for which we can't do any casting/checking. :( - if issubclass(param_type, type(None)): - return value - - # Avoid converting a non-string type to a string. - if (issubclass(param_type, (six.string_types, six.binary_type)) and - not isinstance(value, (six.string_types, six.binary_type))): - raise ValueError(fail_msg) - - # Avoid converting a number or string type to a boolean or vice versa. - if issubclass(param_type, bool) != isinstance(value, bool): - raise ValueError(fail_msg) - - # Avoid converting float to an integer (the reverse is fine). - if (issubclass(param_type, numbers.Integral) and - not isinstance(value, numbers.Integral)): - raise ValueError(fail_msg) - - # Avoid converting a non-numeric type to a numeric type. - if (issubclass(param_type, numbers.Number) and - not isinstance(value, numbers.Number)): - raise ValueError(fail_msg) - - return param_type(value) - - def parse_values(values, type_map): """Parses hyperparameter values from a string into a python map. @@ -487,18 +438,17 @@ class HParams(object): Raises: ValueError: If there is a type mismatch. """ - param_type, is_list = self._hparam_types[name] + _, is_list = self._hparam_types[name] if isinstance(value, list): if not is_list: raise ValueError( 'Must not pass a list for single-valued parameter: %s' % name) - setattr(self, name, [ - _cast_to_type_if_compatible(name, param_type, v) for v in value]) + setattr(self, name, value) else: if is_list: raise ValueError( 'Must pass a list for multi-valued parameter: %s.' % name) - setattr(self, name, _cast_to_type_if_compatible(name, param_type, value)) + setattr(self, name, value) def parse(self, values): """Override hyperparameter values, parsing new values from a string. |