diff options
author | 2018-09-06 17:25:10 -0700 | |
---|---|---|
committer | 2018-09-06 17:29:48 -0700 | |
commit | 25f93ba1f880e8b092be611d9a343b18136a267b (patch) | |
tree | 14c1282420d20ac814f831a4264b600768ae22c8 /tensorflow/python/keras/engine/training_arrays.py | |
parent | d57cac9d95c8a10650e98f38ca9572c7bd6c6548 (diff) |
Adding support for FeatureColumn input in Keras models. Modifies the Model.fit() function to support taking in dictionaries of features in.
Support for functional models coming in a subsequent change.
PiperOrigin-RevId: 211897153
Diffstat (limited to 'tensorflow/python/keras/engine/training_arrays.py')
-rw-r--r-- | tensorflow/python/keras/engine/training_arrays.py | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/tensorflow/python/keras/engine/training_arrays.py b/tensorflow/python/keras/engine/training_arrays.py index e2c458c65f..95b864bef0 100644 --- a/tensorflow/python/keras/engine/training_arrays.py +++ b/tensorflow/python/keras/engine/training_arrays.py @@ -55,7 +55,7 @@ def fit_loop(model, Arguments: model: Keras Model instance. - inputs: List of input arrays. + inputs: Either a list of arrays or a dictionary. targets: List of target arrays. sample_weights: Optional list of sample weight arrays. batch_size: Integer batch size or None if unknown. @@ -88,6 +88,7 @@ def fit_loop(model, sample_weights = sample_weights or [] val_sample_weights = val_sample_weights or [] + inputs = training_utils.ModelInputs(inputs).as_list() if model.uses_learning_phase and not isinstance(K.learning_phase(), int): ins = inputs + targets + sample_weights + [1] else: @@ -262,6 +263,7 @@ def predict_loop(model, inputs, batch_size=32, verbose=0, steps=None): model._make_predict_function() f = model.predict_function + inputs = training_utils.ModelInputs(inputs).as_list() if model.uses_learning_phase and not isinstance(K.learning_phase(), int): ins = inputs + [0] else: @@ -368,6 +370,7 @@ def test_loop(model, f = model.test_function sample_weights = sample_weights or [] + inputs = training_utils.ModelInputs(inputs).as_list() if model.uses_learning_phase and not isinstance(K.learning_phase(), int): ins = inputs + targets + sample_weights + [0] else: |