diff --git a/README.md b/README.md index 57acd97..8b8c098 100644 --- a/README.md +++ b/README.md @@ -68,10 +68,10 @@ Contributions are welcome! Please open an issue or pull request if you'd like to ## general roadmap | Task | Status | |-------------------------------------------------------------|--------| -| stablish good sharding support with buckets | ✅ | +| stablish good sharding support w buckets and worker threads | ✅ | +| finish multi threading | ✅ | | finish the event coverage roadmap | ✅ | | proper error handling | ✅ | -| use the priority queues for handling ratelimits (half done) | ❌ | | make the library scalable with a gateway proxy | ❌ | | get a cool logo | ❌ | diff --git a/src/core.zig b/src/core.zig index 4d897dc..79b9dbb 100644 --- a/src/core.zig +++ b/src/core.zig @@ -37,6 +37,10 @@ connect_queue: ConnectQueue(Shard), shards: std.AutoArrayHashMap(usize, Shard), handler: GatewayDispatchEvent(*Shard), +/// where we dispatch work for every thread, threads must be spawned upon shard creation +/// make sure the address of workers is stable +workers: std.Thread.Pool = undefined, + /// configuration settings options: SessionOptions, log: Log, @@ -64,6 +68,10 @@ pub const SessionOptions = struct { shard_end: usize = 1, /// The payload handlers for messages on the shard. resharding: ?struct { interval: u64, percentage: usize } = null, + /// worker threads + workers_per_shard: usize = 1, + /// The shard lifespan in milliseconds. If a shard is not connected within this time, it will be closed. + shard_lifespan: ?u64 = null, }; pub fn init(allocator: mem.Allocator, settings: struct { @@ -78,6 +86,7 @@ pub fn init(allocator: mem.Allocator, settings: struct { .allocator = allocator, .connect_queue = try ConnectQueue(Shard).init(allocator, concurrency, 5000), .shards = .init(allocator), + .workers = undefined, .shard_details = ShardDetails{ .token = settings.token, .intents = settings.intents, @@ -92,6 +101,7 @@ pub fn init(allocator: mem.Allocator, settings: struct { .total_shards = settings.options.total_shards, .shard_start = settings.options.shard_start, .shard_end = settings.options.shard_end, + .workers_per_shard = settings.options.workers_per_shard, }, .log = settings.log, }; @@ -147,6 +157,13 @@ fn spawnBuckets(self: *Self) ![][]Shard { self.logif("{d} buckets created", .{bucket_count}); + // finally defihne threads + + try self.workers.init(.{ + .allocator = self.allocator, + .n_jobs = self.options.workers_per_shard * self.options.total_shards, + }); + return buckets; } @@ -163,6 +180,7 @@ fn create(self: *Self, shard_id: usize) !Shard { }, .run = self.handler, .log = self.log, + .sharder_pool = &self.workers, }); try self.shards.put(shard_id, shard); diff --git a/src/shard.zig b/src/shard.zig index 4472c33..f5176af 100644 --- a/src/shard.zig +++ b/src/shard.zig @@ -89,7 +89,8 @@ session_id: ?[]const u8, sequence: std.atomic.Value(isize) = .init(0), heart: Heart = .{ .heartbeatInterval = 45000, .lastBeat = 0 }, -/// +// we only need to know whether this shard is part of a thread pool, and if so, initialise it with a pointer thereon +sharder_pool: ?*std.Thread.Pool = null, handler: GatewayDispatchEvent(*Self), packets: std.ArrayListUnmanaged(u8), inflator: zlib.Decompressor, @@ -151,6 +152,7 @@ pub fn init(allocator: mem.Allocator, shard_id: usize, total_shards: usize, sett options: ShardOptions, run: GatewayDispatchEvent(*Self), log: Log, + sharder_pool: ?*std.Thread.Pool = null, }) zlib.Error!Self { return Self{ .options = ShardOptions{ @@ -181,6 +183,7 @@ pub fn init(allocator: mem.Allocator, shard_id: usize, total_shards: usize, sett settings.options.ratelimit_options.ratelimit_reset_interval, Self.calculateSafeRequests(settings.options.ratelimit_options), ), + .sharder_pool = settings.sharder_pool, }; } @@ -250,24 +253,37 @@ fn readMessage(self: *Self, _: anytype) !void { /// The event name for this payload t: ?[]const u8 = null, }; + + // must allocate to avoid race conditions + const payload = try self.allocator.create(std.json.Value); + const raw = try std.json.parseFromSlice(GatewayPayloadType, self.allocator, decompressed, .{ .ignore_unknown_fields = true, .max_value_len = 0x1000, }); errdefer raw.deinit(); - const payload = raw.value; - switch (@as(Opcode, @enumFromInt(payload.op))) { + // make sure to avoid race conditions + // we free this payload eventually once our event executes + payload.* = raw.value.d.?; + + switch (@as(Opcode, @enumFromInt(raw.value.op))) { .Dispatch => { - // maybe use threads and call it instead from there - if (payload.t) |name| { - self.sequence.store(payload.s orelse 0, .monotonic); - try self.handleEvent(name, payload.d.?); + if (raw.value.t) |some_name| { + self.sequence.store(raw.value.s orelse 0, .monotonic); + + const name = try self.allocator.alloc(u8, some_name.len); + std.mem.copyForwards(u8, name, some_name); + + // run thread pool + if (self.sharder_pool) |sharder_pool| { + try sharder_pool.spawn(handleEventNoError, .{ self, name, payload }); + } else try self.handleEvent(name, payload.*); } }, .Hello => { const HelloPayload = struct { heartbeat_interval: u64, _trace: [][]const u8 }; - const parsed = try std.json.parseFromValue(HelloPayload, self.allocator, payload.d.?, .{}); + const parsed = try std.json.parseFromValue(HelloPayload, self.allocator, payload.*, .{}); defer parsed.deinit(); const helloPayload = parsed.value; @@ -289,7 +305,7 @@ fn readMessage(self: *Self, _: anytype) !void { var prng = std.Random.DefaultPrng.init(0); const jitter = std.Random.float(prng.random(), f64); self.heart.lastBeat = std.time.milliTimestamp(); - const heartbeat_writer = try std.Thread.spawn(.{}, Self.heartbeat, .{ self, jitter }); + const heartbeat_writer = try std.Thread.spawn(.{}, heartbeat, .{ self, jitter }); heartbeat_writer.detach(); }, .HeartbeatACK => { @@ -312,7 +328,7 @@ fn readMessage(self: *Self, _: anytype) !void { session_id: []const u8, seq: ?isize, }; - const parsed = try std.json.parseFromValue(WithSequence, self.allocator, payload.d.?, .{}); + const parsed = try std.json.parseFromValue(WithSequence, self.allocator, payload.*, .{}); defer parsed.deinit(); const resume_payload = parsed.value; @@ -374,8 +390,6 @@ pub const ConnectError = pub fn connect(self: *Self) ConnectError!void { //std.time.sleep(std.time.ms_per_s * 5); self.client = try Self._connect_ws(self.allocator, self.gatewayUrl()); - //const event_listener = try std.Thread.spawn(.{}, Self.readMessage, .{ &self, null }); - //event_listener.join(); self.readMessage(null) catch unreachable; } @@ -408,6 +422,15 @@ pub fn send(self: *Self, _: bool, data: anytype) SendError!void { try self.client.write(try string.toOwnedSlice()); } +pub fn handleEventNoError(self: *Self, name: []const u8, payload_ptr: *json.Value) void { + // log to make sure this executes + std.debug.print("Shard {d} dispatching {s}\n", .{self.id, name}); + + self.handleEvent(name, payload_ptr.*) catch |err| { + std.debug.print("Shard {d} error: {s}\n", .{self.id, @errorName(err)}); + }; +} + pub fn handleEvent(self: *Self, name: []const u8, payload: json.Value) !void { if (mem.eql(u8, name, "READY")) if (self.handler.ready) |event| { const ready = try json.parseFromValue(Types.Ready, self.allocator, payload, .{