129 lines
5.4 KiB
Python
129 lines
5.4 KiB
Python
import pytest
|
|
import struct
|
|
import monocypher
|
|
from openttd.protocol import OpenTTDProtocol, PacketGameType
|
|
from openttd_protocol.wire.exceptions import SocketClosed
|
|
|
|
class MockTransport:
|
|
def __init__(self): self._closing = False
|
|
def is_closing(self): return self._closing
|
|
def close(self): self._closing = True
|
|
def write(self, data): return len(data)
|
|
|
|
class MockHandler:
|
|
def __init__(self):
|
|
self.encryption_enabled = False
|
|
self._recv_aead = None
|
|
self._send_aead = None
|
|
self._session_key_recv = b"A" * 32
|
|
self._session_key_send = b"B" * 32
|
|
self._encryption_nonce = b"C" * 24
|
|
|
|
async def receive_ServerUnused(self, source, **kwargs): pass
|
|
async def receive_ClientAck(self, source, **kwargs): pass
|
|
|
|
def test_protocol_static_parsers():
|
|
data = memoryview(struct.pack("<BI B", 1, 42, 0) + b"Hello\x00")
|
|
res = OpenTTDProtocol.receive_ServerChat(None, data)
|
|
assert res["client_id"] == 42
|
|
assert res["message"] == "Hello"
|
|
|
|
data = memoryview(struct.pack("<I", 123))
|
|
res = OpenTTDProtocol.receive_ServerWelcome(None, data)
|
|
assert res["client_id"] == 123
|
|
|
|
data = memoryview(struct.pack("<II", 1000, 2000) + b"\x00" * 12 + b"\x07")
|
|
res = OpenTTDProtocol.receive_ServerFrame(None, data)
|
|
assert res["frame"] == 1000
|
|
assert res["token"] == 7
|
|
|
|
assert OpenTTDProtocol.receive_ServerExternalChat(None, b"") == {}
|
|
assert OpenTTDProtocol.receive_ServerCommand(None, b"") == {}
|
|
assert OpenTTDProtocol.receive_ServerFull(None, b"") == {}
|
|
assert OpenTTDProtocol.receive_ServerBanned(None, b"") == {}
|
|
assert OpenTTDProtocol.receive_ClientIdentify(None, b"") == {}
|
|
assert OpenTTDProtocol.receive_ClientAck(None, struct.pack("<IB", 1, 2)) == {"frame": 1, "token": 2}
|
|
assert OpenTTDProtocol.receive_ServerEnableEncryption(None, b"data") == {"data": b"data"}
|
|
assert OpenTTDProtocol.receive_ServerCheckNewGRFs(None, b"") == {}
|
|
assert OpenTTDProtocol.receive_ServerUnused(None, b"") == {}
|
|
assert OpenTTDProtocol.receive_ServerMapDone(None, b"") == {}
|
|
assert OpenTTDProtocol.receive_ServerClientInfo(None, b"") == {}
|
|
assert OpenTTDProtocol.receive_ServerSync(None, b"") == {}
|
|
assert OpenTTDProtocol.receive_ServerClientJoined(None, b"") == {}
|
|
assert OpenTTDProtocol.receive_ServerMapBegin(None, b"") == {}
|
|
assert OpenTTDProtocol.receive_ServerMapSize(None, b"") == {"size": 0}
|
|
assert OpenTTDProtocol.receive_ServerMapData(None, b"data") == {"data": b"data"}
|
|
assert OpenTTDProtocol.receive_ServerConfigurationUpdate(None, b"") == {}
|
|
assert OpenTTDProtocol.receive_ServerAuthenticationRequest(None, struct.pack("<B", 1) + b"data") == {"auth_type": 1, "data": b"data"}
|
|
assert OpenTTDProtocol.receive_ServerError(None, b"\x08") == {"error_code": 8}
|
|
assert OpenTTDProtocol.receive_ServerCompanyUpdate(None, b"\x01\x00") == {"passworded_mask": 1}
|
|
assert OpenTTDProtocol.receive_ServerNeedCompanyPassword(None, memoryview(struct.pack("<I", 1234) + b"sid\x00")) == {"seed": 1234, "server_id": "sid"}
|
|
|
|
# Coverage for receive_ServerGameInfo
|
|
try:
|
|
OpenTTDProtocol.receive_ServerGameInfo(None, memoryview(b"\x00" * 200))
|
|
except Exception:
|
|
pass
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_protocol_exception_handling():
|
|
handler = MockHandler()
|
|
proto = OpenTTDProtocol(handler)
|
|
proto.transport = MockTransport()
|
|
|
|
# Passing data that causes struct.unpack to fail (too short for uint16)
|
|
ptype, kwargs = proto.receive_packet(None, memoryview(b"\x01"))
|
|
assert ptype == PacketGameType.ServerUnused
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_protocol_encryption_logic():
|
|
handler = MockHandler()
|
|
handler.encryption_enabled = True
|
|
proto = OpenTTDProtocol(handler)
|
|
proto.transport = MockTransport()
|
|
proto._can_write.set()
|
|
|
|
# Send test
|
|
payload = b"\x03\x00\x0e"
|
|
written_len = await proto.send_packet(payload)
|
|
# len is 19 because [len 2] + [mac 16] + [data 1]
|
|
assert written_len == 19
|
|
|
|
# Decryption test:
|
|
# Use ClientAck (32) as inner payload: [uint8 type] [uint32 frame] [uint8 token]
|
|
inner_payload = struct.pack("<B I B", 32, 1234, 7)
|
|
|
|
locker = monocypher.IncrementalAuthenticatedEncryption(handler._session_key_recv, handler._encryption_nonce)
|
|
mac, ciphertext = locker.lock(inner_payload)
|
|
|
|
handler._recv_aead = monocypher.IncrementalAuthenticatedEncryption(handler._session_key_recv, handler._encryption_nonce)
|
|
wire_data = memoryview(struct.pack("<H", len(mac) + len(ciphertext) + 2) + mac + ciphertext)
|
|
|
|
ptype, kwargs = proto.receive_packet(None, wire_data)
|
|
assert ptype == PacketGameType.ClientAck
|
|
assert kwargs["frame"] == 1234
|
|
assert kwargs["token"] == 7
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_protocol_decryption_failure():
|
|
handler = MockHandler()
|
|
handler.encryption_enabled = True
|
|
proto = OpenTTDProtocol(handler)
|
|
proto.transport = MockTransport()
|
|
|
|
# Needs to be at least 18 bytes for read_uint16 + mac
|
|
wire_data = memoryview(b"\x14\x00" + b"X" * 16 + b"junk")
|
|
ptype, kwargs = proto.receive_packet(None, wire_data)
|
|
assert ptype == PacketGameType.ServerUnused
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_protocol_is_closing_failure():
|
|
handler = MockHandler()
|
|
proto = OpenTTDProtocol(handler)
|
|
proto.transport = MockTransport()
|
|
proto.transport.close()
|
|
proto._can_write.set()
|
|
|
|
with pytest.raises(SocketClosed):
|
|
await proto.send_packet(b"\x02\x00")
|