Simple SOCKS5 Python asyncio proxy
Step by step implementation of a simple SOCKS5 proxy server with asynchronous asyncio python.
The other day a coworker asked me if I knew how a proxy works. And well aside from the obvious semantics of what a proxy does I didn’t know how it works. So I decided that the best way to learn it is to implement one myself.
I never read an RFC before, but the SOCKS 5 RFC was a very clear and easy read. I recommend you take a look at it, the protocol doesn’t really look that hard.
The first thing we should do is to start listening to a port.
loop = asyncio.get_event_loop()
coro = asyncio.start_server(handle_request, '127.0.0.1', 8484, loop=loop)
server = loop.run_until_complete(coro)
loop.run_forever()
We will start listening to all the requests using handle_request
, which should accept two parameters:
async def handle_request(
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter
) -> None:
...
Authentication
The client will send us two “networks”(!
) bytes, the first is the version, and the other is the number of authentication methods it supports.
header = await reader.readexactly(2)
version, nmethods = struct.unpack("!BB", header)
assert version == Socks5Invariants.VER, # just 5
assert nmethods > 0
There can be many auth methods, but since I was practicing I found this is more than enough to play around with:
class AuthMethod(int, Enum):
NO_AUTHENTICATION_REQUIRED = 0x0
GSSAPI = 0x01
USERNAME_PASSWORD = 0x02
NO_ACCEPTABLE_METHODS = 0Xff
We can get all the methods supported by the client with:
async def get_methods(
reader: asyncio.StreamReader,
n: int
) -> Set[int]:
methods = []
for _i in range(n):
method_raw = await reader.read(1)
methods.append(ord(method_raw))
return set(methods)
Then we check the client’s support for the auth method. In this example, no auth.
methods = await get_methods(reader, nmethods)
assert AuthMethod.NO_AUTHENTICATION_REQUIRED in methods
If we continue without authentication, we should tell the client that we want to proceed. We do that by sending two bytes, one with the socks version and another with the auth method that was selected.
writer.write(
struct.pack(
"!BB",
Socks5Invariants.VER,
AuthMethod.NO_AUTHENTICATION_REQUIRED
)
Connection request
Once we acknowledge the authentication, the client will send us a connection request. That request will contain, the field that you will see in RequestDetails
. In this request, the user defines the host and port where the proxy should connect. There is no magic here, just using struct.
and socket.
to parse the request.
@dataclass
class RequestDetails:
ver: int
cmd: Cmd
rsv: int
atyp: Atyp
dst_addr: bytes
dst_port: int
@classmethod
async def read(cls, reader: asyncio.StreamReader) -> "RequestDetails":
address: Optional[bytes] = None
ver, raw_cmd, rsv, raw_atyp = struct.unpack(
"!BBBB", await reader.readexactly(4)
)
atyp = Atyp(raw_atyp)
cmd = Cmd(raw_cmd)
if atyp == Atyp.IP_V4:
raw_ip = await reader.readexactly(4)
address = socket.inet_ntoa(raw_ip)
elif atyp == Atyp.DOMAINNAME:
domain_length = await reader.readexactly(1)
address = await reader.readexactly(domain_length[0])
else:
raise Socks5Unsupported(atyp)
assert address
dst_addr = address
dst_port, = struct.unpack('!H', await reader.readexactly(2))
return cls(
ver=ver,
cmd=cmd,
rsv=rsv,
atyp=atyp,
dst_addr=dst_addr,
dst_port=dst_port,
)
We can receive different cmd
but we will expect only CONNECT
. I defined Cmd
in this way:
class Cmd(int, Enum):
CONNECT = 0x01
BIND = 0x02
UDP_ASSOCIATE = 0X03
def __repr__(self) -> str:
return self.name
Connect to the target/destination server
If we accept everything that just came we will need to connect to the server and return the connection details:
@dataclass
class Destination:
request: RequestDetails
reader: asyncio.StreamReader
writer: asyncio.StreamWriter
bind_address: bytes
bind_port: int
@classmethod
async def connect(cls, req: RequestDetails) -> Optional["Destination"]:
try:
reader, writer = await asyncio.open_connection(
req.dst_addr, req.dst_port
)
except Exception:
return None
sockname = writer.get_extra_info("sockname")
return cls(
request=req,
reader=reader,
writer=writer,
bind_address=sockname[0],
bind_port=sockname[1],
)
def get_connected_reply(self) -> bytes:
ver = Socks5Invariants.VER
atyp = Atyp.IP_V4
addr = struct.unpack("!I", socket.inet_aton(self.bind_address[0]))[0]
port = self.bind_port
return struct.pack("!BBBBIH", ver, 0, 0, atyp, addr, port)
Once we connect
we should return get_connected_reply
:
+----+-----+-------+------+----------+----------+
|VER | REP | RSV | ATYP | BND.ADDR | BND.PORT |
+----+-----+-------+------+----------+----------+
| 1 | 1 | X'00' | 1 | Variable | 2 |
+----+-----+-------+------+----------+----------+
Exchange loop
Now we have two connections, one with the client and one with the server, since we are in the middle we need to forward everything between them. The following code just runs _rcv_from_origin
and _rcv_from_destination
concurrently suppressing all errors, and killing long connections.
class ExchangeLoop:
_DEFAULT_TTL = timedelta(seconds=30)
_EXCEPTIONS_TO_IGNORE = (
asyncio.CancelledError,
asyncio.TimeoutError,
TimeoutError
)
def __init__(self, origin: Origin, destination: Destination) -> None:
self._orig = origin
self._dest = destination
def _expect_cancellations(inner_func):
@wraps(inner_func)
async def outer_expected_cancellations(self, *args, **kwargs):
try:
return await inner_func(self, *args, **kwargs)
except self._EXCEPTIONS_TO_IGNORE:
return
return outer_expected_cancellations
@_expect_cancellations
async def _rcv_from_origin(self) -> None:
while True:
data = await self._orig.reader.read(4096)
if not data:
break
self._dest.writer.write(data)
@_expect_cancellations
async def _rcv_from_destination(self) -> None:
while True:
data = await self._dest.reader.read(4096)
if not data:
break
self._orig.writer.write(data)
@_expect_cancellations
async def run(self) -> None:
aws = [
self._rcv_from_origin(),
self._rcv_from_destination()
]
max_ttl = self._DEFAULT_TTL.total_seconds()
aws = [
asyncio.create_task(asyncio.wait_for(a, max_ttl))
for a in aws
]
try:
await asyncio.wait(aws, return_when=asyncio.FIRST_COMPLETED)
finally:
for a in aws:
try:
if not a.done():
a.cancel()
await a
except Exception:
pass