diff options
-rw-r--r-- | tensorflow/c/eager/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api.cc | 57 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api.h | 14 | ||||
-rw-r--r-- | tensorflow/contrib/rpc/python/kernel_tests/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/core/kernels/dynamic_stitch_op.cc | 1 | ||||
-rw-r--r-- | tensorflow/python/eager/function.py | 18 | ||||
-rw-r--r-- | tensorflow/python/eager/function_test.py | 3 | ||||
-rw-r--r-- | tensorflow/python/framework/test_util.py | 24 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/BUILD | 5 | ||||
-rw-r--r-- | tensorflow/python/ops/resource_variable_ops.py | 24 | ||||
-rw-r--r-- | tensorflow/python/pywrap_tfe.i | 2 |
11 files changed, 145 insertions, 6 deletions
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index a2d96357ac..3e14c10727 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -41,6 +41,8 @@ tf_cuda_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + # TODO(b/74620627): move this here + "//tensorflow/python:cpp_shape_inference_proto_cc", ], }) + select({ "//tensorflow:with_xla_support": [ diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 393851d13c..369342b142 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -49,6 +49,7 @@ limitations under the License. #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/public/version.h" +#include "tensorflow/python/framework/cpp_shape_inference.pb.h" using tensorflow::int64; using tensorflow::string; @@ -1015,6 +1016,62 @@ void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, ctx->context.RunMetadataProto()->Clear(); } +void TFE_GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output, + TF_Buffer* output_proto, + TF_Status* status) { + tensorflow::Node* node = &output.oper->node; + tensorflow::CppShapeInferenceResult::HandleData handle_data; + handle_data.set_is_set(true); + { + tensorflow::mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(node); + CHECK(ic != nullptr); + CHECK_LT(output.index, ic->num_outputs()); + const auto* shapes_and_types = + ic->output_handle_shapes_and_types(output.index); + if (shapes_and_types == nullptr) { + output_proto->data = nullptr; + output_proto->length = 0; + output_proto->data_deallocator = nullptr; + return; + } + + for (const auto& p : *shapes_and_types) { + auto* out_shape_and_type = handle_data.add_shape_and_type(); + ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape()); + out_shape_and_type->set_dtype(p.dtype); + } + } + status->status = MessageToBuffer(handle_data, output_proto); +} + +void TFE_SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output, + const void* proto, size_t proto_len, + TF_Status* status) { + tensorflow::CppShapeInferenceResult::HandleData handle_data; + if (!handle_data.ParseFromArray(proto, proto_len)) { + status->status = tensorflow::errors::InvalidArgument( + "Couldn't deserialize HandleData proto"); + return; + } + DCHECK(handle_data.is_set()); + + tensorflow::mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(&output.oper->node); + + std::vector<tensorflow::shape_inference::ShapeAndType> shapes_and_types; + for (const auto& shape_and_type_proto : handle_data.shape_and_type()) { + tensorflow::shape_inference::ShapeHandle shape; + status->status = + ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape); + if (status->status.ok()) return; + shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype()); + } + ic->set_output_handle_shapes_and_types(output.index, shapes_and_types); +} + namespace { TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func, TF_Status* status) { diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 3926c22ce1..15ac0f376c 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -329,6 +329,20 @@ TF_CAPI_EXPORT extern void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status); +// Returns the serialized CppShapeInferenceResult::HandleData proto for +// `output` if its a resource tensor, or otherwise returns an empty buffer. +TF_CAPI_EXPORT extern void TFE_GetResourceHandleShapeAndType( + TF_Graph* graph, TF_Output output, TF_Buffer* output_proto, + TF_Status* status); + +// Sets `output` based on `proto`, which should be a serialized +// CppShapeInferenceResult::HandleData proto. +TF_CAPI_EXPORT extern void TFE_SetResourceHandleShapeAndType(TF_Graph* graph, + TF_Output output, + const void* proto, + size_t proto_len, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/contrib/rpc/python/kernel_tests/BUILD b/tensorflow/contrib/rpc/python/kernel_tests/BUILD index 2311c15a68..f3e6731213 100644 --- a/tensorflow/contrib/rpc/python/kernel_tests/BUILD +++ b/tensorflow/contrib/rpc/python/kernel_tests/BUILD @@ -28,6 +28,7 @@ py_library( py_library( name = "rpc_op_test_base", srcs = ["rpc_op_test_base.py"], + tags = ["notsan"], deps = [ ":test_example_proto_py", "//tensorflow/contrib/proto", diff --git a/tensorflow/core/kernels/dynamic_stitch_op.cc b/tensorflow/core/kernels/dynamic_stitch_op.cc index f018499f6c..b01db91720 100644 --- a/tensorflow/core/kernels/dynamic_stitch_op.cc +++ b/tensorflow/core/kernels/dynamic_stitch_op.cc @@ -326,6 +326,7 @@ struct ParallelDynamicStitchOpCPU : DynamicStitchOpImplCPU<T, true> { ParallelDynamicStitchOpCPU<type>) TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH); +TF_CALL_variant(REGISTER_DYNAMIC_STITCH); #undef REGISTER_DYNAMIC_STITCH #if GOOGLE_CUDA diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 5168ad3b18..0f1170bb42 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -38,6 +38,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.util import compat from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator @@ -69,9 +70,22 @@ def capture_value(tensor_map, value, dtype, name): captured_value = graph_placeholder( dtype=dtype or value.dtype, shape=value.shape, name=name) if captured_value.dtype == dtypes_module.resource: - handle_data = value._handle_data # pylint: disable=protected-access - captured_value._handle_data = handle_data # pylint: disable=protected-access + if ops._USE_C_SHAPES: # pylint: disable=protected-access + if isinstance(value, ops.EagerTensor): + handle_data = value._handle_data # pylint: disable=protected-access + else: + handle_data = resource_variable_ops.get_resource_handle_data(value) + else: + handle_data = value._handle_data # pylint: disable=protected-access if handle_data is not None and handle_data.is_set: + # pylint: disable=protected-access + if ops._USE_C_SHAPES: + pywrap_tensorflow.TFE_SetResourceHandleShapeAndType( + captured_value.graph._c_graph, captured_value._as_tf_output(), + handle_data.SerializeToString()) + else: + captured_value._handle_data = handle_data + # pylint: enable=protected-access # Ensure that shapes and dtypes are propagated. shapes, types = zip(*[(pair.shape, pair.dtype) for pair in handle_data.shape_and_type]) diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 65dde75e60..1828c987f4 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import function as tf_function from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.layers import convolutional from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops @@ -41,6 +42,7 @@ from tensorflow.python.ops import variables from tensorflow.python.training import gradient_descent +@test_util.with_c_shapes class FunctionTest(test.TestCase): def testBasic(self): @@ -615,6 +617,7 @@ class FunctionTest(test.TestCase): self.assertAllEqual([[[[4.0]]]], y.numpy()) +@test_util.with_c_shapes class AutomaticControlDependenciesTest(test.TestCase): def testBasic(self): diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 70e70abc06..f954b9d6c7 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -464,6 +464,30 @@ def with_c_api(cls): return cls +def with_c_shapes(cls): + """Adds methods that call original methods but with C API shapes enabled. + + Note this enables C shapes in new methods after running the test class's + setup method. + + Args: + cls: class to decorate + + Returns: + cls with new test methods added + """ + # If C shapes are already enabled, don't do anything. Some tests break if the + # same test is run twice, so this allows us to turn on the C shapes by default + # without breaking these tests. + if ops._USE_C_SHAPES: + return cls + + for name, value in cls.__dict__.copy().items(): + if callable(value) and name.startswith("test"): + setattr(cls, name + "WithCShapes", enable_c_shapes(value)) + return cls + + def assert_no_new_pyobjects_executing_eagerly(f): """Decorator for asserting that no new Python objects persist after a test. diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 3aedd70f8c..9440f2a4f9 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -1604,7 +1604,10 @@ cuda_py_test( "//tensorflow/python:variables", ], shard_count = 4, - tags = ["noasan"], + tags = [ + "noasan", + "notap", + ], ) cuda_py_test( diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 49dd7f9948..4d26b2f46e 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -24,6 +24,8 @@ from tensorflow.core.framework import variable_pb2 from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context from tensorflow.python.eager import tape +from tensorflow.python.framework import c_api_util +from tensorflow.python.framework import cpp_shape_inference_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -41,6 +43,19 @@ from tensorflow.python.training import checkpointable from tensorflow.python.util import compat +def get_resource_handle_data(graph_op): + assert ops._USE_C_SHAPES # pylint: disable=protected-access + assert type(graph_op) == ops.Tensor # pylint: disable=unidiomatic-typecheck + + with c_api_util.tf_buffer() as buf: + pywrap_tensorflow.TFE_GetResourceHandleShapeAndType( + graph_op.graph._c_graph, graph_op._as_tf_output(), buf) # pylint: disable=protected-access + data = pywrap_tensorflow.TF_GetBuffer(buf) + + return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString( + compat.as_bytes(data)) + + def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode): """Creates a variable handle with information to do shape inference.""" container = ops.get_default_graph()._container # pylint: disable=protected-access @@ -73,9 +88,12 @@ def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode): # shape inference doesn't run in eager mode we copy this data here for when # the handle is captured by an eager mode function. # pylint: disable=protected-access - if h._handle_data is None: - ops.set_shape_and_handle_data_for_outputs(h.op) - handle._handle_data = h._handle_data + if ops._USE_C_SHAPES: + handle._handle_data = get_resource_handle_data(h) + else: + if h._handle_data is None: + ops.set_shape_and_handle_data_for_outputs(h.op) + handle._handle_data = h._handle_data # pylint: enable=protected-access # Clean up our reference cycles to avoid making the garbage collector run. diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 5ee55301df..0982a67dee 100644 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -59,6 +59,8 @@ limitations under the License. %rename("%s") TFE_ContextOptionsSetAsync; %rename("%s") TFE_DeleteContextOptions; %rename("%s") TFE_Py_TensorShapeSlice; +%rename("%s") TFE_GetResourceHandleShapeAndType; +%rename("%s") TFE_SetResourceHandleShapeAndType; %{ #include "tensorflow/python/eager/pywrap_tfe.h" |