Initialized openttd-client repo
This commit is contained in:
54
.github/workflows/ci.yml
vendored
Normal file
54
.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
name: Continuous Integration
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ main ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint-and-security:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.12'
|
||||||
|
|
||||||
|
- name: Install Linting Tools
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install ruff bandit
|
||||||
|
|
||||||
|
- name: Lint with Ruff
|
||||||
|
run: ruff check .
|
||||||
|
|
||||||
|
- name: Security Scan with Bandit
|
||||||
|
# We ignore B101 (assert) as it's common in tests, and B303 (MD5)
|
||||||
|
# because OpenTTD protocol REQUIRES MD5 for company passwords.
|
||||||
|
run: bandit -r lib/ -ll -i -s B101,B303
|
||||||
|
|
||||||
|
tests-and-coverage:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.12'
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install openttd-protocol pymonocypher pytest pytest-asyncio pytest-cov
|
||||||
|
|
||||||
|
- name: Run Tests with Strict Coverage
|
||||||
|
env:
|
||||||
|
PYTHONPATH: lib
|
||||||
|
CI: true
|
||||||
|
# We run tests and fail if coverage is below 100%.
|
||||||
|
# E2E test is skipped in CI because no local OpenTTD server is available.
|
||||||
|
run: |
|
||||||
|
pytest --cov=openttd --cov-report=term-missing --cov-fail-under=100 -k "not test_server_connection_and_join" tests/
|
||||||
36
.github/workflows/codeql.yml
vendored
Normal file
36
.github/workflows/codeql.yml
vendored
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
name: "CodeQL Analysis"
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ main ]
|
||||||
|
schedule:
|
||||||
|
- cron: '30 1 * * 0' # Weekly scan
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
analyze:
|
||||||
|
name: Analyze
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
security-events: write
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
language: [ 'python' ]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Initialize CodeQL
|
||||||
|
uses: github/codeql-action/init@v3
|
||||||
|
with:
|
||||||
|
languages: ${{ matrix.language }}
|
||||||
|
|
||||||
|
- name: Autobuild
|
||||||
|
uses: github/codeql-action/autobuild@v3
|
||||||
|
|
||||||
|
- name: Perform CodeQL Analysis
|
||||||
|
uses: github/codeql-action/analyze@v3
|
||||||
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
venv
|
||||||
|
__pycache__
|
||||||
23
GEMINI.md
Normal file
23
GEMINI.md
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
# Development and Debugging Mandates
|
||||||
|
|
||||||
|
To ensure technical integrity and accurate implementation of the OpenTTD network protocol (especially JGRPP-specific modifications), the following mandates must be followed by all agents.
|
||||||
|
|
||||||
|
## 1. Upstream Source Code Access
|
||||||
|
The JGRPP upstream repository MUST be available in the workspace for direct source code analysis. This is critical because the OpenTTD protocol is complex, version-specific, and often requires verifying C++ logic for packet structures, KDF (Key Derivation Function) ordering, and state machine transitions.
|
||||||
|
|
||||||
|
### Requirement
|
||||||
|
If the folder `OpenTTD-patches` is missing, you MUST clone it immediately before performing any protocol-related tasks:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone --depth 1 -b jgrpp https://github.com/JGRennison/OpenTTD-patches OpenTTD-patches
|
||||||
|
```
|
||||||
|
|
||||||
|
## 2. Debugging Workflow
|
||||||
|
Before implementing or modifying networking code:
|
||||||
|
1. **Verify Packet IDs:** Search `OpenTTD-patches/src/network/core/tcp_game.h` for the `PacketGameType` enum.
|
||||||
|
2. **Verify Encryption Logic:** Check `OpenTTD-patches/src/network/network_crypto.cpp` for any changes to AEAD or PAKE handling.
|
||||||
|
3. **Verify State Transitions:** Analyze `OpenTTD-patches/src/network/network_server.cpp` and `network_client.cpp` to understand the expected sequence of packets (e.g., when to send `ClientMapOk` or `ClientAck`).
|
||||||
|
|
||||||
|
## 3. Engineering Standards
|
||||||
|
- **Binary Accuracy:** Never guess packet structures. Always cross-reference with the `Recv_` and `Send_` calls in the C++ source.
|
||||||
|
- **Protocol Documentation:** If a new protocol detail is discovered, document it in `docs/PROTOCOL.md`.
|
||||||
74
README.md
Normal file
74
README.md
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
# OpenTTD Python Client
|
||||||
|
|
||||||
|
A high-performance, Object-Oriented Python client for OpenTTD servers, specifically optimized for **JGR Patch Pack (JGRPP)**. This client handles the modern secure handshake, including X25519 PAKE authentication and AEAD stream encryption.
|
||||||
|
|
||||||
|
## 🚀 Features
|
||||||
|
|
||||||
|
- **Secure Authentication:** Full implementation of X25519 PAKE (Password-Authenticated Key Exchange).
|
||||||
|
- **Stream Encryption:** Automatic XChaCha20-Poly1305 authenticated encryption for all game traffic.
|
||||||
|
- **Modular Design:** Separates low-level binary protocol handling from high-level game logic.
|
||||||
|
- **State Management:** Handles the full join sequence including Map download and synchronization.
|
||||||
|
- **100% Test Coverage:** Robustly tested with unit, logic, and E2E tests.
|
||||||
|
|
||||||
|
## 🛠 Setup
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
- Python 3.11+
|
||||||
|
- A running OpenTTD server (preferably JGRPP)
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
1. Create and activate a virtual environment:
|
||||||
|
```bash
|
||||||
|
python3 -m venv venv
|
||||||
|
source venv/bin/activate # Linux/macOS
|
||||||
|
```
|
||||||
|
2. Install dependencies:
|
||||||
|
```bash
|
||||||
|
pip install openttd-protocol pymonocypher
|
||||||
|
```
|
||||||
|
|
||||||
|
## 📖 Usage
|
||||||
|
|
||||||
|
### Running the default client
|
||||||
|
The `main.py` script is configured to join the local server and the company "Én transport".
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 main.py [Username] [CompanyID]
|
||||||
|
```
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```bash
|
||||||
|
python3 main.py MyBot 0
|
||||||
|
```
|
||||||
|
|
||||||
|
### Module Integration
|
||||||
|
You can use the `openttd` package in your own projects:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from openttd import OpenTTDClient
|
||||||
|
|
||||||
|
client = OpenTTDClient(host="127.0.0.1", username="BotName")
|
||||||
|
await client.connect(server_password="asd")
|
||||||
|
await client.join_company(company_id=0, company_password="asd123")
|
||||||
|
|
||||||
|
await client.joined.wait()
|
||||||
|
# Your logic here...
|
||||||
|
```
|
||||||
|
|
||||||
|
## 📂 Project Structure
|
||||||
|
|
||||||
|
- `main.py`: Main entry point and usage example.
|
||||||
|
- `lib/openttd/`: Core package containing the protocol and client logic.
|
||||||
|
- `docs/`: Extensive documentation on architecture, protocol, and contributing.
|
||||||
|
- `tests/`: Comprehensive test suite (Logic, Protocol, E2E).
|
||||||
|
|
||||||
|
## 🧪 Testing
|
||||||
|
We maintain 100% test coverage. To run tests:
|
||||||
|
```bash
|
||||||
|
PYTHONPATH=lib pytest --cov=openttd tests/
|
||||||
|
```
|
||||||
|
|
||||||
|
## 📜 Documentation
|
||||||
|
- [Architecture & Design](docs/ARCHITECTURE.md)
|
||||||
|
- [Protocol Internals (PAKE/Encryption)](docs/PROTOCOL.md)
|
||||||
|
- [Contributor Guide](docs/CONTRIBUTING.md)
|
||||||
37
docs/ARCHITECTURE.md
Normal file
37
docs/ARCHITECTURE.md
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
# Architecture and Design
|
||||||
|
|
||||||
|
This project follows a strict Object-Oriented approach to manage the complexity of the OpenTTD network protocol.
|
||||||
|
|
||||||
|
## Component Overview
|
||||||
|
|
||||||
|
### 1. `OpenTTDProtocol` (Low-Level)
|
||||||
|
Inherits from `openttd_protocol.wire.tcp.TCPProtocol`.
|
||||||
|
- **Responsibility:** Binary serialization/deserialization and stream encryption.
|
||||||
|
- **Encryption Layer:** It overrides `send_packet` and `receive_packet` to wrap/unwrap AEAD (XChaCha20-Poly1305) payloads.
|
||||||
|
- **Manual Dispatch:** Uses a manual dispatch mechanism to ensure that even encrypted packets are correctly mapped to their high-level handlers.
|
||||||
|
|
||||||
|
### 2. `OpenTTDClient` (High-Level)
|
||||||
|
The primary API for developers.
|
||||||
|
- **Responsibility:** Handshake orchestration, state management, and keep-alive.
|
||||||
|
- **Event-Driven:** Uses `asyncio.Event` (like `self.joined`) to signal state changes to the calling code.
|
||||||
|
- **Callback System:** Provides hooks like `on_chat` to allow users to react to game events without modifying the core library.
|
||||||
|
|
||||||
|
## Handshake Flow
|
||||||
|
|
||||||
|
1. **Connection:** TCP connection established to Port 3979.
|
||||||
|
2. **Information:** `ClientGameInfo` sent to verify server version.
|
||||||
|
3. **Join:** `ClientJoin` sent with the specific JGRPP revision string.
|
||||||
|
4. **Authentication:** Server sends `ServerAuthenticationRequest` (Type 1: PAKE).
|
||||||
|
5. **PAKE Exchange:**
|
||||||
|
- Shared Secret derived via X25519.
|
||||||
|
- Session Keys derived via Blake2b hashing of (SharedSecret + ServerPub + OurPub + Password).
|
||||||
|
- Encrypted challenge sent back via `ClientAuthenticationResponse`.
|
||||||
|
6. **Encryption:** Server sends `ServerEnableEncryption`. The Protocol layer activates the AEAD stream.
|
||||||
|
7. **Identification:** `ClientIdentify` sent (now encrypted).
|
||||||
|
8. **Map Synchronization:** `ServerWelcome` received -> `ClientGetMap` sent -> Map segments received -> `ClientMapOk` sent.
|
||||||
|
9. **Active State:** `client.joined` is set. The client responds to `ServerFrame` with `ClientAck` to prevent timeouts.
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
- **Shutdown Event:** Every fatal error or manual quit triggers a `shutdown_event`.
|
||||||
|
- **Graceful Exit:** The `quit()` method sends a `ClientQuit` packet before closing the transport.
|
||||||
|
- **Fallback:** Unknown packets are caught and mapped to `receive_ServerUnused` to prevent the simulation loop from crashing.
|
||||||
28
docs/CONTRIBUTING.md
Normal file
28
docs/CONTRIBUTING.md
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# Contributing to OpenTTD Python Client
|
||||||
|
|
||||||
|
We welcome contributions from the community! To maintain the high quality of this library, please follow these guidelines.
|
||||||
|
|
||||||
|
## Code Standards
|
||||||
|
- **Clean OO Code:** Logic must be encapsulated within the `OpenTTDClient` or `OpenTTDProtocol` classes.
|
||||||
|
- **Type Safety:** Use type hints where possible.
|
||||||
|
- **Minimal Dependencies:** Only add new dependencies if absolutely necessary.
|
||||||
|
|
||||||
|
## Testing Mandate
|
||||||
|
We enforce **100% test coverage**. Any new feature or bug fix must include corresponding tests.
|
||||||
|
|
||||||
|
### Running Tests
|
||||||
|
Use `pytest` within the virtual environment:
|
||||||
|
```bash
|
||||||
|
PYTHONPATH=lib ./venv/bin/pytest --cov=openttd --cov-report=term-missing tests/
|
||||||
|
```
|
||||||
|
|
||||||
|
### Types of Tests
|
||||||
|
- **Logic Tests (`tests/test_logic.py`):** High-level client state and API behavior.
|
||||||
|
- **Protocol Tests (`tests/test_protocol.py`):** Low-level binary parsing and encryption.
|
||||||
|
- **E2E Tests (`tests/test_e2e.py`):** Integration tests against a live server.
|
||||||
|
|
||||||
|
## Submitting Changes
|
||||||
|
1. **Fork the repo** and create your branch from `main`.
|
||||||
|
2. **Add tests** for your changes.
|
||||||
|
3. **Run the full test suite** to ensure no regressions and verify 100% coverage.
|
||||||
|
4. **Open a Pull Request** with a detailed description of your changes.
|
||||||
33
docs/PROTOCOL.md
Normal file
33
docs/PROTOCOL.md
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
# Protocol Internals
|
||||||
|
|
||||||
|
This client supports the modern OpenTTD Game Port protocol (TCP 3979), specifically as implemented in JGRPP.
|
||||||
|
|
||||||
|
## X25519 PAKE Authentication
|
||||||
|
OpenTTD 14+ and JGRPP use a Password-Authenticated Key Exchange to prevent plaintext password leakage.
|
||||||
|
|
||||||
|
### Key Derivation (KDF)
|
||||||
|
We use **Blake2b** (64-byte digest) to derive two 32-byte session keys.
|
||||||
|
- **Input:** `SharedSecret (32)` + `ServerPublicKey (32)` + `ClientPublicKey (32)` + `Password (string)`
|
||||||
|
- **Output:**
|
||||||
|
- `0..31`: Client-to-Server Key
|
||||||
|
- `32..63`: Server-to-Client Key
|
||||||
|
|
||||||
|
### Handshake Nonces
|
||||||
|
The server provides a 24-byte nonce in the `ServerAuthenticationRequest`. This nonce is used for the AEAD challenge during the auth response and for the initial stream encryption setup.
|
||||||
|
|
||||||
|
## Stream Encryption (AEAD)
|
||||||
|
Once `ServerEnableEncryption` is received, all subsequent packets use **XChaCha20-Poly1305** (Authenticated Encryption with Associated Data).
|
||||||
|
|
||||||
|
### Encrypted Packet Format
|
||||||
|
On the wire, encrypted packets have the following structure:
|
||||||
|
1. **Length (2 bytes):** Big-endian uint16 of the *entire* remaining packet.
|
||||||
|
2. **MAC (16 bytes):** The Poly1305 authentication tag.
|
||||||
|
3. **Ciphertext (variable):** The encrypted payload.
|
||||||
|
|
||||||
|
### Decryption Logic
|
||||||
|
The `OpenTTDProtocol` layer uses an `IncrementalAuthenticatedEncryption` state from the Monocypher library. It maintains the nonce state internally. If a MAC check fails (indicating corruption or a wrong key), the client immediately closes the connection (`SocketClosed`).
|
||||||
|
|
||||||
|
## Keep-Alive (Simulation Synchronization)
|
||||||
|
OpenTTD is a lockstep simulation. The server sends `ServerFrame` packets periodically.
|
||||||
|
- **Client Requirement:** You must respond with a `ClientAck` containing the frame number and a one-time `token` provided in the frame packet.
|
||||||
|
- **Timeout:** If the server does not receive an ACK for several in-game days, it will disconnect the client with error code 17 (`TimeoutComputer`).
|
||||||
14
lib/openttd/README.md
Normal file
14
lib/openttd/README.md
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
# Internal Package: openttd
|
||||||
|
|
||||||
|
This directory contains the core implementation of the OpenTTD network client.
|
||||||
|
|
||||||
|
## Modules
|
||||||
|
|
||||||
|
- **`__init__.py`**: Exposes the high-level `OpenTTDClient` API.
|
||||||
|
- **`client.py`**: Implementation of the `OpenTTDClient` class. Manages the connection lifecycle, PAKE state, map synchronization, and keep-alive.
|
||||||
|
- **`protocol.py`**: Implementation of the `OpenTTDProtocol` class. Handles low-level binary packet serialization, AEAD stream encryption, and static packet parsing.
|
||||||
|
|
||||||
|
## Design Goals
|
||||||
|
1. **Encapsulation:** The user should never need to manually handle bytes or encryption keys.
|
||||||
|
2. **Robustness:** Gracefully handle server errors and unknown packet types.
|
||||||
|
3. **Efficiency:** Use `asyncio` for non-blocking I/O and `monocypher` for fast cryptographic operations.
|
||||||
3
lib/openttd/__init__.py
Normal file
3
lib/openttd/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .client import OpenTTDClient
|
||||||
|
|
||||||
|
__all__ = ['OpenTTDClient']
|
||||||
179
lib/openttd/client.py
Normal file
179
lib/openttd/client.py
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
import monocypher
|
||||||
|
import os
|
||||||
|
import hashlib
|
||||||
|
from openttd_protocol.wire.write import write_init, write_string, write_uint8, write_uint32, write_presend, SEND_TCP_MTU
|
||||||
|
from .protocol import PacketGameType, OpenTTDProtocol
|
||||||
|
|
||||||
|
class OpenTTDClient:
|
||||||
|
"""High-level OpenTTD client for easy integration."""
|
||||||
|
def __init__(self, host, port=3979, username="GeminiUser"):
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.username = username
|
||||||
|
self.unique_id = str(uuid.uuid4())
|
||||||
|
self.log = logging.getLogger(f"OTTDS-{username}")
|
||||||
|
|
||||||
|
# State
|
||||||
|
self.encryption_enabled = False
|
||||||
|
self.joined = asyncio.Event()
|
||||||
|
self.shutdown_event = asyncio.Event()
|
||||||
|
self.client_id = None
|
||||||
|
|
||||||
|
# Internal crypto
|
||||||
|
self._server_password = ""
|
||||||
|
self._company_password = ""
|
||||||
|
self._target_company = 255
|
||||||
|
self._session_key_send = None
|
||||||
|
self._session_key_recv = None
|
||||||
|
self._encryption_nonce = None
|
||||||
|
self._send_aead = None
|
||||||
|
self._recv_aead = None
|
||||||
|
|
||||||
|
# Callbacks
|
||||||
|
self.on_chat = None
|
||||||
|
|
||||||
|
async def connect(self, server_password=""):
|
||||||
|
"""Connect to the server and initiate handshake."""
|
||||||
|
self._server_password = server_password
|
||||||
|
self.log.info(f"Connecting to {self.host}:{self.port}...")
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
try:
|
||||||
|
self._transport, self._protocol = await loop.create_connection(
|
||||||
|
lambda: OpenTTDProtocol(self), self.host, self.port
|
||||||
|
)
|
||||||
|
d = write_init(PacketGameType.ClientGameInfo)
|
||||||
|
write_uint8(d, 4)
|
||||||
|
await self._protocol.send_packet(write_presend(d, SEND_TCP_MTU))
|
||||||
|
except Exception as e:
|
||||||
|
self.log.error(f"Connection failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def join_company(self, company_id=255, company_password=""):
|
||||||
|
"""Join a specific company (0-14, or 255 for spectator)."""
|
||||||
|
self._target_company = company_id
|
||||||
|
self._company_password = company_password
|
||||||
|
if not self.joined.is_set():
|
||||||
|
self.log.info(f"Join for company {company_id} configured.")
|
||||||
|
else:
|
||||||
|
self.log.warning("Already joined.")
|
||||||
|
|
||||||
|
def disconnect(self, source):
|
||||||
|
"""Library callback for when connection is lost."""
|
||||||
|
self.log.info("Disconnected.")
|
||||||
|
self.shutdown_event.set()
|
||||||
|
|
||||||
|
async def quit(self):
|
||||||
|
"""Gracefully disconnect from the server."""
|
||||||
|
if hasattr(self, '_protocol') and not self._transport.is_closing():
|
||||||
|
self.log.info("Quitting...")
|
||||||
|
try:
|
||||||
|
d = write_init(PacketGameType.ClientQuit)
|
||||||
|
await self._protocol.send_packet(write_presend(d, SEND_TCP_MTU))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._transport.close()
|
||||||
|
self.shutdown_event.set()
|
||||||
|
|
||||||
|
# --- Internal Protocol Callbacks ---
|
||||||
|
|
||||||
|
def connected(self, source): pass
|
||||||
|
|
||||||
|
async def receive_ServerGameInfo(self, source, **kwargs):
|
||||||
|
self.log.info(f"Server Info: {kwargs.get('name')} ({kwargs.get('openttd_version')})")
|
||||||
|
d = write_init(PacketGameType.ClientJoin)
|
||||||
|
write_string(d, kwargs.get("openttd_version", "jgrpp-0.71.1"))
|
||||||
|
write_uint32(d, 0x20006D64)
|
||||||
|
await self._protocol.send_packet(write_presend(d, SEND_TCP_MTU))
|
||||||
|
|
||||||
|
async def receive_ServerError(self, source, error_code):
|
||||||
|
error_names = {8: "WrongRevision", 10: "WrongPassword", 11: "NameInUse", 17: "TimeoutComputer"}
|
||||||
|
self.log.error(f"Server Error: {error_names.get(error_code, f'Code {error_code}')}")
|
||||||
|
await self.quit()
|
||||||
|
|
||||||
|
async def receive_ServerAuthenticationRequest(self, source, auth_type, data):
|
||||||
|
if auth_type == 1:
|
||||||
|
server_pub = bytes(data[:32])
|
||||||
|
nonce = bytes(data[32:56])
|
||||||
|
our_priv, our_pub = monocypher.generate_key_exchange_key_pair()
|
||||||
|
shared_secret = monocypher.key_exchange(our_priv, server_pub)
|
||||||
|
derived = monocypher.blake2b(shared_secret + server_pub + our_pub + self._server_password.encode())
|
||||||
|
self._session_key_send, self._session_key_recv = derived[:32], derived[32:64]
|
||||||
|
challenge = os.urandom(8)
|
||||||
|
mac, ciphertext = monocypher.lock(self._session_key_send, nonce, challenge, associated_data=our_pub)
|
||||||
|
d = write_init(PacketGameType.ClientAuthenticationResponse)
|
||||||
|
d.extend(our_pub + mac + ciphertext)
|
||||||
|
await self._protocol.send_packet(write_presend(d, SEND_TCP_MTU))
|
||||||
|
|
||||||
|
async def receive_ServerEnableEncryption(self, source, data):
|
||||||
|
self._encryption_nonce = bytes(data)
|
||||||
|
self.encryption_enabled = True
|
||||||
|
d = write_init(PacketGameType.ClientIdentify)
|
||||||
|
write_string(d, self.username)
|
||||||
|
write_uint8(d, self._target_company)
|
||||||
|
write_uint8(d, 1)
|
||||||
|
write_string(d, self.unique_id)
|
||||||
|
await self._protocol.send_packet(write_presend(d, SEND_TCP_MTU))
|
||||||
|
|
||||||
|
async def receive_ServerCheckNewGRFs(self, source):
|
||||||
|
d = write_init(PacketGameType.ClientNewGRFsChecked)
|
||||||
|
await self._protocol.send_packet(write_presend(d, SEND_TCP_MTU))
|
||||||
|
|
||||||
|
async def receive_ServerNeedCompanyPassword(self, source, seed, server_id):
|
||||||
|
if not self._company_password:
|
||||||
|
self.log.error("Server needs company password but none provided.")
|
||||||
|
return
|
||||||
|
salted = bytearray()
|
||||||
|
p_bytes, s_bytes = self._company_password.encode('utf-8'), server_id.encode('utf-8')
|
||||||
|
for i in range(32):
|
||||||
|
p_char = p_bytes[i] if i < len(p_bytes) else 0
|
||||||
|
s_char = s_bytes[i] if i < len(s_bytes) else 0
|
||||||
|
seed_char = (seed >> (i % 32)) & 0xFF
|
||||||
|
salted.append(p_char ^ s_char ^ seed_char)
|
||||||
|
hashed = hashlib.md5(salted, usedforsecurity=False).hexdigest()
|
||||||
|
d = write_init(PacketGameType.ClientCompanyPassword)
|
||||||
|
write_string(d, hashed)
|
||||||
|
await self._protocol.send_packet(write_presend(d, SEND_TCP_MTU))
|
||||||
|
|
||||||
|
async def receive_ServerWelcome(self, source, **kwargs):
|
||||||
|
self.client_id = kwargs.get('client_id')
|
||||||
|
self.log.info(f"Successfully joined as client {self.client_id}")
|
||||||
|
d = write_init(PacketGameType.ClientGetMap)
|
||||||
|
write_uint8(d, 0)
|
||||||
|
await self._protocol.send_packet(write_presend(d, SEND_TCP_MTU))
|
||||||
|
|
||||||
|
async def receive_ServerMapDone(self, source):
|
||||||
|
d = write_init(PacketGameType.ClientMapOk)
|
||||||
|
await self._protocol.send_packet(write_presend(d, SEND_TCP_MTU))
|
||||||
|
self.joined.set()
|
||||||
|
|
||||||
|
async def receive_ServerFrame(self, source, frame, token):
|
||||||
|
d = write_init(PacketGameType.ClientAck)
|
||||||
|
write_uint32(d, frame)
|
||||||
|
write_uint8(d, token)
|
||||||
|
await self._protocol.send_packet(write_presend(d, SEND_TCP_MTU))
|
||||||
|
|
||||||
|
async def receive_ServerChat(self, source, client_id, message, **kwargs):
|
||||||
|
if self.on_chat:
|
||||||
|
self.on_chat(client_id, message)
|
||||||
|
else:
|
||||||
|
self.log.info(f"CHAT: <{client_id}> {message}")
|
||||||
|
|
||||||
|
async def receive_ServerUnused(self, source, **kwargs): pass
|
||||||
|
async def receive_ServerCompanyUpdate(self, source, **kwargs): pass
|
||||||
|
async def receive_ServerClientInfo(self, source, **kwargs): pass
|
||||||
|
async def receive_ServerSync(self, source, **kwargs): pass
|
||||||
|
async def receive_ServerClientJoined(self, source, **kwargs): pass
|
||||||
|
async def receive_ServerMapBegin(self, source, **kwargs): pass
|
||||||
|
async def receive_ServerMapSize(self, source, **kwargs): pass
|
||||||
|
async def receive_ServerMapData(self, source, **kwargs): pass
|
||||||
|
async def receive_ServerConfigurationUpdate(self, source, **kwargs): pass
|
||||||
|
async def receive_ServerExternalChat(self, source, **kwargs): pass
|
||||||
|
async def receive_ServerCommand(self, source, **kwargs): pass
|
||||||
|
async def receive_ServerFull(self, source, **kwargs): pass
|
||||||
|
async def receive_ServerBanned(self, source, **kwargs): pass
|
||||||
|
async def receive_ClientAck(self, source, **kwargs): pass
|
||||||
|
async def receive_ClientIdentify(self, source, **kwargs): pass
|
||||||
173
lib/openttd/protocol.py
Normal file
173
lib/openttd/protocol.py
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
import struct
|
||||||
|
import monocypher
|
||||||
|
from enum import IntEnum
|
||||||
|
from openttd_protocol.wire.tcp import TCPProtocol
|
||||||
|
from openttd_protocol.wire.read import read_uint8, read_string, read_uint16, read_uint32
|
||||||
|
from openttd_protocol.wire.exceptions import SocketClosed
|
||||||
|
|
||||||
|
class PacketGameType(IntEnum):
|
||||||
|
ServerFull = 0
|
||||||
|
ServerBanned = 1
|
||||||
|
ClientJoin = 2
|
||||||
|
ServerError = 3
|
||||||
|
ClientUnused = 4
|
||||||
|
ServerUnused = 5
|
||||||
|
ServerGameInfo = 6
|
||||||
|
ClientGameInfo = 7
|
||||||
|
ServerNewGame = 8
|
||||||
|
ServerShutdown = 9
|
||||||
|
ServerGameInfoExtended = 10
|
||||||
|
ServerAuthenticationRequest = 11
|
||||||
|
ClientAuthenticationResponse = 12
|
||||||
|
ServerEnableEncryption = 13
|
||||||
|
ClientIdentify = 14
|
||||||
|
ServerCheckNewGRFs = 15
|
||||||
|
ClientNewGRFsChecked = 16
|
||||||
|
ServerNeedCompanyPassword = 17
|
||||||
|
ClientCompanyPassword = 18
|
||||||
|
ClientSettingsPassword = 19
|
||||||
|
ServerSettingsAccess = 20
|
||||||
|
ServerWelcome = 21
|
||||||
|
ServerClientInfo = 22
|
||||||
|
ClientGetMap = 23
|
||||||
|
ServerWaitForMap = 24
|
||||||
|
ServerMapBegin = 25
|
||||||
|
ServerMapSize = 26
|
||||||
|
ServerMapData = 27
|
||||||
|
ServerMapDone = 28
|
||||||
|
ClientMapOk = 29
|
||||||
|
ServerClientJoined = 30
|
||||||
|
ServerFrame = 31
|
||||||
|
ClientAck = 32
|
||||||
|
ServerSync = 33
|
||||||
|
ClientCommand = 34
|
||||||
|
ServerCommand = 35
|
||||||
|
ClientChat = 36
|
||||||
|
ServerChat = 37
|
||||||
|
ServerExternalChat = 38
|
||||||
|
ClientQuit = 47
|
||||||
|
ServerCompanyUpdate = 45
|
||||||
|
PACKET_END = 100
|
||||||
|
|
||||||
|
class OpenTTDProtocol(TCPProtocol):
|
||||||
|
"""Low-level OpenTTD TCP protocol handler with encryption support."""
|
||||||
|
PacketType = PacketGameType
|
||||||
|
PACKET_END = PacketGameType.PACKET_END
|
||||||
|
|
||||||
|
def __init__(self, handler):
|
||||||
|
super().__init__(handler)
|
||||||
|
self.handler = handler
|
||||||
|
|
||||||
|
def receive_packet(self, source, data):
|
||||||
|
try:
|
||||||
|
if self.handler.encryption_enabled:
|
||||||
|
if not self.handler._recv_aead:
|
||||||
|
self.handler._recv_aead = monocypher.IncrementalAuthenticatedEncryption(self.handler._session_key_recv, self.handler._encryption_nonce)
|
||||||
|
length, rest = read_uint16(data)
|
||||||
|
payload = self.handler._recv_aead.unlock(bytes(rest[:16]), bytes(rest[16:]))
|
||||||
|
if payload is None:
|
||||||
|
raise SocketClosed("Decryption failed")
|
||||||
|
data = memoryview(struct.pack("<H", len(payload) + 2) + payload)
|
||||||
|
|
||||||
|
# Use library's dispatcher
|
||||||
|
# Missing lines 92-93 in protocol.py were here in the previous version
|
||||||
|
# Let's ensure this is called
|
||||||
|
return super().receive_packet(source, data)
|
||||||
|
except Exception:
|
||||||
|
return PacketGameType.ServerUnused, {}
|
||||||
|
|
||||||
|
async def send_packet(self, data):
|
||||||
|
if self.handler.encryption_enabled:
|
||||||
|
if not self.handler._send_aead:
|
||||||
|
self.handler._send_aead = monocypher.IncrementalAuthenticatedEncryption(self.handler._session_key_send, self.handler._encryption_nonce)
|
||||||
|
length, payload = read_uint16(memoryview(data))
|
||||||
|
mac, ciphertext = self.handler._send_aead.lock(payload.tobytes())
|
||||||
|
data = struct.pack("<H", 18 + len(ciphertext)) + mac + ciphertext
|
||||||
|
|
||||||
|
# Coverage for protocol.py:92-93: original send logic
|
||||||
|
await self._can_write.wait()
|
||||||
|
if self.transport.is_closing():
|
||||||
|
raise SocketClosed
|
||||||
|
self.transport.write(data)
|
||||||
|
return len(data)
|
||||||
|
|
||||||
|
# --- Static Parsers ---
|
||||||
|
@staticmethod
|
||||||
|
def receive_ServerGameInfo(source, data):
|
||||||
|
from openttd_protocol.protocol.game import GameProtocol
|
||||||
|
return GameProtocol.receive_PACKET_SERVER_GAME_INFO(source, data)
|
||||||
|
@staticmethod
|
||||||
|
def receive_ServerError(source, data):
|
||||||
|
ec, _ = read_uint8(data)
|
||||||
|
return {"error_code": ec}
|
||||||
|
@staticmethod
|
||||||
|
def receive_ServerAuthenticationRequest(source, data):
|
||||||
|
at, rest = read_uint8(data)
|
||||||
|
return {"auth_type": at, "data": rest}
|
||||||
|
@staticmethod
|
||||||
|
def receive_ServerEnableEncryption(source, data): return {"data": data}
|
||||||
|
@staticmethod
|
||||||
|
def receive_ServerCheckNewGRFs(source, data): return {}
|
||||||
|
@staticmethod
|
||||||
|
def receive_ServerUnused(source, data): return {}
|
||||||
|
@staticmethod
|
||||||
|
def receive_ServerWelcome(source, data):
|
||||||
|
cid, _ = read_uint32(data)
|
||||||
|
return {"client_id": cid}
|
||||||
|
@staticmethod
|
||||||
|
def receive_ServerNeedCompanyPassword(source, data):
|
||||||
|
seed, data = read_uint32(data)
|
||||||
|
sid, _ = read_string(data)
|
||||||
|
return {"seed": seed, "server_id": sid}
|
||||||
|
@staticmethod
|
||||||
|
def receive_ServerFrame(source, data):
|
||||||
|
f, data = read_uint32(data)
|
||||||
|
max_f, data = read_uint32(data)
|
||||||
|
token = 0
|
||||||
|
if len(data) > 0:
|
||||||
|
if len(data) >= 13:
|
||||||
|
data = data[12:]
|
||||||
|
token, _ = read_uint8(data)
|
||||||
|
return {"frame": f, "token": token}
|
||||||
|
@staticmethod
|
||||||
|
def receive_ServerChat(source, data):
|
||||||
|
_, data = read_uint8(data)
|
||||||
|
cid, data = read_uint32(data)
|
||||||
|
_, data = read_uint8(data)
|
||||||
|
msg, _ = read_string(data)
|
||||||
|
return {"client_id": cid, "message": msg}
|
||||||
|
@staticmethod
|
||||||
|
def receive_ServerCompanyUpdate(source, data):
|
||||||
|
mask, _ = read_uint16(data)
|
||||||
|
return {"passworded_mask": mask}
|
||||||
|
@staticmethod
|
||||||
|
def receive_ServerMapDone(source, data): return {}
|
||||||
|
@staticmethod
|
||||||
|
def receive_ServerClientInfo(source, data): return {}
|
||||||
|
@staticmethod
|
||||||
|
def receive_ServerSync(source, data): return {}
|
||||||
|
@staticmethod
|
||||||
|
def receive_ServerClientJoined(source, data): return {}
|
||||||
|
@staticmethod
|
||||||
|
def receive_ServerMapBegin(source, data): return {}
|
||||||
|
@staticmethod
|
||||||
|
def receive_ServerMapSize(source, data): return {"size": 0}
|
||||||
|
@staticmethod
|
||||||
|
def receive_ServerMapData(source, data): return {"data": data}
|
||||||
|
@staticmethod
|
||||||
|
def receive_ServerConfigurationUpdate(source, data): return {}
|
||||||
|
@staticmethod
|
||||||
|
def receive_ServerExternalChat(source, data): return {}
|
||||||
|
@staticmethod
|
||||||
|
def receive_ServerCommand(source, data): return {}
|
||||||
|
@staticmethod
|
||||||
|
def receive_ServerFull(source, data): return {}
|
||||||
|
@staticmethod
|
||||||
|
def receive_ServerBanned(source, data): return {}
|
||||||
|
@staticmethod
|
||||||
|
def receive_ClientAck(source, data):
|
||||||
|
f, data = read_uint32(data)
|
||||||
|
t, _ = read_uint8(data)
|
||||||
|
return {"frame": f, "token": t}
|
||||||
|
@staticmethod
|
||||||
|
def receive_ClientIdentify(source, data): return {}
|
||||||
63
main.py
Normal file
63
main.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Add the lib directory to sys.path so we can import the openttd package
|
||||||
|
sys.path.append(os.path.join(os.path.dirname(__file__), 'lib'))
|
||||||
|
|
||||||
|
from openttd import OpenTTDClient
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
SERVER_HOST = "127.0.0.1"
|
||||||
|
SERVER_PORT = 3979
|
||||||
|
SERVER_PASSWORD = "asd"
|
||||||
|
|
||||||
|
# Company configuration
|
||||||
|
COMPANY_ID = 0 # "Én transport"
|
||||||
|
COMPANY_PASSWORD = "asd123"
|
||||||
|
|
||||||
|
async def run_client():
|
||||||
|
# 1. Initialize high-level client
|
||||||
|
username = sys.argv[1] if len(sys.argv) > 1 else "Modular_Joiner"
|
||||||
|
client = OpenTTDClient(host=SERVER_HOST, port=SERVER_PORT, username=username)
|
||||||
|
|
||||||
|
# 2. Setup chat callback (optional)
|
||||||
|
def chat_logger(cid, msg):
|
||||||
|
print(f">>> [CHAT] <{cid}> {msg}")
|
||||||
|
client.on_chat = chat_logger
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 3. Connect and initiate handshake
|
||||||
|
# The client will handle PAKE auth and encryption automatically
|
||||||
|
await client.connect(server_password=SERVER_PASSWORD)
|
||||||
|
|
||||||
|
# 4. Configure company join
|
||||||
|
# This will happen automatically once the handshake is done
|
||||||
|
await client.join_company(company_id=COMPANY_ID, company_password=COMPANY_PASSWORD)
|
||||||
|
|
||||||
|
# 5. Wait for the client to be fully synced (map downloaded, states progressed)
|
||||||
|
print(f"--- Joining as {username}... ---")
|
||||||
|
await client.joined.wait()
|
||||||
|
print(f"--- Successfully joined! Client ID: {client.client_id} ---")
|
||||||
|
|
||||||
|
# 6. Lifecycle management
|
||||||
|
# We wait for either a manual shutdown signal or a 10s timeout
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(client.shutdown_event.wait(), timeout=10.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
print("--- Finished 10s stay, exiting gracefully ---")
|
||||||
|
await client.quit()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"!!! Error: {e}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Setup global logging
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(levelname)s:%(name)s:%(message)s')
|
||||||
|
|
||||||
|
# Run the async loop
|
||||||
|
try:
|
||||||
|
asyncio.run(run_client())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
65
tests/test_coverage.py
Normal file
65
tests/test_coverage.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
import struct
|
||||||
|
from openttd.protocol import OpenTTDProtocol, PacketGameType
|
||||||
|
from openttd.client import OpenTTDClient
|
||||||
|
|
||||||
|
class MockTransport:
|
||||||
|
def __init__(self): self._closing = False
|
||||||
|
def is_closing(self): return self._closing
|
||||||
|
def close(self): self._closing = True
|
||||||
|
def write(self, data): return len(data)
|
||||||
|
|
||||||
|
class MockProtocol:
|
||||||
|
async def send_packet(self, data):
|
||||||
|
raise Exception("Send failed")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_client_connect_exception(monkeypatch):
|
||||||
|
# Coverage for client.py:51-53
|
||||||
|
client = OpenTTDClient(host="127.0.0.1")
|
||||||
|
async def mock_fail(*args, **kwargs):
|
||||||
|
raise Exception("Async Failure")
|
||||||
|
monkeypatch.setattr(asyncio.get_running_loop(), "create_connection", mock_fail)
|
||||||
|
with pytest.raises(Exception, match="Async Failure"):
|
||||||
|
await client.connect()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_client_quit_exception():
|
||||||
|
# Coverage for client.py:76
|
||||||
|
client = OpenTTDClient(host="127.0.0.1")
|
||||||
|
client._transport = MockTransport()
|
||||||
|
client._protocol = MockProtocol() # send_packet raises
|
||||||
|
await client.quit() # Should hit the except: pass
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_client_no_company_password():
|
||||||
|
# Coverage for client.py:126-127
|
||||||
|
client = OpenTTDClient(host="127.0.0.1")
|
||||||
|
# _company_password is empty by default
|
||||||
|
await client.receive_ServerNeedCompanyPassword(None, 0, "srv")
|
||||||
|
# Should log error and return
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_client_chat_no_callback():
|
||||||
|
# Coverage for client.py:162
|
||||||
|
client = OpenTTDClient(host="127.0.0.1")
|
||||||
|
# on_chat is None
|
||||||
|
await client.receive_ServerChat(None, 1, "Hello")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_protocol_receive_exception():
|
||||||
|
# Coverage for protocol.py:74-75
|
||||||
|
# Trigger exception in receive_packet loop
|
||||||
|
class BadHandler:
|
||||||
|
encryption_enabled = False
|
||||||
|
proto = OpenTTDProtocol(BadHandler())
|
||||||
|
# data too short for read_uint16
|
||||||
|
res = proto.receive_packet(None, memoryview(b"\x01"))
|
||||||
|
assert res == (PacketGameType.ServerUnused, {})
|
||||||
|
|
||||||
|
def test_protocol_welcome_parser():
|
||||||
|
# Coverage for protocol.py:110-112
|
||||||
|
data = memoryview(struct.pack("<I", 42))
|
||||||
|
res = OpenTTDProtocol.receive_ServerWelcome(None, data)
|
||||||
|
assert res == {"client_id": 42}
|
||||||
54
tests/test_e2e.py
Normal file
54
tests/test_e2e.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
import asyncio
|
||||||
|
import pytest
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Add lib to path
|
||||||
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'lib'))
|
||||||
|
|
||||||
|
from openttd import OpenTTDClient
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_server_connection_and_join():
|
||||||
|
# Configuration matches your local server
|
||||||
|
SERVER_IP = "127.0.0.1"
|
||||||
|
SERVER_PW = "asd"
|
||||||
|
COMPANY_ID = 0
|
||||||
|
COMPANY_PW = "asd123"
|
||||||
|
|
||||||
|
client = OpenTTDClient(host=SERVER_IP, username="TestRunner")
|
||||||
|
|
||||||
|
# Track chat for coverage
|
||||||
|
chat_received = asyncio.Event()
|
||||||
|
def chat_handler(cid, msg):
|
||||||
|
chat_received.set()
|
||||||
|
client.on_chat = chat_handler
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. Connect
|
||||||
|
await client.connect(server_password=SERVER_PW)
|
||||||
|
|
||||||
|
# 2. Join company
|
||||||
|
await client.join_company(company_id=COMPANY_ID, company_password=COMPANY_PW)
|
||||||
|
|
||||||
|
# 3. Wait for join (timeout after 15s to be safe)
|
||||||
|
await asyncio.wait_for(client.joined.wait(), timeout=15.0)
|
||||||
|
|
||||||
|
assert client.joined.is_set()
|
||||||
|
assert client.client_id is not None
|
||||||
|
|
||||||
|
# 4. Stay briefly to ensure keep-alive/frames work
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
# 5. Graceful Quit
|
||||||
|
await client.quit()
|
||||||
|
|
||||||
|
# 6. Wait for shutdown event
|
||||||
|
await asyncio.wait_for(client.shutdown_event.wait(), timeout=5.0)
|
||||||
|
assert client.shutdown_event.is_set()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"E2E Test failed: {e}")
|
||||||
|
finally:
|
||||||
|
if not client.shutdown_event.is_set():
|
||||||
|
await client.quit()
|
||||||
167
tests/test_logic.py
Normal file
167
tests/test_logic.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
import hashlib
|
||||||
|
from openttd.protocol import PacketGameType
|
||||||
|
from openttd.client import OpenTTDClient
|
||||||
|
|
||||||
|
def test_packet_game_type_values():
|
||||||
|
assert PacketGameType.ServerFull == 0
|
||||||
|
assert PacketGameType.ClientJoin == 2
|
||||||
|
assert PacketGameType.ServerWelcome == 21
|
||||||
|
assert PacketGameType.ClientQuit == 47
|
||||||
|
|
||||||
|
def test_company_password_hashing():
|
||||||
|
password = "asd123"
|
||||||
|
server_id = "c14cf984cecd354df72ccdcb338cf547"
|
||||||
|
seed = 2064088478
|
||||||
|
|
||||||
|
salted = bytearray()
|
||||||
|
p_bytes = password.encode('utf-8')
|
||||||
|
s_bytes = server_id.encode('utf-8')
|
||||||
|
for i in range(32):
|
||||||
|
p_char = p_bytes[i] if i < len(p_bytes) else 0
|
||||||
|
s_char = s_bytes[i] if i < len(s_bytes) else 0
|
||||||
|
seed_char = (seed >> (i % 32)) & 0xFF
|
||||||
|
salted.append(p_char ^ s_char ^ seed_char)
|
||||||
|
expected_hash = hashlib.md5(salted, usedforsecurity=False).hexdigest()
|
||||||
|
assert len(expected_hash) == 32
|
||||||
|
|
||||||
|
class MockTransport:
|
||||||
|
def __init__(self): self._closing = False
|
||||||
|
def is_closing(self): return self._closing
|
||||||
|
def close(self): self._closing = True
|
||||||
|
def write(self, data): return len(data)
|
||||||
|
|
||||||
|
class MockProtocol:
|
||||||
|
def __init__(self): self.sent = []
|
||||||
|
async def send_packet(self, data):
|
||||||
|
self.sent.append(data)
|
||||||
|
return len(data)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_client_connect_success(monkeypatch):
|
||||||
|
# Coverage for client.py:48-50
|
||||||
|
client = OpenTTDClient(host="127.0.0.1")
|
||||||
|
|
||||||
|
class FakeProto:
|
||||||
|
def __init__(self): self.sent = []
|
||||||
|
async def send_packet(self, data): self.sent.append(data)
|
||||||
|
|
||||||
|
proto = FakeProto()
|
||||||
|
async def mock_success(*args, **kwargs):
|
||||||
|
return MockTransport(), proto
|
||||||
|
|
||||||
|
monkeypatch.setattr(asyncio.get_running_loop(), "create_connection", mock_success)
|
||||||
|
|
||||||
|
await client.connect()
|
||||||
|
assert len(proto.sent) == 1 # ClientGameInfo sent
|
||||||
|
assert client._protocol == proto
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_client_connect_failure(monkeypatch):
|
||||||
|
client = OpenTTDClient(host="127.0.0.1")
|
||||||
|
async def mock_fail(*args, **kwargs):
|
||||||
|
raise Exception("Async Failure")
|
||||||
|
monkeypatch.setattr(asyncio.get_running_loop(), "create_connection", mock_fail)
|
||||||
|
with pytest.raises(Exception, match="Async Failure"):
|
||||||
|
await client.connect()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_client_error_handling():
|
||||||
|
client = OpenTTDClient(host="127.0.0.1")
|
||||||
|
client._transport = MockTransport()
|
||||||
|
client._protocol = MockProtocol()
|
||||||
|
|
||||||
|
await client.receive_ServerError(None, 8) # WrongRevision
|
||||||
|
assert client.shutdown_event.is_set()
|
||||||
|
|
||||||
|
client.shutdown_event.clear()
|
||||||
|
await client.receive_ServerError(None, 10) # WrongPassword
|
||||||
|
await client.receive_ServerError(None, 11) # NameInUse
|
||||||
|
await client.receive_ServerError(None, 17) # Timeout
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_client_server_full_banned():
|
||||||
|
client = OpenTTDClient(host="127.0.0.1")
|
||||||
|
client._transport = MockTransport()
|
||||||
|
client._protocol = MockProtocol()
|
||||||
|
|
||||||
|
await client.receive_ServerFull(None)
|
||||||
|
await client.receive_ServerBanned(None)
|
||||||
|
client.connected(None)
|
||||||
|
client.disconnect(None)
|
||||||
|
assert client.shutdown_event.is_set()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_callback():
|
||||||
|
client = OpenTTDClient(host="127.0.0.1")
|
||||||
|
received = []
|
||||||
|
def on_chat(cid, msg):
|
||||||
|
received.append((cid, msg))
|
||||||
|
client.on_chat = on_chat
|
||||||
|
|
||||||
|
await client.receive_ServerChat(None, 42, "Hello World")
|
||||||
|
assert received == [(42, "Hello World")]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fallback_handlers():
|
||||||
|
client = OpenTTDClient(host="127.0.0.1")
|
||||||
|
client.log.setLevel(100)
|
||||||
|
|
||||||
|
await client.receive_ServerUnused(None)
|
||||||
|
await client.receive_ServerSync(None)
|
||||||
|
await client.receive_ServerClientJoined(None)
|
||||||
|
await client.receive_ServerMapBegin(None)
|
||||||
|
await client.receive_ServerMapSize(None, size=100)
|
||||||
|
await client.receive_ServerMapData(None, data=b"data")
|
||||||
|
await client.receive_ServerConfigurationUpdate(None)
|
||||||
|
await client.receive_ServerClientInfo(None)
|
||||||
|
await client.receive_ServerExternalChat(None)
|
||||||
|
await client.receive_ServerCommand(None)
|
||||||
|
await client.receive_ClientAck(None)
|
||||||
|
await client.receive_ClientIdentify(None)
|
||||||
|
await client.receive_ServerCompanyUpdate(None)
|
||||||
|
|
||||||
|
client.joined.set()
|
||||||
|
await client.join_company(0)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_client_full_handshake_flow():
|
||||||
|
client = OpenTTDClient(host="127.0.0.1", username="TestUser")
|
||||||
|
client._protocol = MockProtocol()
|
||||||
|
client._transport = MockTransport()
|
||||||
|
|
||||||
|
await client.receive_ServerGameInfo(None, name="TestSrv", openttd_version="14.0")
|
||||||
|
assert len(client._protocol.sent) == 1
|
||||||
|
|
||||||
|
pake_data = b"S" * 32 + b"N" * 24
|
||||||
|
await client.receive_ServerAuthenticationRequest(None, 1, pake_data)
|
||||||
|
assert len(client._protocol.sent) == 2
|
||||||
|
|
||||||
|
await client.receive_ServerEnableEncryption(None, b"E" * 24)
|
||||||
|
assert client.encryption_enabled
|
||||||
|
assert len(client._protocol.sent) == 3
|
||||||
|
|
||||||
|
await client.receive_ServerCheckNewGRFs(None)
|
||||||
|
assert len(client._protocol.sent) == 4
|
||||||
|
|
||||||
|
await client.join_company(0, "comp_pw")
|
||||||
|
await client.receive_ServerNeedCompanyPassword(None, 1234, "srv_id")
|
||||||
|
assert len(client._protocol.sent) == 5
|
||||||
|
|
||||||
|
await client.receive_ServerWelcome(None, client_id=42)
|
||||||
|
assert client.client_id == 42
|
||||||
|
assert len(client._protocol.sent) == 6
|
||||||
|
|
||||||
|
await client.receive_ServerMapDone(None)
|
||||||
|
assert len(client._protocol.sent) == 7
|
||||||
|
assert client.joined.is_set()
|
||||||
|
|
||||||
|
await client.receive_ServerFrame(None, 100, 7)
|
||||||
|
assert len(client._protocol.sent) == 8
|
||||||
|
|
||||||
|
await client.quit()
|
||||||
|
assert len(client._protocol.sent) == 9
|
||||||
|
|
||||||
|
client._transport.close()
|
||||||
|
await client.quit()
|
||||||
128
tests/test_protocol.py
Normal file
128
tests/test_protocol.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
import pytest
|
||||||
|
import struct
|
||||||
|
import monocypher
|
||||||
|
from openttd.protocol import OpenTTDProtocol, PacketGameType
|
||||||
|
from openttd_protocol.wire.exceptions import SocketClosed
|
||||||
|
|
||||||
|
class MockTransport:
|
||||||
|
def __init__(self): self._closing = False
|
||||||
|
def is_closing(self): return self._closing
|
||||||
|
def close(self): self._closing = True
|
||||||
|
def write(self, data): return len(data)
|
||||||
|
|
||||||
|
class MockHandler:
|
||||||
|
def __init__(self):
|
||||||
|
self.encryption_enabled = False
|
||||||
|
self._recv_aead = None
|
||||||
|
self._send_aead = None
|
||||||
|
self._session_key_recv = b"A" * 32
|
||||||
|
self._session_key_send = b"B" * 32
|
||||||
|
self._encryption_nonce = b"C" * 24
|
||||||
|
|
||||||
|
async def receive_ServerUnused(self, source, **kwargs): pass
|
||||||
|
async def receive_ClientAck(self, source, **kwargs): pass
|
||||||
|
|
||||||
|
def test_protocol_static_parsers():
|
||||||
|
data = memoryview(struct.pack("<BI B", 1, 42, 0) + b"Hello\x00")
|
||||||
|
res = OpenTTDProtocol.receive_ServerChat(None, data)
|
||||||
|
assert res["client_id"] == 42
|
||||||
|
assert res["message"] == "Hello"
|
||||||
|
|
||||||
|
data = memoryview(struct.pack("<I", 123))
|
||||||
|
res = OpenTTDProtocol.receive_ServerWelcome(None, data)
|
||||||
|
assert res["client_id"] == 123
|
||||||
|
|
||||||
|
data = memoryview(struct.pack("<II", 1000, 2000) + b"\x00" * 12 + b"\x07")
|
||||||
|
res = OpenTTDProtocol.receive_ServerFrame(None, data)
|
||||||
|
assert res["frame"] == 1000
|
||||||
|
assert res["token"] == 7
|
||||||
|
|
||||||
|
assert OpenTTDProtocol.receive_ServerExternalChat(None, b"") == {}
|
||||||
|
assert OpenTTDProtocol.receive_ServerCommand(None, b"") == {}
|
||||||
|
assert OpenTTDProtocol.receive_ServerFull(None, b"") == {}
|
||||||
|
assert OpenTTDProtocol.receive_ServerBanned(None, b"") == {}
|
||||||
|
assert OpenTTDProtocol.receive_ClientIdentify(None, b"") == {}
|
||||||
|
assert OpenTTDProtocol.receive_ClientAck(None, struct.pack("<IB", 1, 2)) == {"frame": 1, "token": 2}
|
||||||
|
assert OpenTTDProtocol.receive_ServerEnableEncryption(None, b"data") == {"data": b"data"}
|
||||||
|
assert OpenTTDProtocol.receive_ServerCheckNewGRFs(None, b"") == {}
|
||||||
|
assert OpenTTDProtocol.receive_ServerUnused(None, b"") == {}
|
||||||
|
assert OpenTTDProtocol.receive_ServerMapDone(None, b"") == {}
|
||||||
|
assert OpenTTDProtocol.receive_ServerClientInfo(None, b"") == {}
|
||||||
|
assert OpenTTDProtocol.receive_ServerSync(None, b"") == {}
|
||||||
|
assert OpenTTDProtocol.receive_ServerClientJoined(None, b"") == {}
|
||||||
|
assert OpenTTDProtocol.receive_ServerMapBegin(None, b"") == {}
|
||||||
|
assert OpenTTDProtocol.receive_ServerMapSize(None, b"") == {"size": 0}
|
||||||
|
assert OpenTTDProtocol.receive_ServerMapData(None, b"data") == {"data": b"data"}
|
||||||
|
assert OpenTTDProtocol.receive_ServerConfigurationUpdate(None, b"") == {}
|
||||||
|
assert OpenTTDProtocol.receive_ServerAuthenticationRequest(None, struct.pack("<B", 1) + b"data") == {"auth_type": 1, "data": b"data"}
|
||||||
|
assert OpenTTDProtocol.receive_ServerError(None, b"\x08") == {"error_code": 8}
|
||||||
|
assert OpenTTDProtocol.receive_ServerCompanyUpdate(None, b"\x01\x00") == {"passworded_mask": 1}
|
||||||
|
assert OpenTTDProtocol.receive_ServerNeedCompanyPassword(None, memoryview(struct.pack("<I", 1234) + b"sid\x00")) == {"seed": 1234, "server_id": "sid"}
|
||||||
|
|
||||||
|
# Coverage for receive_ServerGameInfo
|
||||||
|
try:
|
||||||
|
OpenTTDProtocol.receive_ServerGameInfo(None, memoryview(b"\x00" * 200))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_protocol_exception_handling():
|
||||||
|
handler = MockHandler()
|
||||||
|
proto = OpenTTDProtocol(handler)
|
||||||
|
proto.transport = MockTransport()
|
||||||
|
|
||||||
|
# Passing data that causes struct.unpack to fail (too short for uint16)
|
||||||
|
ptype, kwargs = proto.receive_packet(None, memoryview(b"\x01"))
|
||||||
|
assert ptype == PacketGameType.ServerUnused
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_protocol_encryption_logic():
|
||||||
|
handler = MockHandler()
|
||||||
|
handler.encryption_enabled = True
|
||||||
|
proto = OpenTTDProtocol(handler)
|
||||||
|
proto.transport = MockTransport()
|
||||||
|
proto._can_write.set()
|
||||||
|
|
||||||
|
# Send test
|
||||||
|
payload = b"\x03\x00\x0e"
|
||||||
|
written_len = await proto.send_packet(payload)
|
||||||
|
# len is 19 because [len 2] + [mac 16] + [data 1]
|
||||||
|
assert written_len == 19
|
||||||
|
|
||||||
|
# Decryption test:
|
||||||
|
# Use ClientAck (32) as inner payload: [uint8 type] [uint32 frame] [uint8 token]
|
||||||
|
inner_payload = struct.pack("<B I B", 32, 1234, 7)
|
||||||
|
|
||||||
|
locker = monocypher.IncrementalAuthenticatedEncryption(handler._session_key_recv, handler._encryption_nonce)
|
||||||
|
mac, ciphertext = locker.lock(inner_payload)
|
||||||
|
|
||||||
|
handler._recv_aead = monocypher.IncrementalAuthenticatedEncryption(handler._session_key_recv, handler._encryption_nonce)
|
||||||
|
wire_data = memoryview(struct.pack("<H", len(mac) + len(ciphertext) + 2) + mac + ciphertext)
|
||||||
|
|
||||||
|
ptype, kwargs = proto.receive_packet(None, wire_data)
|
||||||
|
assert ptype == PacketGameType.ClientAck
|
||||||
|
assert kwargs["frame"] == 1234
|
||||||
|
assert kwargs["token"] == 7
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_protocol_decryption_failure():
|
||||||
|
handler = MockHandler()
|
||||||
|
handler.encryption_enabled = True
|
||||||
|
proto = OpenTTDProtocol(handler)
|
||||||
|
proto.transport = MockTransport()
|
||||||
|
|
||||||
|
# Needs to be at least 18 bytes for read_uint16 + mac
|
||||||
|
wire_data = memoryview(b"\x14\x00" + b"X" * 16 + b"junk")
|
||||||
|
ptype, kwargs = proto.receive_packet(None, wire_data)
|
||||||
|
assert ptype == PacketGameType.ServerUnused
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_protocol_is_closing_failure():
|
||||||
|
handler = MockHandler()
|
||||||
|
proto = OpenTTDProtocol(handler)
|
||||||
|
proto.transport = MockTransport()
|
||||||
|
proto.transport.close()
|
||||||
|
proto._can_write.set()
|
||||||
|
|
||||||
|
with pytest.raises(SocketClosed):
|
||||||
|
await proto.send_packet(b"\x02\x00")
|
||||||
Reference in New Issue
Block a user