aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework')
-rw-r--r--tensorflow/python/framework/dtypes.py14
-rw-r--r--tensorflow/python/framework/graph_util_impl.py2
-rw-r--r--tensorflow/python/framework/graph_util_test.py2
-rw-r--r--tensorflow/python/framework/load_library.py2
-rw-r--r--tensorflow/python/framework/python_op_gen.i8
-rw-r--r--tensorflow/python/framework/test_util.py2
6 files changed, 17 insertions, 13 deletions
diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py
index 807582bd7e..7f9ef53457 100644
--- a/tensorflow/python/framework/dtypes.py
+++ b/tensorflow/python/framework/dtypes.py
@@ -700,11 +700,13 @@ def as_dtype(type_value):
if type_value.type == np.string_ or type_value.type == np.unicode_:
return string
- for key, val in _NP_TO_TF:
- try:
- if key == type_value:
- return val
- except TypeError as e:
- raise TypeError("Cannot convert {} to a dtype. {}".format(type_value, e))
+ if isinstance(type_value, (type, np.dtype)):
+ for key, val in _NP_TO_TF:
+ try:
+ if key == type_value:
+ return val
+ except TypeError as e:
+ raise TypeError("Cannot convert {} to a dtype. {}".format(
+ type_value, e))
raise TypeError("Cannot convert value %r to a TensorFlow DType." % type_value)
diff --git a/tensorflow/python/framework/graph_util_impl.py b/tensorflow/python/framework/graph_util_impl.py
index 910364364c..394fac6c85 100644
--- a/tensorflow/python/framework/graph_util_impl.py
+++ b/tensorflow/python/framework/graph_util_impl.py
@@ -285,7 +285,7 @@ def convert_variables_to_constants(sess,
output_graph_def.node.extend([output_node])
output_graph_def.library.CopyFrom(inference_graph.library)
- print("Converted %d variables to const ops." % how_many_converted)
+ logging.info("Converted %d variables to const ops.", how_many_converted)
return output_graph_def
diff --git a/tensorflow/python/framework/graph_util_test.py b/tensorflow/python/framework/graph_util_test.py
index b618152b02..2dafb94ba7 100644
--- a/tensorflow/python/framework/graph_util_test.py
+++ b/tensorflow/python/framework/graph_util_test.py
@@ -209,7 +209,7 @@ class DeviceFunctionsTest(test.TestCase):
defun_node, 2.0, name="output_node")
with session.Session() as sess:
- init = variables.initialize_variables([variable_node])
+ init = variables.variables_initializer([variable_node])
sess.run(init)
output = sess.run(output_node)
self.assertNear(4.0, output, 0.00001)
diff --git a/tensorflow/python/framework/load_library.py b/tensorflow/python/framework/load_library.py
index 535c6017f5..9a8477debb 100644
--- a/tensorflow/python/framework/load_library.py
+++ b/tensorflow/python/framework/load_library.py
@@ -58,7 +58,7 @@ def load_op_library(library_filename):
op_list_str = py_tf.TF_GetOpList(lib_handle)
op_list = op_def_pb2.OpList()
op_list.ParseFromString(compat.as_bytes(op_list_str))
- wrappers = py_tf.GetPythonWrappers(op_list_str)
+ wrappers = py_tf.GetEagerPythonWrappers(op_list_str)
# Delete the library handle to release any memory held in C
# that are no longer needed.
diff --git a/tensorflow/python/framework/python_op_gen.i b/tensorflow/python/framework/python_op_gen.i
index 26ec4e8e66..efcce2f209 100644
--- a/tensorflow/python/framework/python_op_gen.i
+++ b/tensorflow/python/framework/python_op_gen.i
@@ -16,10 +16,10 @@ limitations under the License.
%include "tensorflow/python/platform/base.i"
%{
-#include "tensorflow/python/framework/python_op_gen.h"
+#include "tensorflow/python/eager/python_eager_op_gen.h"
%}
-// Input typemap for GetPythonWrappers.
+// Input typemap for GetEagerPythonWrappers.
// Accepts a python object of 'bytes' type, and converts it to
// a const char* pointer and size_t length. The default typemap
// going from python bytes to const char* tries to decode the
@@ -37,5 +37,5 @@ limitations under the License.
%ignoreall;
-%unignore tensorflow::GetPythonWrappers;
-%include "tensorflow/python/framework/python_op_gen.h"
+%unignore tensorflow::GetEagerPythonWrappers;
+%include "tensorflow/python/eager/python_eager_op_gen.h"
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index f954b9d6c7..5a8bc43727 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -1014,6 +1014,8 @@ class TensorFlowTestCase(googletest.TestCase):
config.graph_options.optimizer_options.opt_level = -1
config.graph_options.rewrite_options.constant_folding = (
rewriter_config_pb2.RewriterConfig.OFF)
+ config.graph_options.rewrite_options.arithmetic_optimization = (
+ rewriter_config_pb2.RewriterConfig.OFF)
return config
if graph is None: