diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-05-02 15:21:17 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-02 15:24:07 -0700 |
commit | 4704ae7af1918755d72f159f49d98d35da6eb6fa (patch) | |
tree | ba461e70d073bf60260ce9c8b65a4ee2e04e6781 /tensorflow/python/feature_column | |
parent | 9180cc254dff42368af126aa68eb82823ef67736 (diff) |
Optimize LogicalOr and LogicalAnd with all true or false inputs:
LogicalOr(x, true) = true
LogicalOr(x, false) = x
LogicalAnd(x, true) = x
LogicalAnd(x, false) = false
and similar if the first argument is constant.
PiperOrigin-RevId: 195161140
Diffstat (limited to 'tensorflow/python/feature_column')
-rw-r--r-- | tensorflow/python/feature_column/feature_column_test.py | 20 |
1 files changed, 16 insertions, 4 deletions
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py index d963dd9b55..b06540489f 100644 --- a/tensorflow/python/feature_column/feature_column_test.py +++ b/tensorflow/python/feature_column/feature_column_test.py @@ -25,6 +25,8 @@ import numpy as np from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 +from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session from tensorflow.python.eager import backprop from tensorflow.python.eager import context @@ -54,8 +56,8 @@ from tensorflow.python.training import coordinator from tensorflow.python.training import queue_runner_impl -def _initialized_session(): - sess = session.Session() +def _initialized_session(config=None): + sess = session.Session(config=config) sess.run(variables_lib.global_variables_initializer()) sess.run(lookup_ops.tables_initializer()) return sess @@ -6191,7 +6193,12 @@ class WeightedCategoricalColumnTest(test.TestCase): 'values': ((.5,), (1.,)) }, (column,), sparse_combiner='mean') - with _initialized_session(): + # Disabling the constant folding optimizer here since it changes the + # error message differently on CPU and GPU. + config = config_pb2.ConfigProto() + config.graph_options.rewrite_options.constant_folding = ( + rewriter_config_pb2.RewriterConfig.OFF) + with _initialized_session(config): with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'): predictions.eval() @@ -6284,7 +6291,12 @@ class WeightedCategoricalColumnTest(test.TestCase): 'values': ((.5,), (1.,)) }, (column,), sparse_combiner='mean') - with _initialized_session(): + # Disabling the constant folding optimizer here since it changes the + # error message differently on CPU and GPU. + config = config_pb2.ConfigProto() + config.graph_options.rewrite_options.constant_folding = ( + rewriter_config_pb2.RewriterConfig.OFF) + with _initialized_session(config): with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'): predictions.eval() |