aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-10 14:23:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-10 14:23:23 -0700
commitad872f220df6808e8a5fcb926480f87cb2371dfd (patch)
tree8996f7bf76a0503c5002df44109bab59fbc6e2bc
parentbb9c72ae54f3a4a16b851a811a20f93740f5f1d3 (diff)
parent4a1fdff581db18e3262daebbc1f9543936bf47d1 (diff)
Merge pull request #21073 from yanboliang:to-json
PiperOrigin-RevId: 212332139
-rw-r--r--tensorflow/python/keras/engine/network.py5
-rw-r--r--tensorflow/python/keras/engine/topology_test.py17
2 files changed, 21 insertions, 1 deletions
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index 10dd70cf23..5ef8d13487 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -1576,7 +1576,10 @@ class Network(base_layer.Layer):
def get_json_type(obj):
# If obj is any numpy type
if type(obj).__module__ == np.__name__:
- return obj.item()
+ if isinstance(obj, np.ndarray):
+ return obj.tolist()
+ else:
+ return obj.item()
# If obj is a python 'type'
if type(obj).__name__ == type.__name__:
diff --git a/tensorflow/python/keras/engine/topology_test.py b/tensorflow/python/keras/engine/topology_test.py
index 079c8dae71..1fcd77d7f6 100644
--- a/tensorflow/python/keras/engine/topology_test.py
+++ b/tensorflow/python/keras/engine/topology_test.py
@@ -912,6 +912,23 @@ class TopologyConstructionTest(test.TestCase):
assert out.shape == (4, 3, 2, 1)
self.assertAllClose(out, x * 0.2 + x * 0.3, atol=1e-4)
+ def test_constant_initializer_with_numpy(self):
+
+ with self.test_session():
+ initializer = keras.initializers.Constant(np.ones((3, 2)))
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,),
+ kernel_initializer=initializer))
+ model.add(keras.layers.Dense(3))
+ model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
+
+ json_str = model.to_json()
+ keras.models.model_from_json(json_str)
+
+ if yaml is not None:
+ yaml_str = model.to_yaml()
+ keras.models.model_from_yaml(yaml_str)
+
class DeferredModeTest(test.TestCase):