Simple SOCKS5 Python asyncio proxy

Step by step implementation of a simple SOCKS5 proxy server with asynchronous asyncio python.

Simple SOCKS5 Python asyncio proxy
Photo by Gabriel González / Unsplash

The other day a coworker asked me if I knew how a proxy works. And well aside from the obvious semantics of what a proxy does I didn’t know how it works. So I decided that the best way to learn it is to implement one myself.

I never read an RFC before, but the SOCKS 5 RFC was a very clear and easy read. I recommend you take a look at it, the protocol doesn’t really look that hard.

The first thing we should do is to start listening to a port.

loop = asyncio.get_event_loop()
coro = asyncio.start_server(handle_request, '127.0.0.1', 8484, loop=loop)
server = loop.run_until_complete(coro)
loop.run_forever()

We will start listening to all the requests using handle_request, which should accept two parameters:

async def handle_request(
    reader: asyncio.StreamReader,
    writer: asyncio.StreamWriter
) -> None:
    ...

Authentication

The client will send us two “networks”(!) bytes, the first is the version, and the other is the number of authentication methods it supports.

header = await reader.readexactly(2)
version, nmethods = struct.unpack("!BB", header)
assert version == Socks5Invariants.VER,  # just 5
assert nmethods > 0

There can be many auth methods, but since I was practicing I found this is more than enough to play around with:

class AuthMethod(int, Enum):
    NO_AUTHENTICATION_REQUIRED = 0x0
    GSSAPI = 0x01
    USERNAME_PASSWORD = 0x02
    NO_ACCEPTABLE_METHODS = 0Xff

We can get all the methods supported by the client with:

async def get_methods(
    reader: asyncio.StreamReader,
    n: int
) -> Set[int]:
    methods = []
    for _i in range(n):
        method_raw = await reader.read(1)
        methods.append(ord(method_raw))
    return set(methods)

Then we check the client’s support for the auth method. In this example, no auth.

methods = await get_methods(reader, nmethods)
assert AuthMethod.NO_AUTHENTICATION_REQUIRED in methods

If we continue without authentication, we should tell the client that we want to proceed. We do that by sending two bytes, one with the socks version and another with the auth method that was selected.

writer.write(
    struct.pack(
        "!BB",
        Socks5Invariants.VER,
        AuthMethod.NO_AUTHENTICATION_REQUIRED
    )

Connection request

Once we acknowledge the authentication, the client will send us a connection request. That request will contain, the field that you will see in RequestDetails. In this request, the user defines the host and port where the proxy should connect. There is no magic here, just using struct. and socket. to parse the request.

@dataclass
class RequestDetails:
    ver: int
    cmd: Cmd
    rsv: int
    atyp: Atyp
    dst_addr: bytes
    dst_port: int

    @classmethod
    async def read(cls, reader: asyncio.StreamReader) -> "RequestDetails":
        address: Optional[bytes] = None
        ver, raw_cmd, rsv, raw_atyp = struct.unpack(
            "!BBBB", await reader.readexactly(4)
        )
        atyp = Atyp(raw_atyp)
        cmd = Cmd(raw_cmd)

        if atyp == Atyp.IP_V4:
            raw_ip = await reader.readexactly(4)
            address = socket.inet_ntoa(raw_ip)
        elif atyp == Atyp.DOMAINNAME:
            domain_length = await reader.readexactly(1)
            address = await reader.readexactly(domain_length[0])
        else:
            raise Socks5Unsupported(atyp)
        assert address

        dst_addr = address
        dst_port, = struct.unpack('!H', await reader.readexactly(2))

        return cls(
            ver=ver,
            cmd=cmd,
            rsv=rsv,
            atyp=atyp,
            dst_addr=dst_addr,
            dst_port=dst_port,
        )

We can receive different cmd but we will expect only CONNECT. I defined Cmd in this way:

class Cmd(int, Enum):
    CONNECT = 0x01
    BIND = 0x02
    UDP_ASSOCIATE = 0X03

    def __repr__(self) -> str:
        return self.name

Connect to the target/destination server

If we accept everything that just came we will need to connect to the server and return the connection details:

@dataclass
class Destination:
    request: RequestDetails
    reader: asyncio.StreamReader
    writer: asyncio.StreamWriter
    bind_address: bytes
    bind_port: int

    @classmethod
    async def connect(cls, req: RequestDetails) -> Optional["Destination"]:
        try:
            reader, writer = await asyncio.open_connection(
                req.dst_addr, req.dst_port
            )
        except Exception:
            return None
        sockname = writer.get_extra_info("sockname")

        return cls(
            request=req,
            reader=reader,
            writer=writer,
            bind_address=sockname[0],
            bind_port=sockname[1],
        )

    def get_connected_reply(self) -> bytes:
        ver = Socks5Invariants.VER
        atyp = Atyp.IP_V4
        addr = struct.unpack("!I", socket.inet_aton(self.bind_address[0]))[0]
        port = self.bind_port
        return struct.pack("!BBBBIH", ver, 0, 0, atyp, addr, port)

Once we connect we should return get_connected_reply:

+----+-----+-------+------+----------+----------+
|VER | REP |  RSV  | ATYP | BND.ADDR | BND.PORT |
+----+-----+-------+------+----------+----------+
| 1  |  1  | X'00' |  1   | Variable |    2     |
+----+-----+-------+------+----------+----------+

Exchange loop

Now we have two connections, one with the client and one with the server, since we are in the middle we need to forward everything between them. The following code just runs _rcv_from_origin and _rcv_from_destination concurrently suppressing all errors, and killing long connections.

class ExchangeLoop:
    _DEFAULT_TTL = timedelta(seconds=30)
    _EXCEPTIONS_TO_IGNORE = (
        asyncio.CancelledError,
        asyncio.TimeoutError,
        TimeoutError
    )

    def __init__(self, origin: Origin, destination: Destination) -> None:
        self._orig = origin
        self._dest = destination

    def _expect_cancellations(inner_func):
        @wraps(inner_func)
        async def outer_expected_cancellations(self, *args, **kwargs):
            try:
                return await inner_func(self, *args, **kwargs)
            except self._EXCEPTIONS_TO_IGNORE:
                return
        return outer_expected_cancellations

    @_expect_cancellations
    async def _rcv_from_origin(self) -> None:
        while True:
            data = await self._orig.reader.read(4096)
            if not data:
                break
            self._dest.writer.write(data)

    @_expect_cancellations
    async def _rcv_from_destination(self) -> None:
        while True:
            data = await self._dest.reader.read(4096)
            if not data:
                break
            self._orig.writer.write(data)

    @_expect_cancellations
    async def run(self) -> None:
        aws = [
            self._rcv_from_origin(),
            self._rcv_from_destination()
        ]
        max_ttl = self._DEFAULT_TTL.total_seconds()
        aws = [
            asyncio.create_task(asyncio.wait_for(a, max_ttl))
            for a in aws
        ]

        try:
            await asyncio.wait(aws, return_when=asyncio.FIRST_COMPLETED)
        finally:
            for a in aws:
                try:
                    if not a.done():
                        a.cancel()
                    await a
                except Exception:
                    pass