add multi threading support

This commit is contained in:
Yuzu 2025-04-12 16:27:37 -05:00
parent f57bbdcf5d
commit acc539c342
3 changed files with 55 additions and 14 deletions

View File

@ -68,10 +68,10 @@ Contributions are welcome! Please open an issue or pull request if you'd like to
## general roadmap ## general roadmap
| Task | Status | | Task | Status |
|-------------------------------------------------------------|--------| |-------------------------------------------------------------|--------|
| stablish good sharding support with buckets | ✅ | | stablish good sharding support w buckets and worker threads | ✅ |
| finish multi threading | ✅ |
| finish the event coverage roadmap | ✅ | | finish the event coverage roadmap | ✅ |
| proper error handling | ✅ | | proper error handling | ✅ |
| use the priority queues for handling ratelimits (half done) | ❌ |
| make the library scalable with a gateway proxy | ❌ | | make the library scalable with a gateway proxy | ❌ |
| get a cool logo | ❌ | | get a cool logo | ❌ |

View File

@ -37,6 +37,10 @@ connect_queue: ConnectQueue(Shard),
shards: std.AutoArrayHashMap(usize, Shard), shards: std.AutoArrayHashMap(usize, Shard),
handler: GatewayDispatchEvent(*Shard), handler: GatewayDispatchEvent(*Shard),
/// where we dispatch work for every thread, threads must be spawned upon shard creation
/// make sure the address of workers is stable
workers: std.Thread.Pool = undefined,
/// configuration settings /// configuration settings
options: SessionOptions, options: SessionOptions,
log: Log, log: Log,
@ -64,6 +68,10 @@ pub const SessionOptions = struct {
shard_end: usize = 1, shard_end: usize = 1,
/// The payload handlers for messages on the shard. /// The payload handlers for messages on the shard.
resharding: ?struct { interval: u64, percentage: usize } = null, resharding: ?struct { interval: u64, percentage: usize } = null,
/// worker threads
workers_per_shard: usize = 1,
/// The shard lifespan in milliseconds. If a shard is not connected within this time, it will be closed.
shard_lifespan: ?u64 = null,
}; };
pub fn init(allocator: mem.Allocator, settings: struct { pub fn init(allocator: mem.Allocator, settings: struct {
@ -78,6 +86,7 @@ pub fn init(allocator: mem.Allocator, settings: struct {
.allocator = allocator, .allocator = allocator,
.connect_queue = try ConnectQueue(Shard).init(allocator, concurrency, 5000), .connect_queue = try ConnectQueue(Shard).init(allocator, concurrency, 5000),
.shards = .init(allocator), .shards = .init(allocator),
.workers = undefined,
.shard_details = ShardDetails{ .shard_details = ShardDetails{
.token = settings.token, .token = settings.token,
.intents = settings.intents, .intents = settings.intents,
@ -92,6 +101,7 @@ pub fn init(allocator: mem.Allocator, settings: struct {
.total_shards = settings.options.total_shards, .total_shards = settings.options.total_shards,
.shard_start = settings.options.shard_start, .shard_start = settings.options.shard_start,
.shard_end = settings.options.shard_end, .shard_end = settings.options.shard_end,
.workers_per_shard = settings.options.workers_per_shard,
}, },
.log = settings.log, .log = settings.log,
}; };
@ -147,6 +157,13 @@ fn spawnBuckets(self: *Self) ![][]Shard {
self.logif("{d} buckets created", .{bucket_count}); self.logif("{d} buckets created", .{bucket_count});
// finally defihne threads
try self.workers.init(.{
.allocator = self.allocator,
.n_jobs = self.options.workers_per_shard * self.options.total_shards,
});
return buckets; return buckets;
} }
@ -163,6 +180,7 @@ fn create(self: *Self, shard_id: usize) !Shard {
}, },
.run = self.handler, .run = self.handler,
.log = self.log, .log = self.log,
.sharder_pool = &self.workers,
}); });
try self.shards.put(shard_id, shard); try self.shards.put(shard_id, shard);

View File

@ -89,7 +89,8 @@ session_id: ?[]const u8,
sequence: std.atomic.Value(isize) = .init(0), sequence: std.atomic.Value(isize) = .init(0),
heart: Heart = .{ .heartbeatInterval = 45000, .lastBeat = 0 }, heart: Heart = .{ .heartbeatInterval = 45000, .lastBeat = 0 },
/// // we only need to know whether this shard is part of a thread pool, and if so, initialise it with a pointer thereon
sharder_pool: ?*std.Thread.Pool = null,
handler: GatewayDispatchEvent(*Self), handler: GatewayDispatchEvent(*Self),
packets: std.ArrayListUnmanaged(u8), packets: std.ArrayListUnmanaged(u8),
inflator: zlib.Decompressor, inflator: zlib.Decompressor,
@ -151,6 +152,7 @@ pub fn init(allocator: mem.Allocator, shard_id: usize, total_shards: usize, sett
options: ShardOptions, options: ShardOptions,
run: GatewayDispatchEvent(*Self), run: GatewayDispatchEvent(*Self),
log: Log, log: Log,
sharder_pool: ?*std.Thread.Pool = null,
}) zlib.Error!Self { }) zlib.Error!Self {
return Self{ return Self{
.options = ShardOptions{ .options = ShardOptions{
@ -181,6 +183,7 @@ pub fn init(allocator: mem.Allocator, shard_id: usize, total_shards: usize, sett
settings.options.ratelimit_options.ratelimit_reset_interval, settings.options.ratelimit_options.ratelimit_reset_interval,
Self.calculateSafeRequests(settings.options.ratelimit_options), Self.calculateSafeRequests(settings.options.ratelimit_options),
), ),
.sharder_pool = settings.sharder_pool,
}; };
} }
@ -250,24 +253,37 @@ fn readMessage(self: *Self, _: anytype) !void {
/// The event name for this payload /// The event name for this payload
t: ?[]const u8 = null, t: ?[]const u8 = null,
}; };
// must allocate to avoid race conditions
const payload = try self.allocator.create(std.json.Value);
const raw = try std.json.parseFromSlice(GatewayPayloadType, self.allocator, decompressed, .{ const raw = try std.json.parseFromSlice(GatewayPayloadType, self.allocator, decompressed, .{
.ignore_unknown_fields = true, .ignore_unknown_fields = true,
.max_value_len = 0x1000, .max_value_len = 0x1000,
}); });
errdefer raw.deinit(); errdefer raw.deinit();
const payload = raw.value; // make sure to avoid race conditions
switch (@as(Opcode, @enumFromInt(payload.op))) { // we free this payload eventually once our event executes
payload.* = raw.value.d.?;
switch (@as(Opcode, @enumFromInt(raw.value.op))) {
.Dispatch => { .Dispatch => {
// maybe use threads and call it instead from there if (raw.value.t) |some_name| {
if (payload.t) |name| { self.sequence.store(raw.value.s orelse 0, .monotonic);
self.sequence.store(payload.s orelse 0, .monotonic);
try self.handleEvent(name, payload.d.?); const name = try self.allocator.alloc(u8, some_name.len);
std.mem.copyForwards(u8, name, some_name);
// run thread pool
if (self.sharder_pool) |sharder_pool| {
try sharder_pool.spawn(handleEventNoError, .{ self, name, payload });
} else try self.handleEvent(name, payload.*);
} }
}, },
.Hello => { .Hello => {
const HelloPayload = struct { heartbeat_interval: u64, _trace: [][]const u8 }; const HelloPayload = struct { heartbeat_interval: u64, _trace: [][]const u8 };
const parsed = try std.json.parseFromValue(HelloPayload, self.allocator, payload.d.?, .{}); const parsed = try std.json.parseFromValue(HelloPayload, self.allocator, payload.*, .{});
defer parsed.deinit(); defer parsed.deinit();
const helloPayload = parsed.value; const helloPayload = parsed.value;
@ -289,7 +305,7 @@ fn readMessage(self: *Self, _: anytype) !void {
var prng = std.Random.DefaultPrng.init(0); var prng = std.Random.DefaultPrng.init(0);
const jitter = std.Random.float(prng.random(), f64); const jitter = std.Random.float(prng.random(), f64);
self.heart.lastBeat = std.time.milliTimestamp(); self.heart.lastBeat = std.time.milliTimestamp();
const heartbeat_writer = try std.Thread.spawn(.{}, Self.heartbeat, .{ self, jitter }); const heartbeat_writer = try std.Thread.spawn(.{}, heartbeat, .{ self, jitter });
heartbeat_writer.detach(); heartbeat_writer.detach();
}, },
.HeartbeatACK => { .HeartbeatACK => {
@ -312,7 +328,7 @@ fn readMessage(self: *Self, _: anytype) !void {
session_id: []const u8, session_id: []const u8,
seq: ?isize, seq: ?isize,
}; };
const parsed = try std.json.parseFromValue(WithSequence, self.allocator, payload.d.?, .{}); const parsed = try std.json.parseFromValue(WithSequence, self.allocator, payload.*, .{});
defer parsed.deinit(); defer parsed.deinit();
const resume_payload = parsed.value; const resume_payload = parsed.value;
@ -374,8 +390,6 @@ pub const ConnectError =
pub fn connect(self: *Self) ConnectError!void { pub fn connect(self: *Self) ConnectError!void {
//std.time.sleep(std.time.ms_per_s * 5); //std.time.sleep(std.time.ms_per_s * 5);
self.client = try Self._connect_ws(self.allocator, self.gatewayUrl()); self.client = try Self._connect_ws(self.allocator, self.gatewayUrl());
//const event_listener = try std.Thread.spawn(.{}, Self.readMessage, .{ &self, null });
//event_listener.join();
self.readMessage(null) catch unreachable; self.readMessage(null) catch unreachable;
} }
@ -408,6 +422,15 @@ pub fn send(self: *Self, _: bool, data: anytype) SendError!void {
try self.client.write(try string.toOwnedSlice()); try self.client.write(try string.toOwnedSlice());
} }
pub fn handleEventNoError(self: *Self, name: []const u8, payload_ptr: *json.Value) void {
// log to make sure this executes
std.debug.print("Shard {d} dispatching {s}\n", .{self.id, name});
self.handleEvent(name, payload_ptr.*) catch |err| {
std.debug.print("Shard {d} error: {s}\n", .{self.id, @errorName(err)});
};
}
pub fn handleEvent(self: *Self, name: []const u8, payload: json.Value) !void { pub fn handleEvent(self: *Self, name: []const u8, payload: json.Value) !void {
if (mem.eql(u8, name, "READY")) if (self.handler.ready) |event| { if (mem.eql(u8, name, "READY")) if (self.handler.ready) |event| {
const ready = try json.parseFromValue(Types.Ready, self.allocator, payload, .{ const ready = try json.parseFromValue(Types.Ready, self.allocator, payload, .{