Index: src/OFTCPSocket+SOCKS5.m ================================================================== --- src/OFTCPSocket+SOCKS5.m +++ src/OFTCPSocket+SOCKS5.m @@ -15,54 +15,96 @@ */ #include "config.h" #import "OFTCPSocket+SOCKS5.h" +#import "OFDataArray.h" #import "OFConnectionFailedException.h" +#import "OFOutOfRangeException.h" +#import "OFReadFailedException.h" +#import "OFWriteFailedException.h" + +#import "socket_helpers.h" /* Reference for static linking */ int _OFTCPSocket_SOCKS5_reference; + +static void +send_or_exception(OFTCPSocket *self, int socket, char *buffer, size_t length) +{ + if (send(socket, buffer, length, 0) != length) + @throw [OFWriteFailedException + exceptionWithObject: self + requestedLength: length + errNo: of_socket_errno()]; +} + +static void +recv_exact(OFTCPSocket *self, int socket, char *buffer, size_t length) +{ + while (length > 0) { + ssize_t ret = recv(socket, buffer, length, 0); + + if (ret < 0) + @throw [OFReadFailedException + exceptionWithObject: self + requestedLength: length + errNo: of_socket_errno()]; + + buffer += ret; + length -= ret; + } +} @implementation OFTCPSocket (SOCKS5) - (void)OF_SOCKS5ConnectToHost: (OFString*)host port: (uint16_t)port { - const char request[] = { 5, 1, 0, 3 }; + char request[] = { 5, 1, 0, 3 }; char reply[256]; - bool wasWriteBuffered; + void *pool; + OFDataArray *connectRequest; + + if ([host UTF8StringLength] > 255) + @throw [OFOutOfRangeException exception]; /* 5 1 0 -> no authentication */ - [self writeBuffer: request - length: 3]; + send_or_exception(self, _socket, request, 3); - [self readIntoBuffer: reply - exactLength: 2]; + recv_exact(self, _socket, reply, 2); if (reply[0] != 5 || reply[1] != 0) { [self close]; @throw [OFConnectionFailedException exceptionWithHost: host port: port socket: self]; } - wasWriteBuffered = [self isWriteBuffered]; - [self setWriteBuffered: true]; - /* CONNECT request */ - [self writeBuffer: request - length: 4]; - [self writeInt8: [host UTF8StringLength]]; - [self writeBuffer: [host UTF8String] - length: [host UTF8StringLength]]; - [self writeBigEndianInt16: port]; - - [self flushWriteBuffer]; - [self setWriteBuffered: wasWriteBuffered]; - - [self readIntoBuffer: reply - exactLength: 4]; + pool = objc_autoreleasePoolPush(); + connectRequest = [OFDataArray dataArray]; + + [connectRequest addItems: request + count: 4]; + + request[0] = [host UTF8StringLength]; + [connectRequest addItem: request]; + [connectRequest addItems: [host UTF8String] + count: request[0]]; + + request[0] = port >> 8; + request[1] = port & 0xFF; + [connectRequest addItems: request + count: 2]; + + send_or_exception(self, _socket, + [connectRequest items], [connectRequest count]); + + objc_autoreleasePoolPop(pool); + + recv_exact(self, _socket, reply, 4); if (reply[0] != 5 || reply[1] != 0 || reply[2] != 0) { [self close]; @throw [OFConnectionFailedException exceptionWithHost: host port: port @@ -70,26 +112,24 @@ } /* Skip the rest of the reply */ switch (reply[3]) { case 1: /* IPv4 */ - [self readIntoBuffer: reply - exactLength: 4]; + recv_exact(self, _socket, reply, 4); break; case 3: /* Domainname */ - [self readIntoBuffer: reply - exactLength: [self readInt8]]; + recv_exact(self, _socket, reply, 1); + recv_exact(self, _socket, reply, reply[0]); break; case 4: /* IPv6 */ - [self readIntoBuffer: reply - exactLength: 16]; + recv_exact(self, _socket, reply, 16); break; default: [self close]; @throw [OFConnectionFailedException exceptionWithHost: host port: port socket: self]; } - [self readBigEndianInt16]; + recv_exact(self, _socket, reply, 2); } @end