aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-01-25 17:20:07 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-25 17:23:28 -0800
commitfbd3e8a2c01d83a6aa6cca044fe5678d20035451 (patch)
treecbd63281a4599f8273278f9554ded595b216f65a
parentc89c452cd7a1675a6e2332d09379469320197a8c (diff)
Make kernel_tests/scalar_test.py work with the C API enabled.
This also moves the set_producer_version function from a specific test file to test_util.py, since it's needed in two test files now. PiperOrigin-RevId: 183316990
-rw-r--r--tensorflow/python/framework/test_util.py12
-rw-r--r--tensorflow/python/kernel_tests/scalar_test.py4
-rw-r--r--tensorflow/python/ops/nn_batchnorm_test.py15
3 files changed, 17 insertions, 14 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 0133318456..6a7e1d0c89 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -53,6 +53,7 @@ from tensorflow.python.eager import tape
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
+from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import versions
@@ -1460,3 +1461,14 @@ def get_node_def_from_graph(node_name, graph_def):
if node_def.name == node_name:
return node_def
return None
+
+
+def set_producer_version(graph, producer_version):
+ """Sets graph.graph_def_versions.producer to `producer_version`."""
+ # The C API doesn't expose altering GraphDefVersions. We can indirectly set
+ # it via import_graph_def though.
+ graph_def = graph_pb2.GraphDef()
+ graph_def.versions.producer = producer_version
+ with graph.as_default():
+ importer.import_graph_def(graph_def)
+ assert graph.graph_def_versions.producer, producer_version
diff --git a/tensorflow/python/kernel_tests/scalar_test.py b/tensorflow/python/kernel_tests/scalar_test.py
index b34426cc21..e65241981e 100644
--- a/tensorflow/python/kernel_tests/scalar_test.py
+++ b/tensorflow/python/kernel_tests/scalar_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_io_ops
from tensorflow.python.ops import math_ops
@@ -30,6 +31,7 @@ import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
+@test_util.with_c_api
class ScalarTest(test.TestCase):
def check(self, op, args, error, correct=None):
@@ -51,7 +53,7 @@ class ScalarTest(test.TestCase):
# Test various GraphDef versions
for version in strict + lenient:
with ops.Graph().as_default() as g:
- g.graph_def_versions.producer = version
+ test_util.set_producer_version(g, version)
with self.test_session(graph=g) as sess:
feed = {}
xs = placeholders(args, feed)
diff --git a/tensorflow/python/ops/nn_batchnorm_test.py b/tensorflow/python/ops/nn_batchnorm_test.py
index fc013b565b..eebfb17085 100644
--- a/tensorflow/python/ops/nn_batchnorm_test.py
+++ b/tensorflow/python/ops/nn_batchnorm_test.py
@@ -21,10 +21,8 @@ from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
-from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
@@ -40,15 +38,6 @@ from tensorflow.python.platform import test
@test_util.with_c_api
class BatchNormalizationTest(test.TestCase):
- def SetProducerVersion(self, graph, producer_version):
- # The C API doesn't expose altering GraphDefVersions. We can indirectly set
- # it via import_graph_def though.
- graph_def = graph_pb2.GraphDef()
- graph_def.versions.producer = producer_version
- with graph.as_default():
- importer.import_graph_def(graph_def)
- assert graph.graph_def_versions.producer, producer_version
-
def _npBatchNorm(self, x, m, v, beta, gamma, epsilon,
scale_after_normalization, shift_after_normalization):
y = (x - m) / np.sqrt(v + epsilon)
@@ -65,7 +54,7 @@ class BatchNormalizationTest(test.TestCase):
def _tfBatchNormV1(self, x, m, v, beta, gamma, epsilon,
scale_after_normalization):
"""Original implementation."""
- self.SetProducerVersion(ops.get_default_graph(), 8)
+ test_util.set_producer_version(ops.get_default_graph(), 8)
return gen_nn_ops._batch_norm_with_global_normalization(
x, m, v, beta, gamma, epsilon, scale_after_normalization)
# pylint: enable=protected-access
@@ -233,7 +222,7 @@ class BatchNormalizationTest(test.TestCase):
epsilon = 0.001
for scale_after_normalization in [True, False]:
# _batch_norm_with_global_normalization_grad is deprecated in v9
- self.SetProducerVersion(ops.get_default_graph(), 8)
+ test_util.set_producer_version(ops.get_default_graph(), 8)
grad = gen_nn_ops._batch_norm_with_global_normalization_grad(
x, m, v, gamma, backprop, epsilon, scale_after_normalization)
dx, dm, dv, db, dg = grad