aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/feature_column
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-02 15:21:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-02 15:24:07 -0700
commit4704ae7af1918755d72f159f49d98d35da6eb6fa (patch)
treeba461e70d073bf60260ce9c8b65a4ee2e04e6781 /tensorflow/python/feature_column
parent9180cc254dff42368af126aa68eb82823ef67736 (diff)
Optimize LogicalOr and LogicalAnd with all true or false inputs:
LogicalOr(x, true) = true LogicalOr(x, false) = x LogicalAnd(x, true) = x LogicalAnd(x, false) = false and similar if the first argument is constant. PiperOrigin-RevId: 195161140
Diffstat (limited to 'tensorflow/python/feature_column')
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py20
1 files changed, 16 insertions, 4 deletions
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index d963dd9b55..b06540489f 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -25,6 +25,8 @@ import numpy as np
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
@@ -54,8 +56,8 @@ from tensorflow.python.training import coordinator
from tensorflow.python.training import queue_runner_impl
-def _initialized_session():
- sess = session.Session()
+def _initialized_session(config=None):
+ sess = session.Session(config=config)
sess.run(variables_lib.global_variables_initializer())
sess.run(lookup_ops.tables_initializer())
return sess
@@ -6191,7 +6193,12 @@ class WeightedCategoricalColumnTest(test.TestCase):
'values': ((.5,), (1.,))
}, (column,),
sparse_combiner='mean')
- with _initialized_session():
+ # Disabling the constant folding optimizer here since it changes the
+ # error message differently on CPU and GPU.
+ config = config_pb2.ConfigProto()
+ config.graph_options.rewrite_options.constant_folding = (
+ rewriter_config_pb2.RewriterConfig.OFF)
+ with _initialized_session(config):
with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'):
predictions.eval()
@@ -6284,7 +6291,12 @@ class WeightedCategoricalColumnTest(test.TestCase):
'values': ((.5,), (1.,))
}, (column,),
sparse_combiner='mean')
- with _initialized_session():
+ # Disabling the constant folding optimizer here since it changes the
+ # error message differently on CPU and GPU.
+ config = config_pb2.ConfigProto()
+ config.graph_options.rewrite_options.constant_folding = (
+ rewriter_config_pb2.RewriterConfig.OFF)
+ with _initialized_session(config):
with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'):
predictions.eval()