aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/keras/_impl/keras/estimator.py22
-rw-r--r--tensorflow/python/keras/_impl/keras/estimator_test.py24
2 files changed, 36 insertions, 10 deletions
diff --git a/tensorflow/python/keras/_impl/keras/estimator.py b/tensorflow/python/keras/_impl/keras/estimator.py
index b922a6c683..c3c3fceb45 100644
--- a/tensorflow/python/keras/_impl/keras/estimator.py
+++ b/tensorflow/python/keras/_impl/keras/estimator.py
@@ -29,12 +29,14 @@ from tensorflow.python.estimator import run_config as run_config_lib
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
+from tensorflow.python.framework import tensor_util
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras import models
from tensorflow.python.keras._impl.keras import optimizers
from tensorflow.python.keras._impl.keras.engine.base_layer import Layer
from tensorflow.python.keras._impl.keras.engine.network import Network
from tensorflow.python.keras._impl.keras.utils.generic_utils import CustomObjectScope
+from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics as metrics_module
from tensorflow.python.ops import variables as variables_module
@@ -55,6 +57,17 @@ def _cast_tensor_to_floatx(x):
return math_ops.cast(x, K.floatx())
+def _convert_tensor(x):
+ """Create or cast tensor if needed."""
+ if not tensor_util.is_tensor(x):
+ # x is a numpy array
+ x = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(x)
+ if check_ops.is_numeric_tensor(x):
+ # is_numeric_tensor returns False if provided with a numpy array
+ x = _cast_tensor_to_floatx(x)
+ return x
+
+
def _any_variable_initalized():
"""Check if any variable has been initialized in the Keras model.
@@ -86,7 +99,7 @@ def _create_ordered_io(keras_model, estimator_io, is_input=True):
if isinstance(estimator_io, (list, tuple)):
# Case currently not supported by most built-in input_fn,
# but it's good to have for sanity
- return [_cast_tensor_to_floatx(x) for x in estimator_io]
+ return [_convert_tensor(x) for x in estimator_io]
elif isinstance(estimator_io, dict):
if is_input:
if keras_model._is_graph_network:
@@ -108,12 +121,12 @@ def _create_ordered_io(keras_model, estimator_io, is_input=True):
'It needs to match one '
'of the following: %s' % ('input' if is_input else 'output', key,
', '.join(keras_io_names)))
- tensors = [_cast_tensor_to_floatx(estimator_io[io_name])
+ tensors = [_convert_tensor(estimator_io[io_name])
for io_name in keras_io_names]
return tensors
else:
# Plain array.
- return _cast_tensor_to_floatx(estimator_io)
+ return _convert_tensor(estimator_io)
def _in_place_subclassed_model_reset(model):
@@ -274,8 +287,7 @@ def _clone_and_build_model(mode,
is_input=False)
else:
target_tensors = [
- _cast_tensor_to_floatx(
- sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(labels))
+ _convert_tensor(labels)
]
if keras_model._is_graph_network:
diff --git a/tensorflow/python/keras/_impl/keras/estimator_test.py b/tensorflow/python/keras/_impl/keras/estimator_test.py
index 653cdc01e2..80fa87d041 100644
--- a/tensorflow/python/keras/_impl/keras/estimator_test.py
+++ b/tensorflow/python/keras/_impl/keras/estimator_test.py
@@ -30,6 +30,7 @@ from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras._impl import keras
+from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras import testing_utils
from tensorflow.python.keras._impl.keras.applications import mobilenet
from tensorflow.python.keras._impl.keras.optimizers import SGD
@@ -142,16 +143,20 @@ def randomize_io_type(array, name):
def multi_inputs_multi_outputs_model():
- # test multi-input layer
a = keras.layers.Input(shape=(16,), name='input_a')
b = keras.layers.Input(shape=(16,), name='input_b')
+ m = keras.layers.Input(shape=(8,), dtype='bool', name='input_m')
dense = keras.layers.Dense(8, name='dense_1')
+
a_2 = dense(a)
+ # Apply a mask
+ s_2 = keras.layers.Lambda(lambda k:
+ K.switch(k[0], k[1], K.zeros_like(k[1])))([m, a_2])
b_2 = dense(b)
- merged = keras.layers.concatenate([a_2, b_2], name='merge')
+ merged = keras.layers.concatenate([s_2, b_2], name='merge')
c = keras.layers.Dense(3, activation='softmax', name='dense_2')(merged)
d = keras.layers.Dense(2, activation='softmax', name='dense_3')(merged)
- model = keras.models.Model(inputs=[a, b], outputs=[c, d])
+ model = keras.models.Model(inputs=[a, b, m], outputs=[c, d])
model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
@@ -352,18 +357,27 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
test_samples=50,
input_shape=(16,),
num_classes=2)
+ np.random.seed(_RANDOM_SEED)
+ (input_m_train, _), (input_m_test, _) = testing_utils.get_test_data(
+ train_samples=_TRAIN_SIZE,
+ test_samples=50,
+ input_shape=(8,),
+ num_classes=2)
+
c_train = keras.utils.to_categorical(c_train)
c_test = keras.utils.to_categorical(c_test)
d_train = keras.utils.to_categorical(d_train)
d_test = keras.utils.to_categorical(d_test)
def train_input_fn():
- input_dict = {'input_a': a_train, 'input_b': b_train}
+ input_dict = {'input_a': a_train, 'input_b': b_train,
+ 'input_m': input_m_train > 0}
output_dict = {'dense_2': c_train, 'dense_3': d_train}
return input_dict, output_dict
def eval_input_fn():
- input_dict = {'input_a': a_test, 'input_b': b_test}
+ input_dict = {'input_a': a_test, 'input_b': b_test,
+ 'input_m': input_m_test > 0}
output_dict = {'dense_2': c_test, 'dense_3': d_test}
return input_dict, output_dict