Jelajahi Sumber

propagate malloc failure

eihrul 15 tahun lalu
induk
melakukan
41ccbd2f3f
4 mengubah file dengan 42 tambahan dan 16 penghapusan
  1. 15 3
      host.c
  2. 2 0
      packet.c
  3. 18 7
      peer.c
  4. 7 6
      protocol.c

+ 15 - 3
host.c

@@ -27,13 +27,23 @@
 ENetHost *
 enet_host_create (const ENetAddress * address, size_t peerCount, enet_uint32 incomingBandwidth, enet_uint32 outgoingBandwidth)
 {
-    ENetHost * host = (ENetHost *) enet_malloc (sizeof (ENetHost));
+    ENetHost * host;
     ENetPeer * currentPeer;
 
     if (peerCount > ENET_PROTOCOL_MAXIMUM_PEER_ID)
       return NULL;
 
+    host = (ENetHost *) enet_malloc (sizeof (ENetHost));
+    if (host == NULL)
+      return NULL;
+
     host -> peers = (ENetPeer *) enet_malloc (peerCount * sizeof (ENetPeer));
+    if (host -> peers == NULL)
+    {
+       enet_free (host);
+
+       return NULL;
+    }
     memset (host -> peers, 0, peerCount * sizeof (ENetPeer));
 
     host -> socket = enet_socket_create (ENET_SOCKET_TYPE_DATAGRAM);
@@ -142,10 +152,12 @@ enet_host_connect (ENetHost * host, const ENetAddress * address, size_t channelC
     if (currentPeer >= & host -> peers [host -> peerCount])
       return NULL;
 
-    currentPeer -> state = ENET_PEER_STATE_CONNECTING;
-    currentPeer -> address = * address;
     currentPeer -> channels = (ENetChannel *) enet_malloc (channelCount * sizeof (ENetChannel));
+    if (currentPeer -> channels == NULL)
+      return NULL;
     currentPeer -> channelCount = channelCount;
+    currentPeer -> state = ENET_PEER_STATE_CONNECTING;
+    currentPeer -> address = * address;
     currentPeer -> sessionID = (enet_uint32) enet_rand ();
 
     if (host -> outgoingBandwidth == 0)

+ 2 - 0
packet.c

@@ -20,6 +20,8 @@ ENetPacket *
 enet_packet_create (const void * data, size_t dataLength, enet_uint32 flags)
 {
     ENetPacket * packet = (ENetPacket *) enet_malloc (sizeof (ENetPacket));
+    if (packet == NULL)
+      return NULL;
 
     if (flags & ENET_PACKET_FLAG_NO_ALLOCATE)
       packet -> data = (enet_uint8 *) data;

+ 18 - 7
peer.c

@@ -137,7 +137,8 @@ enet_peer_send (ENetPeer * peer, enet_uint8 channelID, ENetPacket * packet)
          command.sendFragment.totalLength = ENET_HOST_TO_NET_32 (packet -> dataLength);
          command.sendFragment.fragmentOffset = ENET_NET_TO_HOST_32 (fragmentOffset);
 
-         enet_peer_queue_outgoing_command (peer, & command, packet, fragmentOffset, fragmentLength);
+         if (enet_peer_queue_outgoing_command (peer, & command, packet, fragmentOffset, fragmentLength) == NULL)
+           return -1;
       }
 
       return 0;
@@ -167,7 +168,8 @@ enet_peer_send (ENetPeer * peer, enet_uint8 channelID, ENetPacket * packet)
       command.sendUnreliable.dataLength = ENET_HOST_TO_NET_16 (packet -> dataLength);
    }
 
-   enet_peer_queue_outgoing_command (peer, & command, packet, 0, packet -> dataLength);
+   if (enet_peer_queue_outgoing_command (peer, & command, packet, 0, packet -> dataLength) == NULL)
+     return -1;
 
    return 0;
 }
@@ -487,9 +489,11 @@ enet_peer_queue_acknowledgement (ENetPeer * peer, const ENetProtocol * command,
           return NULL;
     }
 
-    peer -> outgoingDataTotal += sizeof (ENetProtocolAcknowledge);
-
     acknowledgement = (ENetAcknowledgement *) enet_malloc (sizeof (ENetAcknowledgement));
+    if (acknowledgement == NULL)
+      return NULL;
+
+    peer -> outgoingDataTotal += sizeof (ENetProtocolAcknowledge);
 
     acknowledgement -> sentTime = sentTime;
     acknowledgement -> command = * command;
@@ -503,12 +507,12 @@ ENetOutgoingCommand *
 enet_peer_queue_outgoing_command (ENetPeer * peer, const ENetProtocol * command, ENetPacket * packet, enet_uint32 offset, enet_uint16 length)
 {
     ENetChannel * channel = & peer -> channels [command -> header.channelID];
-    ENetOutgoingCommand * outgoingCommand;
+    ENetOutgoingCommand * outgoingCommand = (ENetOutgoingCommand *) enet_malloc (sizeof (ENetOutgoingCommand));
+    if (outgoingCommand == NULL)
+      return NULL;
 
     peer -> outgoingDataTotal += enet_protocol_command_size (command -> header.command) + length;
 
-    outgoingCommand = (ENetOutgoingCommand *) enet_malloc (sizeof (ENetOutgoingCommand));
-
     if (command -> header.channelID == 0xFF)
     {
        ++ peer -> outgoingReliableSequenceNumber;
@@ -665,6 +669,8 @@ enet_peer_queue_incoming_command (ENetPeer * peer, const ENetProtocol * command,
     }
 
     incomingCommand = (ENetIncomingCommand *) enet_malloc (sizeof (ENetIncomingCommand));
+    if (incomingCommand == NULL)
+      goto freePacket;
 
     incomingCommand -> reliableSequenceNumber = command -> header.reliableSequenceNumber;
     incomingCommand -> unreliableSequenceNumber = unreliableSequenceNumber & 0xFFFF;
@@ -677,6 +683,11 @@ enet_peer_queue_incoming_command (ENetPeer * peer, const ENetProtocol * command,
     if (fragmentCount > 0)
     { 
        incomingCommand -> fragments = (enet_uint32 *) enet_malloc ((fragmentCount + 31) / 32 * sizeof (enet_uint32));
+       if (incomingCommand -> fragments == NULL)
+       {
+          enet_free (incomingCommand);
+          goto freePacket;
+       }
        memset (incomingCommand -> fragments, 0, (fragmentCount + 31) / 32 * sizeof (enet_uint32));
     }
 

+ 7 - 6
protocol.c

@@ -289,6 +289,10 @@ enet_protocol_handle_connect (ENetHost * host, ENetProtocolHeader * header, ENet
     if (currentPeer >= & host -> peers [host -> peerCount])
       return NULL;
 
+    currentPeer -> channels = (ENetChannel *) enet_malloc (channelCount * sizeof (ENetChannel));
+    if (currentPeer -> channels == NULL)
+      return NULL;
+    currentPeer -> channelCount = channelCount;
     currentPeer -> state = ENET_PEER_STATE_ACKNOWLEDGING_CONNECT;
     currentPeer -> sessionID = command -> connect.sessionID;
     currentPeer -> address = host -> receivedAddress;
@@ -298,8 +302,6 @@ enet_protocol_handle_connect (ENetHost * host, ENetProtocolHeader * header, ENet
     currentPeer -> packetThrottleInterval = ENET_NET_TO_HOST_32 (command -> connect.packetThrottleInterval);
     currentPeer -> packetThrottleAcceleration = ENET_NET_TO_HOST_32 (command -> connect.packetThrottleAcceleration);
     currentPeer -> packetThrottleDeceleration = ENET_NET_TO_HOST_32 (command -> connect.packetThrottleDeceleration);
-    currentPeer -> channels = (ENetChannel *) enet_malloc (channelCount * sizeof (ENetChannel));
-    currentPeer -> channelCount = channelCount;
 
     for (channel = currentPeer -> channels;
          channel < & currentPeer -> channels [channelCount];
@@ -440,15 +442,14 @@ enet_protocol_handle_send_unsequenced (ENetHost * host, ENetPeer * peer, const E
     if (peer -> unsequencedWindow [index / 32] & (1 << (index % 32)))
       return 0;
       
-    peer -> unsequencedWindow [index / 32] |= 1 << (index % 32);
-    
-                        
     packet = enet_packet_create ((const enet_uint8 *) command + sizeof (ENetProtocolSendUnsequenced),
                                  dataLength,
                                  ENET_PACKET_FLAG_UNSEQUENCED);
     if (packet == NULL)
       return -1;
-    
+   
+    peer -> unsequencedWindow [index / 32] |= 1 << (index % 32);
+ 
     enet_peer_queue_incoming_command (peer, command, packet, 0);
     return 0;
 }