diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-10 14:23:23 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-10 14:23:23 -0700 |
commit | ad872f220df6808e8a5fcb926480f87cb2371dfd (patch) | |
tree | 8996f7bf76a0503c5002df44109bab59fbc6e2bc | |
parent | bb9c72ae54f3a4a16b851a811a20f93740f5f1d3 (diff) | |
parent | 4a1fdff581db18e3262daebbc1f9543936bf47d1 (diff) |
Merge pull request #21073 from yanboliang:to-json
PiperOrigin-RevId: 212332139
-rw-r--r-- | tensorflow/python/keras/engine/network.py | 5 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/topology_test.py | 17 |
2 files changed, 21 insertions, 1 deletions
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index 10dd70cf23..5ef8d13487 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -1576,7 +1576,10 @@ class Network(base_layer.Layer): def get_json_type(obj): # If obj is any numpy type if type(obj).__module__ == np.__name__: - return obj.item() + if isinstance(obj, np.ndarray): + return obj.tolist() + else: + return obj.item() # If obj is a python 'type' if type(obj).__name__ == type.__name__: diff --git a/tensorflow/python/keras/engine/topology_test.py b/tensorflow/python/keras/engine/topology_test.py index 079c8dae71..1fcd77d7f6 100644 --- a/tensorflow/python/keras/engine/topology_test.py +++ b/tensorflow/python/keras/engine/topology_test.py @@ -912,6 +912,23 @@ class TopologyConstructionTest(test.TestCase): assert out.shape == (4, 3, 2, 1) self.assertAllClose(out, x * 0.2 + x * 0.3, atol=1e-4) + def test_constant_initializer_with_numpy(self): + + with self.test_session(): + initializer = keras.initializers.Constant(np.ones((3, 2))) + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,), + kernel_initializer=initializer)) + model.add(keras.layers.Dense(3)) + model.compile(loss='mse', optimizer='sgd', metrics=['acc']) + + json_str = model.to_json() + keras.models.model_from_json(json_str) + + if yaml is not None: + yaml_str = model.to_yaml() + keras.models.model_from_yaml(yaml_str) + class DeferredModeTest(test.TestCase): |