Browse Source

Improve memory management in Channel (#368)

* Improve memory management in Channel

* Use the gpr family of allocation functions instead of malloc/free (as
  these cannot return `NULL`).
* Explicitly extend lifetime of `argumentWrappers` - Swift doesn't
  guarantee that objects will live to the end of their scope, so
  the optimizer is free to call Channel.Argument.Wrapper.deinit before
  the call to cgrpc_channel_create_secure, which would result in a
  use-after-free.

* Also use gpr_free to free cgrpc_calls

* Fix imports in call.c
Kevin Sweeney 7 years ago
parent
commit
fcb8ab3360
3 changed files with 24 additions and 22 deletions
  1. 3 3
      Sources/CgRPC/shim/call.c
  2. 5 9
      Sources/CgRPC/shim/channel.c
  3. 16 10
      Sources/SwiftGRPC/Core/Channel.swift

+ 3 - 3
Sources/CgRPC/shim/call.c

@@ -16,15 +16,15 @@
 #include "internal.h"
 #include "cgrpc.h"
 
+#include <grpc/support/alloc.h>
+
 #include <stdlib.h>
-#include <string.h>
-#include <assert.h>
 
 void cgrpc_call_destroy(cgrpc_call *call) {
   if (call->call) {
     grpc_call_unref(call->call);
   }
-  free(call);
+  gpr_free(call);
 }
 
 grpc_call_error cgrpc_call_perform(cgrpc_call *call, cgrpc_operations *operations, void *tag) {

+ 5 - 9
Sources/CgRPC/shim/channel.c

@@ -18,15 +18,12 @@
 #include <grpc/support/string_util.h>
 #include <grpc/support/alloc.h>
 
-#include <assert.h>
-#include <stdio.h>
 #include <stdlib.h>
-#include <string.h>
 
 cgrpc_channel *cgrpc_channel_create(const char *address,
                                     grpc_arg *args,
                                     int num_args) {
-  cgrpc_channel *c = (cgrpc_channel *) malloc(sizeof (cgrpc_channel));
+  cgrpc_channel *c = (cgrpc_channel *) gpr_zalloc(sizeof (cgrpc_channel));
 
   grpc_channel_args channel_args;
   channel_args.args = args;
@@ -43,7 +40,7 @@ cgrpc_channel *cgrpc_channel_create_secure(const char *address,
                                            const char *client_private_key,
                                            grpc_arg *args,
                                            int num_args) {
-  cgrpc_channel *c = (cgrpc_channel *) malloc(sizeof (cgrpc_channel));
+  cgrpc_channel *c = (cgrpc_channel *) gpr_zalloc(sizeof (cgrpc_channel));
 
   grpc_channel_args channel_args;
   channel_args.args = args;
@@ -66,7 +63,7 @@ cgrpc_channel *cgrpc_channel_create_secure(const char *address,
 cgrpc_channel *cgrpc_channel_create_google(const char *address,
                                            grpc_arg *args,
                                            int num_args) {
-  cgrpc_channel *c = (cgrpc_channel *) malloc(sizeof (cgrpc_channel));
+  cgrpc_channel *c = (cgrpc_channel *) gpr_zalloc(sizeof (cgrpc_channel));
 
   grpc_channel_args channel_args;
   channel_args.args = args;
@@ -83,7 +80,7 @@ cgrpc_channel *cgrpc_channel_create_google(const char *address,
 void cgrpc_channel_destroy(cgrpc_channel *c) {
   grpc_channel_destroy(c->channel);
   c->channel = NULL;
-  free(c);
+  gpr_free(c);
 }
 
 cgrpc_call *cgrpc_channel_create_call(cgrpc_channel *channel,
@@ -105,8 +102,7 @@ cgrpc_call *cgrpc_channel_create_call(cgrpc_channel *channel,
                                                      NULL);
   grpc_slice_unref(host_slice);
   grpc_slice_unref(method_slice);
-  cgrpc_call *call = (cgrpc_call *) malloc(sizeof(cgrpc_call));
-  memset(call, 0, sizeof(cgrpc_call));
+  cgrpc_call *call = (cgrpc_call *) gpr_zalloc(sizeof(cgrpc_call));
   call->call = channel_call;
   return call;
 }

+ 16 - 10
Sources/SwiftGRPC/Core/Channel.swift

@@ -45,12 +45,14 @@ public class Channel {
     gRPC.initialize()
     host = address
     let argumentWrappers = arguments.map { $0.toCArg() }
-    var argumentValues = argumentWrappers.map { $0.wrapped }
 
-    if secure {
-      underlyingChannel = cgrpc_channel_create_secure(address, roots_pem(), nil, nil, &argumentValues, Int32(arguments.count))
-    } else {
-      underlyingChannel = cgrpc_channel_create(address, &argumentValues, Int32(arguments.count))
+    underlyingChannel = withExtendedLifetime(argumentWrappers) {
+        var argumentValues = argumentWrappers.map { $0.wrapped }
+        if secure {
+          return cgrpc_channel_create_secure(address, roots_pem(), nil, nil, &argumentValues, Int32(arguments.count))
+        } else {
+          return cgrpc_channel_create(address, &argumentValues, Int32(arguments.count))
+        }
     }
     completionQueue = CompletionQueue(underlyingCompletionQueue: cgrpc_channel_completion_queue(underlyingChannel), name: "Client")
     completionQueue.run() // start a loop that watches the channel's completion queue
@@ -64,9 +66,11 @@ public class Channel {
     gRPC.initialize()
     host = googleAddress
     let argumentWrappers = arguments.map { $0.toCArg() }
-    var argumentValues = argumentWrappers.map { $0.wrapped }
-
-    underlyingChannel = cgrpc_channel_create_google(googleAddress, &argumentValues, Int32(arguments.count))
+    
+    underlyingChannel = withExtendedLifetime(argumentWrappers) {
+        var argumentValues = argumentWrappers.map { $0.wrapped }
+        return cgrpc_channel_create_google(googleAddress, &argumentValues, Int32(arguments.count))
+    }
 
     completionQueue = CompletionQueue(underlyingCompletionQueue: cgrpc_channel_completion_queue(underlyingChannel), name: "Client")
     completionQueue.run() // start a loop that watches the channel's completion queue
@@ -83,9 +87,11 @@ public class Channel {
     gRPC.initialize()
     host = address
     let argumentWrappers = arguments.map { $0.toCArg() }
-    var argumentValues = argumentWrappers.map { $0.wrapped }
 
-    underlyingChannel = cgrpc_channel_create_secure(address, certificates, clientCertificates, clientKey, &argumentValues, Int32(arguments.count))
+    underlyingChannel = withExtendedLifetime(argumentWrappers) {
+        var argumentValues = argumentWrappers.map { $0.wrapped }
+        return cgrpc_channel_create_secure(address, certificates, clientCertificates, clientKey, &argumentValues, Int32(arguments.count))
+    }
     completionQueue = CompletionQueue(underlyingCompletionQueue: cgrpc_channel_completion_queue(underlyingChannel), name: "Client")
     completionQueue.run() // start a loop that watches the channel's completion queue
   }