aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/training/python/training/hparam.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/training/python/training/hparam.py')
-rw-r--r--tensorflow/contrib/training/python/training/hparam.py58
1 files changed, 4 insertions, 54 deletions
diff --git a/tensorflow/contrib/training/python/training/hparam.py b/tensorflow/contrib/training/python/training/hparam.py
index 7db625cdd5..391899b34f 100644
--- a/tensorflow/contrib/training/python/training/hparam.py
+++ b/tensorflow/contrib/training/python/training/hparam.py
@@ -18,7 +18,6 @@ from __future__ import division
from __future__ import print_function
import json
-import numbers
import re
import six
@@ -77,7 +76,7 @@ def _process_scalar_value(name, parse_fn, var_type, m_dict, values,
function.
Raises:
- ValueError: If the name has already been used.
+ ValueError: If the name has already been sued.
"""
try:
parsed_value = parse_fn(m_dict['val'])
@@ -139,54 +138,6 @@ def _process_list_value(name, parse_fn, var_type, m_dict, values,
_parse_fail(name, var_type, m_dict['vals'], values)
-def _cast_to_type_if_compatible(name, param_type, value):
- """Cast hparam to the provided type, if compatible.
-
- Args:
- name: Name of the hparam to be cast.
- param_type: The type of the hparam.
- value: The value to be cast, if compatible.
-
- Returns:
- The result of casting `value` to `param_type`.
-
- Raises:
- ValueError: If the type of `value` is not compatible with param_type.
- * If `param_type` is a string type, but `value` is not.
- * If `param_type` is a boolean, but `value` is not, or vice versa.
- * If `param_type` is an integer type, but `value` is not.
- * If `param_type` is a float type, but `value` is not a numeric type.
- """
- fail_msg = (
- "Could not cast hparam '%s' of type '%s' from value %r" %
- (name, param_type, value))
-
- # Some callers use None, for which we can't do any casting/checking. :(
- if issubclass(param_type, type(None)):
- return value
-
- # Avoid converting a non-string type to a string.
- if (issubclass(param_type, (six.string_types, six.binary_type)) and
- not isinstance(value, (six.string_types, six.binary_type))):
- raise ValueError(fail_msg)
-
- # Avoid converting a number or string type to a boolean or vice versa.
- if issubclass(param_type, bool) != isinstance(value, bool):
- raise ValueError(fail_msg)
-
- # Avoid converting float to an integer (the reverse is fine).
- if (issubclass(param_type, numbers.Integral) and
- not isinstance(value, numbers.Integral)):
- raise ValueError(fail_msg)
-
- # Avoid converting a non-numeric type to a numeric type.
- if (issubclass(param_type, numbers.Number) and
- not isinstance(value, numbers.Number)):
- raise ValueError(fail_msg)
-
- return param_type(value)
-
-
def parse_values(values, type_map):
"""Parses hyperparameter values from a string into a python map.
@@ -487,18 +438,17 @@ class HParams(object):
Raises:
ValueError: If there is a type mismatch.
"""
- param_type, is_list = self._hparam_types[name]
+ _, is_list = self._hparam_types[name]
if isinstance(value, list):
if not is_list:
raise ValueError(
'Must not pass a list for single-valued parameter: %s' % name)
- setattr(self, name, [
- _cast_to_type_if_compatible(name, param_type, v) for v in value])
+ setattr(self, name, value)
else:
if is_list:
raise ValueError(
'Must pass a list for multi-valued parameter: %s.' % name)
- setattr(self, name, _cast_to_type_if_compatible(name, param_type, value))
+ setattr(self, name, value)
def parse(self, values):
"""Override hyperparameter values, parsing new values from a string.