diff options
author | 2017-11-13 19:08:03 -0800 | |
---|---|---|
committer | 2017-11-13 19:11:48 -0800 | |
commit | 7f02a3cddf08fa63a279a89ada2600d18399c383 (patch) | |
tree | 4c10b631e5d82dcad7132e409d3e96f05a16355f | |
parent | f9e3e8d8731daf338b6dc743aef84c35740ca037 (diff) |
Re-order arguments on the replicated model_fn.
This supports the use cases that call Estimator's model_fn via positional arguments.
The right order is defined by Estimator as follows: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/estimator/estimator.py#L102.
PiperOrigin-RevId: 175624067
-rw-r--r-- | tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py | 32 |
2 files changed, 17 insertions, 17 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py index 0848c5f62f..dcc48d1fd9 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py @@ -145,7 +145,7 @@ def replicate_model_fn(model_fn, optimizer_fn, devices=None): 'server device is going to be {}.'.format( devices, local_ps_device)) - def replicated_model_fn(mode, features, labels, params=None, config=None): + def replicated_model_fn(features, labels, mode, params=None, config=None): """Replicated version of `model_fn` to be used instead.""" feature_shards, label_shards = _split_batch( features, labels, len(devices), device=local_ps_device) diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py index 7fb1065ac0..5a1982f5eb 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py @@ -189,8 +189,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) - estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.TRAIN, - features, labels, self.params) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) session.run(variables.global_variables_initializer()) # loss = feature * c - label @@ -219,8 +219,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): devices=['/gpu:0', '/gpu:1']) # This call is going to fail if `replicated_model_fn` is still passing # `params` inside `optimizer_fn`, even though the latter doesn't take any: - estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.TRAIN, - features, labels, self.params) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) del estimator_spec def test_eval(self): @@ -230,8 +230,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) - estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.EVAL, features, - labels, self.params) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.EVAL, self.params) session.run(variables.local_variables_initializer()) session.run(variables.global_variables_initializer()) @@ -259,8 +259,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) - estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.PREDICT, - features, labels, self.params) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.PREDICT, self.params) session.run(variables.global_variables_initializer()) self.assertAllClose({ @@ -274,8 +274,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( self.model_fn, self.optimizer_fn) - estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.TRAIN, - features, labels, self.params) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) session.run(variables.global_variables_initializer()) # loss = feature * c - label @@ -296,8 +296,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( self.model_fn, self.optimizer_fn, devices=['/gpu:0']) - estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.EVAL, features, - labels, self.params) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.EVAL, self.params) session.run(variables.local_variables_initializer()) session.run(variables.global_variables_initializer()) @@ -324,8 +324,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( self.model_fn, self.optimizer_fn, devices=['/gpu:0']) - estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.PREDICT, - features, labels, self.params) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.PREDICT, self.params) session.run(variables.global_variables_initializer()) self.assertAllClose({ @@ -778,8 +778,8 @@ class MergeExportOutputsTest(test_util.TensorFlowTestCase): replicated_model_fn = replicate_model_fn.replicate_model_fn( self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) - estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.PREDICT, - features, labels, {}) + estimator_spec = replicated_model_fn(features, labels, + model_fn_lib.ModeKeys.PREDICT, {}) session.run(variables.global_variables_initializer()) return estimator_spec |