aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Igor Saprykin <isaprykin@google.com>2017-11-13 19:08:03 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-13 19:11:48 -0800
commit7f02a3cddf08fa63a279a89ada2600d18399c383 (patch)
tree4c10b631e5d82dcad7132e409d3e96f05a16355f
parentf9e3e8d8731daf338b6dc743aef84c35740ca037 (diff)
Re-order arguments on the replicated model_fn.
This supports the use cases that call Estimator's model_fn via positional arguments. The right order is defined by Estimator as follows: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/estimator/estimator.py#L102. PiperOrigin-RevId: 175624067
-rw-r--r--tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py2
-rw-r--r--tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py32
2 files changed, 17 insertions, 17 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
index 0848c5f62f..dcc48d1fd9 100644
--- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
+++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
@@ -145,7 +145,7 @@ def replicate_model_fn(model_fn, optimizer_fn, devices=None):
'server device is going to be {}.'.format(
devices, local_ps_device))
- def replicated_model_fn(mode, features, labels, params=None, config=None):
+ def replicated_model_fn(features, labels, mode, params=None, config=None):
"""Replicated version of `model_fn` to be used instead."""
feature_shards, label_shards = _split_batch(
features, labels, len(devices), device=local_ps_device)
diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
index 7fb1065ac0..5a1982f5eb 100644
--- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
@@ -189,8 +189,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
with self.test_session() as session:
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1'])
- estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.TRAIN,
- features, labels, self.params)
+ estimator_spec = replicated_model_fn(
+ features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
session.run(variables.global_variables_initializer())
# loss = feature * c - label
@@ -219,8 +219,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
devices=['/gpu:0', '/gpu:1'])
# This call is going to fail if `replicated_model_fn` is still passing
# `params` inside `optimizer_fn`, even though the latter doesn't take any:
- estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.TRAIN,
- features, labels, self.params)
+ estimator_spec = replicated_model_fn(
+ features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
del estimator_spec
def test_eval(self):
@@ -230,8 +230,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
with self.test_session() as session:
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1'])
- estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.EVAL, features,
- labels, self.params)
+ estimator_spec = replicated_model_fn(
+ features, labels, model_fn_lib.ModeKeys.EVAL, self.params)
session.run(variables.local_variables_initializer())
session.run(variables.global_variables_initializer())
@@ -259,8 +259,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
with self.test_session() as session:
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1'])
- estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.PREDICT,
- features, labels, self.params)
+ estimator_spec = replicated_model_fn(
+ features, labels, model_fn_lib.ModeKeys.PREDICT, self.params)
session.run(variables.global_variables_initializer())
self.assertAllClose({
@@ -274,8 +274,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
with self.test_session() as session:
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn, self.optimizer_fn)
- estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.TRAIN,
- features, labels, self.params)
+ estimator_spec = replicated_model_fn(
+ features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
session.run(variables.global_variables_initializer())
# loss = feature * c - label
@@ -296,8 +296,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
with self.test_session() as session:
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn, self.optimizer_fn, devices=['/gpu:0'])
- estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.EVAL, features,
- labels, self.params)
+ estimator_spec = replicated_model_fn(
+ features, labels, model_fn_lib.ModeKeys.EVAL, self.params)
session.run(variables.local_variables_initializer())
session.run(variables.global_variables_initializer())
@@ -324,8 +324,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
with self.test_session() as session:
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn, self.optimizer_fn, devices=['/gpu:0'])
- estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.PREDICT,
- features, labels, self.params)
+ estimator_spec = replicated_model_fn(
+ features, labels, model_fn_lib.ModeKeys.PREDICT, self.params)
session.run(variables.global_variables_initializer())
self.assertAllClose({
@@ -778,8 +778,8 @@ class MergeExportOutputsTest(test_util.TensorFlowTestCase):
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1'])
- estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.PREDICT,
- features, labels, {})
+ estimator_spec = replicated_model_fn(features, labels,
+ model_fn_lib.ModeKeys.PREDICT, {})
session.run(variables.global_variables_initializer())
return estimator_spec