diff --git a/build.zig b/build.zig index 3d76f9a..b759fe0 100644 --- a/build.zig +++ b/build.zig @@ -20,11 +20,6 @@ pub fn build(b: *std.Build) void { .optimize = optimize, }); - const zig_tls = b.dependency("zig-tls", .{ - .target = target, - .optimize = optimize, - }); - const zlib = b.dependency("zlib", .{}); const zmpl = b.dependency("zmpl", .{ @@ -48,14 +43,12 @@ pub fn build(b: *std.Build) void { // now install your own executable after it's built correctly dzig.addImport("ws", websocket.module("websocket")); - dzig.addImport("tls12", zig_tls.module("zig-tls12")); dzig.addImport("zlib", zlib.module("zlib")); dzig.addImport("zmpl", zmpl.module("zmpl")); dzig.addImport("deque", deque.module("zig-deque")); marin.root_module.addImport("discord.zig", dzig); marin.root_module.addImport("ws", websocket.module("websocket")); - marin.root_module.addImport("tls12", zig_tls.module("zig-tls12")); marin.root_module.addImport("zlib", zlib.module("zlib")); marin.root_module.addImport("zmpl", zmpl.module("zmpl")); marin.root_module.addImport("deque", deque.module("zig-deque")); diff --git a/build.zig.zon b/build.zig.zon index 7b6294e..bf239e6 100644 --- a/build.zig.zon +++ b/build.zig.zon @@ -27,21 +27,17 @@ .url = "https://github.com/magurotuna/zig-deque/archive/refs/heads/main.zip", .hash = "1220d1bedf7d5cfc7475842b3d4e8f03f1308be2e724a036677cceb5c4db13c3da3d", }, - .@"zig-tls" = .{ - .url = "https://github.com/yuzudev/zig-tls12/archive/refs/heads/main.zip", - .hash = "122079aad9eebc5945e207715c50722bbeca34e7b7a5876b3a322fad6468b28f2ed3", - }, - .websocket = .{ - .url = "https://github.com/yuzudev/websocket.zig/archive/refs/heads/master.zip", - .hash = "122062d9dda015ba25780d00697e0e2c1cbc3ffa5b5eae55c9ea28b39b6d99ef1d85", - }, .zlib = .{ .url = "https://github.com/yuzudev/zig-zlib/archive/refs/heads/main.zip", .hash = "1220cd041e8d04f1da9d6f46d0438f4e6809b113ba3454fffdaae96b59d2b35a6b2b", }, .zmpl = .{ .url = "https://github.com/jetzig-framework/zmpl/archive/refs/heads/main.zip", - .hash = "12209a9d6f652ce712448da1bd41517cceef8eebf86884c1bb29c02fc3a933713afd", + .hash = "1220798a4647e3b0766aad653830a2601e11c567ba6bfe83e526eb91d04a6c45f7d8", + }, + .websocket = .{ + .url = "https://github.com/yuzudev/websocket.zig/archive/refs/heads/master.zip", + .hash = "12207c03624f9f5a1c444bde3d484a9b1e927a902067fded98364b714de412d318e0", }, }, .paths = .{ diff --git a/src/discord.zig b/src/discord.zig index 9b5e1a6..c95a20b 100644 --- a/src/discord.zig +++ b/src/discord.zig @@ -1,4 +1,80 @@ pub const Discord = @import("types.zig"); +const Intents = Discord.Intents; + pub const Shard = @import("shard.zig"); pub const Internal = @import("internal.zig"); -pub const Session = @import("session.zig"); +const Log = Internal.Log; +const GatewayDispatchEvent = Internal.GatewayDispatchEvent; + +pub const Sharder = @import("sharder.zig"); +const SessionOptions = Sharder.SessionOptions; + +pub const Shared = @import("shared.zig"); +const GatewayBotInfo = Shared.GatewayBotInfo; + +pub const FetchReq = @import("http.zig").FetchReq; + +const std = @import("std"); +const mem = std.mem; +const http = std.http; +const json = std.json; + +pub const Client = struct { + allocator: mem.Allocator, + sharder: Sharder, + + pub fn init(allocator: mem.Allocator) Client { + return .{ + .allocator = allocator, + .sharder = undefined, + }; + } + + pub fn deinit(self: *Client) void { + self.sharder.deinit(); + } + + pub fn start(self: *Client, settings: struct { + token: []const u8, + intents: Intents, + options: struct { + spawn_shard_delay: u64 = 5300, + total_shards: usize = 1, + shard_start: usize = 0, + shard_end: usize = 1, + }, + run: GatewayDispatchEvent(*Shard), + log: Log, + }) !void { + var req = FetchReq.init(self.allocator, settings.token); + defer req.deinit(); + + const res = try req.makeRequest(.GET, "/gateway/bot", null); + const body = try req.body.toOwnedSlice(); + defer self.allocator.free(body); + + // check status idk + if (res.status != http.Status.ok) { + @panic("we are cooked\n"); + } + + const parsed = try json.parseFromSlice(GatewayBotInfo, self.allocator, body, .{}); + defer parsed.deinit(); + + self.sharder = try Sharder.init(self.allocator, .{ + .token = settings.token, + .intents = settings.intents, + .run = settings.run, + .options = SessionOptions{ + .info = parsed.value, + .shard_start = settings.options.shard_start, + .shard_end = @intCast(parsed.value.shards), + .total_shards = @intCast(parsed.value.shards), + .spawn_shard_delay = settings.options.spawn_shard_delay, + }, + .log = settings.log, + }); + + try self.sharder.spawnShards(); + } +}; diff --git a/src/http.zig b/src/http.zig new file mode 100644 index 0000000..cb0d289 --- /dev/null +++ b/src/http.zig @@ -0,0 +1,49 @@ +const std = @import("std"); +const mem = std.mem; +const http = std.http; + +pub const BASE_URL = "https://discord.com/api/v10"; + +pub const FetchReq = struct { + allocator: mem.Allocator, + token: []const u8, + client: http.Client, + body: std.ArrayList(u8), + + pub fn init(allocator: mem.Allocator, token: []const u8) FetchReq { + const client = http.Client{ .allocator = allocator }; + return FetchReq{ + .allocator = allocator, + .client = client, + .body = std.ArrayList(u8).init(allocator), + .token = token, + }; + } + + pub fn deinit(self: *FetchReq) void { + self.client.deinit(); + self.body.deinit(); + } + + pub fn makeRequest(self: *FetchReq, method: http.Method, path: []const u8, to_post: ?[]const u8) !http.Client.FetchResult { + var fetch_options = http.Client.FetchOptions{ + .location = http.Client.FetchOptions.Location{ + .url = try std.fmt.allocPrint(self.allocator, "{s}{s}", .{ BASE_URL, path }), + }, + .extra_headers = &[_]http.Header{ + http.Header{ .name = "Accept", .value = "application/json" }, + http.Header{ .name = "Content-Type", .value = "application/json" }, + http.Header{ .name = "Authorization", .value = self.token }, + }, + .method = method, + .response_storage = .{ .dynamic = &self.body }, + }; + + if (to_post != null) { + fetch_options.payload = to_post; + } + + const res = try self.client.fetch(fetch_options); + return res; + } +}; diff --git a/src/internal.zig b/src/internal.zig index e248e20..61917ae 100644 --- a/src/internal.zig +++ b/src/internal.zig @@ -1,89 +1,105 @@ const std = @import("std"); const mem = std.mem; -const Deque = @import("deque"); +const Deque = @import("deque").Deque; const Discord = @import("types.zig"); +const builtin = @import("builtin"); +const IdentifyProperties = @import("shared.zig").IdentifyProperties; pub const debug = std.log.scoped(.@"discord.zig"); pub const Log = union(enum) { yes, no }; +pub const default_identify_properties = IdentifyProperties{ + .os = @tagName(builtin.os.tag), + .browser = "discord.zig", + .device = "discord.zig", +}; + /// inspired from: /// https://github.com/tiramisulabs/seyfert/blob/main/src/websocket/structures/timeout.ts -pub const ConnectQueue = struct { - dequeue: Deque(*const fn () void), - allocator: mem.Allocator, - remaining: usize, - interval_time: u64 = 5000, - running: bool, - concurrency: usize = 1, - - pub fn init(allocator: mem.Allocator, concurrency: usize, interval_time: u64) !ConnectQueue { - return .{ - .allocator = allocator, - .dequeue = try Deque(*const fn () void).init(allocator), - .remaining = concurrency, - .interval_time = interval_time, - .concurrency = concurrency, +pub fn ConnectQueue(comptime T: type) type { + return struct { + pub const RequestWithShard = struct { + callback: *const fn (self: *RequestWithShard) anyerror!void, + shard: T, }; - } - pub fn deinit(self: *ConnectQueue) void { - self.dequeue.deinit(); - } + dequeue: Deque(RequestWithShard), + allocator: mem.Allocator, + remaining: usize, + interval_time: u64 = 5000, + running: bool = false, + concurrency: usize = 1, - pub fn push(self: *ConnectQueue, callback: *const fn () void) !void { - if (self.remaining == 0) { - return self.dequeue.pushBack(callback); - } - self.remaining -= 1; - - if (!self.running) { - self.startInterval(); - self.running = true; + pub fn init(allocator: mem.Allocator, concurrency: usize, interval_time: u64) !ConnectQueue(T) { + return .{ + .allocator = allocator, + .dequeue = try Deque(RequestWithShard).init(allocator), + .remaining = concurrency, + .interval_time = interval_time, + .concurrency = concurrency, + }; } - if (self.dequeue.items.len < self.concurrency) { - @call(.auto, callback, .{}); - return; + pub fn deinit(self: *ConnectQueue(T)) void { + self.dequeue.deinit(); } - return self.dequeue.pushBack(callback); - } + pub fn push(self: *ConnectQueue(T), req: RequestWithShard) !void { + if (self.remaining == 0) { + return self.dequeue.pushBack(req); + } + self.remaining -= 1; - fn startInterval(self: *ConnectQueue) void { - while (self.running) { - std.Thread.sleep(std.time.ns_per_ms * (self.interval_time / self.concurrency)); - const callback: ?*const fn () void = self.dequeue.popFront(); + if (!self.running) { + try self.startInterval(); + self.running = true; + } - while (self.dequeue.items.len == 0 and callback == null) {} - - if (callback) |cb| { - @call(.auto, cb, .{}); + if (self.dequeue.len() < self.concurrency) { + // perhaps store this? + const ptr = try self.allocator.create(RequestWithShard); + ptr.* = req; + try @call(.auto, req.callback, .{ptr}); return; } - if (self.remaining < self.concurrency) { - self.remaining += 1; - } + return self.dequeue.pushBack(req); + } - if (self.dequeue.len() == 0) { - self.running = false; + fn startInterval(self: *ConnectQueue(T)) !void { + while (self.running) { + std.Thread.sleep(std.time.ns_per_ms * (self.interval_time / self.concurrency)); + const req: ?RequestWithShard = self.dequeue.popFront(); + + while (self.dequeue.len() == 0 and req == null) {} + + if (req) |r| { + const ptr = try self.allocator.create(RequestWithShard); + ptr.* = r; + try @call(.auto, r.callback, .{ptr}); + return; + } + + if (self.remaining < self.concurrency) { + self.remaining += 1; + } + + if (self.dequeue.len() == 0) { + self.running = false; + } } } - } -}; - -fn lessthan(_: void, a: RequestWithPrio, b: RequestWithPrio) void { - return std.math.order(a, b); + }; } pub const Bucket = struct { /// The queue of requests to acquire an available request. Mapped by (shardId, RequestWithPrio) - queue: std.PriorityQueue(RequestWithPrio, void, lessthan), + queue: std.PriorityQueue(RequestWithPrio, void, Bucket.lessthan), limit: usize, - refillInterval: u64, - refillAmount: usize, + refill_interval: u64, + refill_amount: usize, /// The amount of requests that have been used up already. used: usize = 0, @@ -92,21 +108,82 @@ pub const Bucket = struct { processing: bool = false, /// Whether the timeout should be killed because there is already one running - shouldStop: bool = false, + should_stop: std.atomic.Value(bool) = std.atomic.Value(bool).init(false), /// The timestamp in milliseconds when the next refill is scheduled. - refillsAt: ?u64, + refills_at: ?u64 = null, - /// comes in handy - m: std.Thread.Mutex = .{}, - c: std.Thread.Condition = .{}, + pub const RequestWithPrio = struct { + callback: *const fn () void, + priority: u32 = 1, + }; - fn timeout(self: *Bucket) void { - _ = self; + fn lessthan(_: void, a: RequestWithPrio, b: RequestWithPrio) std.math.Order { + return std.math.order(a.priority, b.priority); } - pub fn processQueue() !void {} - pub fn refill() void {} + pub fn init(allocator: mem.Allocator, limit: usize, refill_interval: u64, refill_amount: usize) Bucket { + return .{ + .queue = std.PriorityQueue(RequestWithPrio, void, lessthan).init(allocator, {}), + .limit = limit, + .refill_interval = refill_interval, + .refill_amount = refill_amount, + }; + } + + fn remaining(self: *Bucket) usize { + if (self.limit < self.used) { + return 0; + } else { + return self.limit - self.used; + } + } + + pub fn refill(self: *Bucket) std.Thread.SpawnError!void { + // Lower the used amount by the refill amount + self.used = if (self.refill_amount > self.used) 0 else self.used - self.refill_amount; + + // Reset the refills_at timestamp since it just got refilled + self.refills_at = null; + + if (self.used > 0) { + if (self.should_stop.load(.monotonic) == true) { + self.should_stop.store(false, .monotonic); + } + const thread = try std.Thread.spawn(.{}, Bucket.timeout, .{self}); + thread.detach; + self.refills_at = std.time.milliTimestamp() + self.refill_interval; + } + } + + fn timeout(self: *Bucket) void { + while (!self.should_stop.load(.monotonic)) { + self.refill(); + std.time.sleep(std.time.ns_per_ms * self.refill_interval); + } + } + + pub fn processQueue(self: *Bucket) std.Thread.SpawnError!void { + if (self.processing) return; + + while (self.queue.remove()) |first_element| { + if (self.remaining() != 0) { + first_element.callback(); + self.used += 1; + + if (!self.should_stop.load(.monotonic)) { + const thread = try std.Thread.spawn(.{}, Bucket.timeout, .{self}); + thread.detach; + self.refills_at = std.time.milliTimestamp() + self.refill_interval; + } + } else if (self.refills_at) |ra| { + const now = std.time.milliTimestamp(); + if (ra > now) std.time.sleep(std.time.ns_per_ms * (ra - now)); + } + } + + self.processing = false; + } pub fn acquire(self: *Bucket, rq: RequestWithPrio) !void { try self.queue.add(rq); @@ -114,11 +191,6 @@ pub const Bucket = struct { } }; -pub const RequestWithPrio = struct { - callback: *const fn () void, - priority: u32, -}; - pub fn GatewayDispatchEvent(comptime T: type) type { return struct { // TODO: implement // application_command_permissions_update: null = null, diff --git a/src/parser.zig b/src/parser.zig index 250a468..0993144 100644 --- a/src/parser.zig +++ b/src/parser.zig @@ -70,8 +70,6 @@ pub fn parseMessage(allocator: mem.Allocator, obj: *zmpl.Data.Object) !Discord.M try mentions.append(try parseUser(allocator, &m.object)); } - std.debug.print("parsing mentions done\n", .{}); - // parse member const member = if (obj.getT(.object, "member")) |m| try parseMember(allocator, m) else null; diff --git a/src/session.zig b/src/session.zig deleted file mode 100644 index e238719..0000000 --- a/src/session.zig +++ /dev/null @@ -1,80 +0,0 @@ -const Intents = @import("types.zig").Intents; -const GatewayBotInfo = @import("shared.zig").GatewayBotInfo; -const Shared = @import("shared.zig"); -const IdentifyProperties = Shared.IdentifyProperties; -const Internal = @import("internal.zig"); -const ConnectQueue = Internal.ConnectQueue; -const GatewayDispatchEvent = Internal.GatewayDispatchEvent; -const Shard = @import("shard.zig"); -const std = @import("std"); -const mem = std.mem; -const debug = Internal.debug; - -const Self = @This(); - -token: []const u8, -intents: Intents, -allocator: mem.Allocator, -connectQueue: ConnectQueue, -shards: std.AutoArrayHashMap(usize, Shard), -run: GatewayDispatchEvent(*Self), - -/// spawn buckets in order -/// https://discord.com/developers/docs/events/gateway#sharding-max-concurrency -fn spawnBuckers(self: *Self) !void { - _ = self; -} - -/// creates a shard and stores it -fn create(self: *Self, shard_id: usize) !Shard { - const shard = try self.shards.getOrPutValue(shard_id, try Shard.login(self.allocator, .{ - .token = self.token, - .intents = self.intents, - .run = self.run, - .log = self.log, - })); - - return shard; -} - -pub const ShardDetails = struct { - /// Bot token which is used to connect to Discord */ - token: []const u8, - /// - /// The URL of the gateway which should be connected to. - /// - url: []const u8 = "wss://gateway.discord.gg", - /// - /// The gateway version which should be used. - /// @default 10 - /// - version: ?usize = 10, - /// - /// The calculated intent value of the events which the shard should receive. - /// - intents: Intents, - /// - /// Identify properties to use - /// - properties: ?IdentifyProperties, -}; - -pub const SessionOptions = struct { - /// Important data which is used by the manager to connect shards to the gateway. */ - info: GatewayBotInfo, - /// Delay in milliseconds to wait before spawning next shard. OPTIMAL IS ABOVE 5100. YOU DON'T WANT TO HIT THE RATE LIMIT!!! - spawnShardDelay: ?u64 = 5300, - /// Total amount of shards your bot uses. Useful for zero-downtime updates or resharding. - totalShards: ?usize = 1, - shardStart: ?usize, - shardEnd: ?usize, - /// - /// The payload handlers for messages on the shard. - /// TODO: - /// handlePayload: (shardId: number, packet: GatewayDispatchPayload): unknown; - /// - resharding: ?struct { - interval: u64, - percentage: usize, - }, -}; diff --git a/src/shard.zig b/src/shard.zig index 9a53eb4..27a1016 100644 --- a/src/shard.zig +++ b/src/shard.zig @@ -1,6 +1,5 @@ const ws = @import("ws"); const builtin = @import("builtin"); -const HttpClient = @import("tls12").HttpClient; const std = @import("std"); const net = std.net; @@ -27,68 +26,21 @@ const IdentifyProperties = Shared.IdentifyProperties; const GatewayInfo = Shared.GatewayInfo; const GatewayBotInfo = Shared.GatewayBotInfo; const GatewaySessionStartLimit = Shared.GatewaySessionStartLimit; +const ShardDetails = Shared.ShardDetails; const Internal = @import("internal.zig"); const Log = Internal.Log; const GatewayDispatchEvent = Internal.GatewayDispatchEvent; +const Bucket = Internal.Bucket; +const default_identify_properties = Internal.default_identify_properties; + +const FetchRequest = @import("http.zig").FetchReq; const ShardSocketCloseCodes = enum(u16) { Shutdown = 3000, ZombiedConnection = 3010, }; -const BASE_URL = "https://discord.com/api/v10"; - -pub const FetchReq = struct { - allocator: mem.Allocator, - token: []const u8, - client: HttpClient, - body: std.ArrayList(u8), - - pub fn init(allocator: mem.Allocator, token: []const u8) FetchReq { - const client = HttpClient{ .allocator = allocator }; - return FetchReq{ - .allocator = allocator, - .client = client, - .body = std.ArrayList(u8).init(allocator), - .token = token, - }; - } - - pub fn deinit(self: *FetchReq) void { - self.client.deinit(); - self.body.deinit(); - } - - pub fn makeRequest(self: *FetchReq, method: http.Method, path: []const u8, to_post: ?[]const u8) !HttpClient.FetchResult { - var fetch_options = HttpClient.FetchOptions{ - .location = HttpClient.FetchOptions.Location{ - .url = try std.fmt.allocPrint(self.allocator, "{s}{s}", .{ BASE_URL, path }), - }, - .extra_headers = &[_]http.Header{ - http.Header{ .name = "Accept", .value = "application/json" }, - http.Header{ .name = "Content-Type", .value = "application/json" }, - http.Header{ .name = "Authorization", .value = self.token }, - }, - .method = method, - .response_storage = .{ .dynamic = &self.body }, - }; - - if (to_post != null) { - fetch_options.payload = to_post; - } - - const res = try self.client.fetch(fetch_options); - return res; - } -}; - -const _default_properties = IdentifyProperties{ - .os = @tagName(builtin.os.tag), - .browser = "discord.zig", - .device = "discord.zig", -}; - const Heart = struct { heartbeatInterval: u64, ack: bool, @@ -96,16 +48,27 @@ const Heart = struct { lastBeat: i64, }; +const RatelimitOptions = struct { + max_requests_per_ratelimit_tick: ?usize = 120, + ratelimit_reset_interval: u64 = 60000, +}; + +pub const ShardOptions = struct { + ratelimit_options: RatelimitOptions = .{}, +}; + +id: usize, + client: ws.Client, -token: []const u8, -intents: Intents, +details: ShardDetails, //heart: Heart = allocator: mem.Allocator, resume_gateway_url: ?[]const u8 = null, info: GatewayBotInfo, +bucket: Bucket, +ratelimit_options: RatelimitOptions, -properties: IdentifyProperties = _default_properties, session_id: ?[]const u8, sequence: std.atomic.Value(isize) = std.atomic.Value(isize).init(0), heart: Heart = .{ .heartbeatInterval = 45000, .ack = false, .lastBeat = 0 }, @@ -135,86 +98,87 @@ pub fn resumable(self: *Self) bool { pub fn resume_(self: *Self) !void { const data = .{ .op = @intFromEnum(Opcode.Resume), .d = .{ - .token = self.token, + .token = self.details.token, .session_id = self.session_id, .seq = self.sequence.load(.monotonic), } }; - try self.send(data); + try self.send(false, data); } inline fn gatewayUrl(self: ?*Self) []const u8 { return if (self) |s| (s.resume_gateway_url orelse s.info.url)["wss://".len..] else "gateway.discord.gg"; } -// identifies in order to connect to Discord and get the online status, this shall be done on hello perhaps +/// identifies in order to connect to Discord and get the online status, this shall be done on hello perhaps fn identify(self: *Self, properties: ?IdentifyProperties) !void { - self.logif("intents: {d}", .{self.intents.toRaw()}); + self.logif("intents: {d}", .{self.details.intents.toRaw()}); - if (self.intents.toRaw() != 0) { + if (self.details.intents.toRaw() != 0) { const data = .{ .op = @intFromEnum(Opcode.Identify), .d = .{ - .intents = self.intents.toRaw(), - .properties = properties orelse Self._default_properties, - .token = self.token, + .intents = self.details.intents.toRaw(), + .properties = properties orelse default_identify_properties, + .token = self.details.token, }, }; - try self.send(data); + try self.send(false, data); } else { const data = .{ .op = @intFromEnum(Opcode.Identify), .d = .{ .capabilities = 30717, - .properties = properties orelse Self._default_properties, - .token = self.token, + .properties = properties orelse default_identify_properties, + .token = self.details.token, }, }; - try self.send(data); + try self.send(false, data); } } -// asks /gateway/bot initializes both the ws client and the http client -pub fn login(allocator: mem.Allocator, args: struct { +pub fn init(allocator: mem.Allocator, shard_id: usize, settings: struct { token: []const u8, intents: Intents, + options: ShardOptions, run: GatewayDispatchEvent(*Self), log: Log, -}) !Self { - var req = FetchReq.init(allocator, args.token); - defer req.deinit(); - - const res = try req.makeRequest(.GET, "/gateway/bot", null); - const body = try req.body.toOwnedSlice(); - defer allocator.free(body); - - // check status idk - if (res.status != http.Status.ok) { - @panic("we are cooked\n"); - } - - const parsed = try json.parseFromSlice(GatewayBotInfo, allocator, body, .{}); - defer parsed.deinit(); - const url = parsed.value.url["wss://".len..]; - - var self: Self = .{ +}) zlib.Error!Self { + return Self{ + .info = .{ .url = "wss://gateway.discord.gg", .shards = 1, .session_start_limit = null }, + .id = shard_id, .allocator = allocator, - .token = args.token, - .intents = args.intents, + .details = ShardDetails{ + .token = settings.token, + .intents = settings.intents, + }, + .client = undefined, // maybe there is a better way to do this - .client = try Self._connect_ws(allocator, url), .session_id = undefined, - .info = parsed.value, - .handler = args.run, - .log = args.log, + .handler = settings.run, + .log = settings.log, .packets = std.ArrayList(u8).init(allocator), .inflator = try zlib.Decompressor.init(allocator, .{ .header = .zlib_or_gzip }), + .bucket = Bucket.init( + allocator, + Self.calculateSafeRequests(settings.options.ratelimit_options), + settings.options.ratelimit_options.ratelimit_reset_interval, + Self.calculateSafeRequests(settings.options.ratelimit_options), + ), + .ratelimit_options = settings.options.ratelimit_options, }; +} - const event_listener = try std.Thread.spawn(.{}, Self.readMessage, .{ &self, null }); - event_listener.join(); +inline fn calculateSafeRequests(options: RatelimitOptions) usize { + const safe_requests = + @as(f64, @floatFromInt(options.max_requests_per_ratelimit_tick orelse 120)) - + @ceil(@as(f64, @floatFromInt(options.ratelimit_reset_interval)) / 30000.0) * 2; - return self; + if (safe_requests < 0) { + return 0; + } + + return @intFromFloat(safe_requests); } inline fn _connect_ws(allocator: mem.Allocator, url: []const u8) !ws.Client { @@ -241,8 +205,8 @@ pub fn deinit(self: *Self) void { self.logif("killing the whole bot", .{}); } -// listens for messages -pub fn readMessage(self: *Self, _: anytype) !void { +/// listens for messages +fn readMessage(self: *Self, _: anytype) !void { try self.client.readTimeout(0); while (true) { @@ -301,7 +265,7 @@ pub fn readMessage(self: *Self, _: anytype) !void { try self.resume_(); return; } else { - try self.identify(self.properties); + try self.identify(self.details.properties); } var prng = std.Random.DefaultPrng.init(0); @@ -321,7 +285,7 @@ pub fn readMessage(self: *Self, _: anytype) !void { self.logif("sending requested heartbeat", .{}); self.ws_mutex.lock(); defer self.ws_mutex.unlock(); - try self.send(.{ .op = @intFromEnum(Opcode.Heartbeat), .d = self.sequence.load(.monotonic) }); + try self.send(false, .{ .op = @intFromEnum(Opcode.Heartbeat), .d = self.sequence.load(.monotonic) }); }, Opcode.Reconnect => { self.logif("reconnecting", .{}); @@ -373,7 +337,7 @@ pub fn heartbeat(self: *Self, initial_jitter: f64) !void { const seq = self.sequence.load(.monotonic); self.logif("sending unrequested heartbeat", .{}); self.ws_mutex.lock(); - try self.send(.{ .op = @intFromEnum(Opcode.Heartbeat), .d = seq }); + try self.send(false, .{ .op = @intFromEnum(Opcode.Heartbeat), .d = seq }); self.ws_mutex.unlock(); if ((std.time.milliTimestamp() - last) > (5000 * self.heart.heartbeatInterval)) { @@ -390,16 +354,25 @@ pub inline fn reconnect(self: *Self) !void { try self.connect(); } -pub fn connect(self: *Self) !void { +pub const ConnectError = + std.net.TcpConnectToAddressError || std.crypto.tls.Client.InitError(std.net.Stream) || std.net.Stream.ReadError || std.net.IPParseError || std.crypto.Certificate.Bundle.RescanError || std.net.TcpConnectToHostError || std.fmt.BufPrintError || mem.Allocator.Error; + +pub fn connect(self: *Self) ConnectError!void { //std.time.sleep(std.time.ms_per_s * 5); self.client = try Self._connect_ws(self.allocator, self.gatewayUrl()); + //const event_listener = try std.Thread.spawn(.{}, Self.readMessage, .{ &self, null }); + //event_listener.join(); + + self.readMessage(null) catch unreachable; } pub fn disconnect(self: *Self) !void { try self.close(ShardSocketCloseCodes.Shutdown, "Shard down request"); } -pub fn close(self: *Self, code: ShardSocketCloseCodes, reason: []const u8) !void { +pub const CloseError = mem.Allocator.Error || error{ReasonTooLong}; + +pub fn close(self: *Self, code: ShardSocketCloseCodes, reason: []const u8) CloseError!void { self.logif("cooked closing ws conn...\n", .{}); // Implement reconnection logic here try self.client.close(.{ @@ -408,7 +381,9 @@ pub fn close(self: *Self, code: ShardSocketCloseCodes, reason: []const u8) !void }); } -pub fn send(self: *Self, data: anytype) !void { +pub const SendError = net.Stream.WriteError || std.ArrayList(u8).Writer.Error; + +pub fn send(self: *Self, _: bool, data: anytype) SendError!void { var buf: [1000]u8 = undefined; var fba = std.heap.FixedBufferAllocator.init(&buf); var string = std.ArrayList(u8).init(fba.allocator()); @@ -546,6 +521,7 @@ pub fn handleEvent(self: *Self, name: []const u8, payload: []const u8) !void { } } +/// highly experimental, do not use pub fn loginWithEmail(allocator: mem.Allocator, settings: struct { login: []const u8, password: []const u8, run: GatewayDispatchEvent(*Self), log: Log }) !Self { const AUTH_LOGIN = "https://discord.com/api/v9/auth/login"; const WS_CONNECT = "gateway.discord.gg"; @@ -555,8 +531,8 @@ pub fn loginWithEmail(allocator: mem.Allocator, settings: struct { login: []cons const AuthLoginResponse = struct { user_id: []const u8, token: []const u8, user_settings: struct { locale: []const u8, theme: []const u8 } }; - var fetch_options = HttpClient.FetchOptions{ - .location = HttpClient.FetchOptions.Location{ + var fetch_options = http.Client.FetchOptions{ + .location = http.Client.FetchOptions.Location{ .url = AUTH_LOGIN, }, .extra_headers = &[_]http.Header{ @@ -572,7 +548,7 @@ pub fn loginWithEmail(allocator: mem.Allocator, settings: struct { login: []cons .password = settings.password, }, .{}); - var client = HttpClient{ .allocator = allocator }; + var client = http.Client{ .allocator = allocator }; defer client.deinit(); _ = try client.fetch(fetch_options); @@ -581,8 +557,10 @@ pub fn loginWithEmail(allocator: mem.Allocator, settings: struct { login: []cons return .{ .allocator = allocator, - .token = response.token, - .intents = @bitCast(@as(u28, @intCast(0))), + .details = ShardDetails{ + .token = response.token, + .intents = @bitCast(@as(u28, @intCast(0))), + }, // maybe there is a better way to do this .client = try Self._connect_ws(allocator, WS_CONNECT), .session_id = undefined, diff --git a/src/sharder.zig b/src/sharder.zig new file mode 100644 index 0000000..68395da --- /dev/null +++ b/src/sharder.zig @@ -0,0 +1,227 @@ +const Intents = @import("types.zig").Intents; +const GatewayBotInfo = @import("shared.zig").GatewayBotInfo; +const Shared = @import("shared.zig"); +const IdentifyProperties = Shared.IdentifyProperties; +const ShardDetails = Shared.ShardDetails; +const Internal = @import("internal.zig"); +const ConnectQueue = Internal.ConnectQueue; +const GatewayDispatchEvent = Internal.GatewayDispatchEvent; +const Log = @import("internal.zig").Log; +const Shard = @import("shard.zig"); +const std = @import("std"); +const mem = std.mem; +const debug = Internal.debug; + +pub const discord_epoch = 1420070400000; + +/// Calculate and return the shard ID for a given guild ID +pub inline fn calculateShardId(guildId: u64, shards: ?usize) u64 { + return (guildId >> 22) % shards orelse 1; +} + +/// Convert a timestamp to a snowflake. +pub inline fn snowflakeToTimestamp(id: u64) u64 { + return (id >> 22) + discord_epoch; +} + +const Self = @This(); + +shard_details: ShardDetails, +allocator: mem.Allocator, + +/// Queue for managing shard connections +connect_queue: ConnectQueue(Shard), +shards: std.AutoArrayHashMap(usize, Shard), +handler: GatewayDispatchEvent(*Shard), + +/// configuration settings +options: SessionOptions, +log: Log, + +pub const ShardData = struct { + /// resume seq to resume connections + resume_seq: ?usize, + + /// resume_gateway_url is the url to resume the connection + /// https://discord.com/developers/docs/topics/gateway#ready-event + resume_gateway_url: ?[]const u8, + + /// session_id is the unique session id of the gateway + session_id: ?[]const u8, +}; + +pub const SessionOptions = struct { + /// Important data which is used by the manager to connect shards to the gateway. */ + info: GatewayBotInfo, + /// Delay in milliseconds to wait before spawning next shard. OPTIMAL IS ABOVE 5100. YOU DON'T WANT TO HIT THE RATE LIMIT!!! + spawn_shard_delay: ?u64 = 5300, + /// Total amount of shards your bot uses. Useful for zero-downtime updates or resharding. + total_shards: usize = 1, + shard_start: usize = 0, + shard_end: usize = 1, + /// The payload handlers for messages on the shard. + resharding: ?struct { interval: u64, percentage: usize } = null, +}; + +pub fn init(allocator: mem.Allocator, settings: struct { + token: []const u8, + intents: Intents, + options: SessionOptions, + run: GatewayDispatchEvent(*Shard), + log: Log, +}) mem.Allocator.Error!Self { + const concurrency = settings.options.info.session_start_limit.?.max_concurrency; + return .{ + .allocator = allocator, + .connect_queue = try ConnectQueue(Shard).init(allocator, concurrency, 5000), + .shards = .init(allocator), + .shard_details = ShardDetails{ + .token = settings.token, + .intents = settings.intents, + }, + .handler = settings.run, + .options = .{ + .info = .{ + .url = settings.options.info.url, + .shards = settings.options.info.shards, + .session_start_limit = settings.options.info.session_start_limit, + }, + .total_shards = settings.options.total_shards, + .shard_start = settings.options.shard_start, + .shard_end = settings.options.shard_end, + }, + .log = settings.log, + }; +} + +pub fn deinit(self: *Self) void { + self.connect_queue.deinit(); + self.shards.deinit(); +} + +pub fn forceIdentify(self: *Self, shard_id: usize) !void { + self.logif("#{d} force identify", .{shard_id}); + const shard = try self.create(shard_id); + + return shard.identify(null); +} + +pub fn disconnect(self: *Self, shard_id: usize) !void { + return if (self.shards.get(shard_id)) |shard| shard.disconnect(); +} + +pub fn disconnectAll(self: *Self) !void { + while (self.shards.iterator().next()) |shard| shard.value_ptr.disconnect(); +} + +/// spawn buckets in order +/// Log bucket preparation +/// Divide shards into chunks based on concurrency +/// Assign each shard to a bucket +/// Return list of buckets +/// https://discord.com/developers/docs/events/gateway#sharding-max-concurrency +fn spawnBuckets(self: *Self) ![][]Shard { + const concurrency = self.options.info.session_start_limit.?.max_concurrency; + + self.logif("{d}-{d}", .{ self.options.shard_start, self.options.shard_end }); + + const range = std.math.sub(usize, self.options.shard_start, self.options.shard_end) catch 1; + const bucket_count = (range + concurrency - 1) / concurrency; + + self.logif("#0 preparing buckets", .{}); + + const buckets = try self.allocator.alloc([]Shard, bucket_count); + + for (buckets, 0..) |*bucket, i| { + const bucket_size = if ((i + 1) * concurrency > range) range - (i * concurrency) else concurrency; + + bucket.* = try self.allocator.alloc(Shard, bucket_size); + + for (bucket.*, 0..) |*shard, j| { + shard.* = try self.create(self.options.shard_start + i * concurrency + j); + } + } + + self.logif("{d} buckets created", .{bucket_count}); + + return buckets; +} + +/// creates a shard and stores it +fn create(self: *Self, shard_id: usize) !Shard { + if (self.shards.get(shard_id)) |s| return s; + + const shard: Shard = try Shard.init(self.allocator, shard_id, .{ + .token = self.shard_details.token, + .intents = self.shard_details.intents, + .options = Shard.ShardOptions{}, + .run = self.handler, + .log = self.log, + }); + + try self.shards.put(shard_id, shard); + + return shard; +} + +pub fn resume_(self: *Self, shard_id: usize, shard_data: ShardData) void { + if (self.shards.contains(shard_id)) return error.CannotOverrideExistingShard; + + const shard = self.create(shard_id); + + shard.data = shard_data; + + return self.connect_queue.push(.{ + .shard = shard, + .callback = &callback, + }); +} + +fn callback(self: *ConnectQueue(Shard).RequestWithShard) anyerror!void { + try self.shard.connect(); +} + +pub fn spawnShards(self: *Self) !void { + const buckets = try self.spawnBuckets(); + + self.logif("Spawning shards", .{}); + + for (buckets) |bucket| { + for (bucket) |shard| { + self.logif("adding {d} to connect queue", .{shard.id}); + try self.connect_queue.push(.{ + .shard = shard, + .callback = &callback, + }); + } + } + + //self.startResharder(); +} + +pub fn send(self: *Self, shard_id: usize, data: anytype) !void { + if (self.shards.get(shard_id)) |shard| { + try shard.send(data); + } +} + +// SPEC OF THE RESHARDER: +// Class Self +// +// Method startResharder(): +// If resharding interval is not set or shard bounds are not valid: +// Exit +// Set up periodic check for resharding: +// If new shards are required: +// Log resharding process +// Update options with new shard settings +// Disconnect old shards and clear them from manager +// Spawn shards again with updated configuration +// + +inline fn logif(self: *Self, comptime format: []const u8, args: anytype) void { + switch (self.log) { + .yes => Internal.debug.info(format, args), + .no => {}, + } +} diff --git a/src/shared.zig b/src/shared.zig index 7c2f586..c1e4d55 100644 --- a/src/shared.zig +++ b/src/shared.zig @@ -1,17 +1,13 @@ +const Intents = @import("types.zig").Intents; +const default_identify_properties = @import("internal.zig").default_identify_properties; const std = @import("std"); pub const IdentifyProperties = struct { - /// /// Operating system the shard runs on. - /// os: []const u8, - /// /// The "browser" where this shard is running on. - /// browser: []const u8, - /// /// The device on which the shard is running. - /// device: []const u8, system_locale: ?[]const u8 = null, // TODO parse this @@ -48,20 +44,29 @@ pub const GatewaySessionStartLimit = struct { /// https://discord.com/developers/docs/topics/gateway#get-gateway-bot pub const GatewayBotInfo = struct { url: []const u8, - /// /// The recommended number of shards to use when connecting /// /// See https://discord.com/developers/docs/topics/gateway#sharding - /// shards: u32, - /// /// Information on the current session start limit /// /// See https://discord.com/developers/docs/topics/gateway#session-start-limit-object - /// session_start_limit: ?GatewaySessionStartLimit, }; +pub const ShardDetails = struct { + /// Bot token which is used to connect to Discord */ + token: []const u8, + /// The URL of the gateway which should be connected to. + url: []const u8 = "wss://gateway.discord.gg", + /// The gateway version which should be used. + version: ?usize = 10, + /// The calculated intent value of the events which the shard should receive. + intents: Intents, + /// Identify properties to use + properties: IdentifyProperties = default_identify_properties, +}; + pub const Snowflake = struct { id: u64, diff --git a/src/test.zig b/src/test.zig index 310bd3c..20747b9 100644 --- a/src/test.zig +++ b/src/test.zig @@ -1,6 +1,8 @@ +const Client = @import("discord.zig").Client; const Shard = @import("discord.zig").Shard; const Discord = @import("discord.zig").Discord; const Internal = @import("discord.zig").Internal; +const FetchReq = @import("discord.zig").FetchReq; const Intents = Discord.Intents; const Thread = std.Thread; const std = @import("std"); @@ -13,7 +15,7 @@ fn message_create(session: *Shard, message: Discord.Message) void { std.debug.print("captured: {?s} send by {s}\n", .{ message.content, message.author.username }); if (message.content) |mc| if (std.ascii.eqlIgnoreCase(mc, "!hi")) { - var req = Shard.FetchReq.init(session.allocator, session.token); + var req = FetchReq.init(session.allocator, session.details.token); defer req.deinit(); const payload: Discord.Partial(Discord.CreateMessage) = .{ .content = "Hi, I'm hang man, your personal assistant" }; @@ -28,14 +30,13 @@ fn message_create(session: *Shard, message: Discord.Message) void { pub fn main() !void { var tsa = std.heap.ThreadSafeAllocator{ .child_allocator = std.heap.c_allocator }; - var handler = try Shard.login(tsa.allocator(), .{ + var handler = Client.init(tsa.allocator()); + try handler.start(.{ .token = std.posix.getenv("TOKEN") orelse unreachable, .intents = Intents.fromRaw(37379), - .run = Internal.GatewayDispatchEvent(*Shard){ - .message_create = &message_create, - .ready = &ready, - }, + .run = .{ .message_create = &message_create, .ready = &ready }, .log = .yes, + .options = .{}, }); errdefer handler.deinit(); }