aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/engine/training_arrays.py
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2018-09-06 17:25:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-06 17:29:48 -0700
commit25f93ba1f880e8b092be611d9a343b18136a267b (patch)
tree14c1282420d20ac814f831a4264b600768ae22c8 /tensorflow/python/keras/engine/training_arrays.py
parentd57cac9d95c8a10650e98f38ca9572c7bd6c6548 (diff)
Adding support for FeatureColumn input in Keras models. Modifies the Model.fit() function to support taking in dictionaries of features in.
Support for functional models coming in a subsequent change. PiperOrigin-RevId: 211897153
Diffstat (limited to 'tensorflow/python/keras/engine/training_arrays.py')
-rw-r--r--tensorflow/python/keras/engine/training_arrays.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/tensorflow/python/keras/engine/training_arrays.py b/tensorflow/python/keras/engine/training_arrays.py
index e2c458c65f..95b864bef0 100644
--- a/tensorflow/python/keras/engine/training_arrays.py
+++ b/tensorflow/python/keras/engine/training_arrays.py
@@ -55,7 +55,7 @@ def fit_loop(model,
Arguments:
model: Keras Model instance.
- inputs: List of input arrays.
+ inputs: Either a list of arrays or a dictionary.
targets: List of target arrays.
sample_weights: Optional list of sample weight arrays.
batch_size: Integer batch size or None if unknown.
@@ -88,6 +88,7 @@ def fit_loop(model,
sample_weights = sample_weights or []
val_sample_weights = val_sample_weights or []
+ inputs = training_utils.ModelInputs(inputs).as_list()
if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
ins = inputs + targets + sample_weights + [1]
else:
@@ -262,6 +263,7 @@ def predict_loop(model, inputs, batch_size=32, verbose=0, steps=None):
model._make_predict_function()
f = model.predict_function
+ inputs = training_utils.ModelInputs(inputs).as_list()
if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
ins = inputs + [0]
else:
@@ -368,6 +370,7 @@ def test_loop(model,
f = model.test_function
sample_weights = sample_weights or []
+ inputs = training_utils.ModelInputs(inputs).as_list()
if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
ins = inputs + targets + sample_weights + [0]
else: