From 79dc5b3a0a459cd46853e04b98f20dcb5bf0a426 Mon Sep 17 00:00:00 2001 From: Yuzu Date: Sun, 10 Nov 2024 11:30:51 -0500 Subject: [PATCH] make errors more explicit and overall refactor --- src/discord.zig | 113 +++++++++++++++++++++++------------------------ src/internal.zig | 12 ++--- src/parser.zig | 6 +-- src/shard.zig | 43 +++++++++++------- src/sharder.zig | 14 +++--- src/shared.zig | 4 +- src/test.zig | 19 ++++---- 7 files changed, 109 insertions(+), 102 deletions(-) diff --git a/src/discord.zig b/src/discord.zig index c95a20b..0ae6504 100644 --- a/src/discord.zig +++ b/src/discord.zig @@ -1,10 +1,9 @@ -pub const Discord = @import("types.zig"); -const Intents = Discord.Intents; - +pub usingnamespace @import("types.zig"); pub const Shard = @import("shard.zig"); + pub const Internal = @import("internal.zig"); -const Log = Internal.Log; const GatewayDispatchEvent = Internal.GatewayDispatchEvent; +const Log = Internal.Log; pub const Sharder = @import("sharder.zig"); const SessionOptions = Sharder.SessionOptions; @@ -19,62 +18,62 @@ const mem = std.mem; const http = std.http; const json = std.json; -pub const Client = struct { - allocator: mem.Allocator, - sharder: Sharder, +const Self = @This(); - pub fn init(allocator: mem.Allocator) Client { - return .{ - .allocator = allocator, - .sharder = undefined, - }; +allocator: mem.Allocator, +sharder: Sharder, + +pub fn init(allocator: mem.Allocator) Self { + return .{ + .allocator = allocator, + .sharder = undefined, + }; +} + +pub fn deinit(self: *Self) void { + self.sharder.deinit(); +} + +pub fn start(self: *Self, settings: struct { + token: []const u8, + intents: Self.Intents, + options: struct { + spawn_shard_delay: u64 = 5300, + total_shards: usize = 1, + shard_start: usize = 0, + shard_end: usize = 1, + }, + run: GatewayDispatchEvent(*Shard), + log: Log, +}) !void { + var req = FetchReq.init(self.allocator, settings.token); + defer req.deinit(); + + const res = try req.makeRequest(.GET, "/gateway/bot", null); + const body = try req.body.toOwnedSlice(); + defer self.allocator.free(body); + + // check status idk + if (res.status != http.Status.ok) { + @panic("we are cooked\n"); } - pub fn deinit(self: *Client) void { - self.sharder.deinit(); - } + const parsed = try json.parseFromSlice(GatewayBotInfo, self.allocator, body, .{}); + defer parsed.deinit(); - pub fn start(self: *Client, settings: struct { - token: []const u8, - intents: Intents, - options: struct { - spawn_shard_delay: u64 = 5300, - total_shards: usize = 1, - shard_start: usize = 0, - shard_end: usize = 1, + self.sharder = try Sharder.init(self.allocator, .{ + .token = settings.token, + .intents = settings.intents, + .run = settings.run, + .options = SessionOptions{ + .info = parsed.value, + .shard_start = settings.options.shard_start, + .shard_end = @intCast(parsed.value.shards), + .total_shards = @intCast(parsed.value.shards), + .spawn_shard_delay = settings.options.spawn_shard_delay, }, - run: GatewayDispatchEvent(*Shard), - log: Log, - }) !void { - var req = FetchReq.init(self.allocator, settings.token); - defer req.deinit(); + .log = settings.log, + }); - const res = try req.makeRequest(.GET, "/gateway/bot", null); - const body = try req.body.toOwnedSlice(); - defer self.allocator.free(body); - - // check status idk - if (res.status != http.Status.ok) { - @panic("we are cooked\n"); - } - - const parsed = try json.parseFromSlice(GatewayBotInfo, self.allocator, body, .{}); - defer parsed.deinit(); - - self.sharder = try Sharder.init(self.allocator, .{ - .token = settings.token, - .intents = settings.intents, - .run = settings.run, - .options = SessionOptions{ - .info = parsed.value, - .shard_start = settings.options.shard_start, - .shard_end = @intCast(parsed.value.shards), - .total_shards = @intCast(parsed.value.shards), - .spawn_shard_delay = settings.options.spawn_shard_delay, - }, - .log = settings.log, - }); - - try self.sharder.spawnShards(); - } -}; + try self.sharder.spawnShards(); +} diff --git a/src/internal.zig b/src/internal.zig index 61917ae..379c723 100644 --- a/src/internal.zig +++ b/src/internal.zig @@ -235,10 +235,10 @@ pub fn GatewayDispatchEvent(comptime T: type) type { // TODO: implement // interaction_create: null = null, // TODO: implement // invite_create: null = null, // TODO: implement // invite_delete: null = null, - message_create: ?*const fn (save: T, message: Discord.Message) void = undefined, - message_update: ?*const fn (save: T, message: Discord.Message) void = undefined, - message_delete: ?*const fn (save: T, log: Discord.MessageDelete) void = undefined, - message_delete_bulk: ?*const fn (save: T, log: Discord.MessageDeleteBulk) void = undefined, + message_create: ?*const fn (save: T, message: Discord.Message) anyerror!void = undefined, + message_update: ?*const fn (save: T, message: Discord.Message) anyerror!void = undefined, + message_delete: ?*const fn (save: T, log: Discord.MessageDelete) anyerror!void = undefined, + message_delete_bulk: ?*const fn (save: T, log: Discord.MessageDeleteBulk) anyerror!void = undefined, // TODO: implement // message_delete_bulk: null = null, // TODO: implement // message_reaction_add: null = null, // TODO: implement // message_reaction_remove: null = null, @@ -260,8 +260,8 @@ pub fn GatewayDispatchEvent(comptime T: type) type { // TODO: implement // message_poll_vote_add: null = null, // TODO: implement // message_poll_vote_remove: null = null, - ready: ?*const fn (save: T, data: Discord.Ready) void = undefined, + ready: ?*const fn (save: T, data: Discord.Ready) anyerror!void = undefined, // TODO: implement // resumed: null = null, - any: ?*const fn (save: T, data: []const u8) void = undefined, + any: ?*const fn (save: T, data: []const u8) anyerror!void = undefined, }; } diff --git a/src/parser.zig b/src/parser.zig index 0993144..be35969 100644 --- a/src/parser.zig +++ b/src/parser.zig @@ -4,7 +4,7 @@ const std = @import("std"); const mem = std.mem; const Snowflake = @import("shared.zig").Snowflake; -pub fn parseUser(_: mem.Allocator, obj: *zmpl.Data.Object) !Discord.User { +pub fn parseUser(_: mem.Allocator, obj: *zmpl.Data.Object) std.fmt.ParseIntError!Discord.User { const avatar_decoration_data_obj = obj.getT(.object, "avatar_decoration_data"); const user = Discord.User{ .clan = null, @@ -35,7 +35,7 @@ pub fn parseUser(_: mem.Allocator, obj: *zmpl.Data.Object) !Discord.User { return user; } -pub fn parseMember(_: mem.Allocator, obj: *zmpl.Data.Object) !Discord.Member { +pub fn parseMember(_: mem.Allocator, obj: *zmpl.Data.Object) std.fmt.ParseIntError!Discord.Member { const avatar_decoration_data_member_obj = obj.getT(.object, "avatar_decoration_data"); const member = Discord.Member{ .deaf = obj.getT(.boolean, "deaf"), @@ -59,7 +59,7 @@ pub fn parseMember(_: mem.Allocator, obj: *zmpl.Data.Object) !Discord.Member { } /// caller must free the received referenced_message if any -pub fn parseMessage(allocator: mem.Allocator, obj: *zmpl.Data.Object) !Discord.Message { +pub fn parseMessage(allocator: mem.Allocator, obj: *zmpl.Data.Object) (mem.Allocator.Error || std.fmt.ParseIntError)!Discord.Message { // parse mentions const mentions_obj = obj.getT(.array, "mentions").?; diff --git a/src/shard.zig b/src/shard.zig index 27a1016..8c7ab99 100644 --- a/src/shard.zig +++ b/src/shard.zig @@ -34,9 +34,7 @@ const GatewayDispatchEvent = Internal.GatewayDispatchEvent; const Bucket = Internal.Bucket; const default_identify_properties = Internal.default_identify_properties; -const FetchRequest = @import("http.zig").FetchReq; - -const ShardSocketCloseCodes = enum(u16) { +pub const ShardSocketCloseCodes = enum(u16) { Shutdown = 3000, ZombiedConnection = 3010, }; @@ -83,8 +81,10 @@ ws_mutex: std.Thread.Mutex = .{}, rw_mutex: std.Thread.RwLock = .{}, log: Log = .no, +pub const JsonResolutionError = std.fmt.ParseIntError || std.fmt.ParseFloatError || json.ParseFromValueError || json.ParseError(json.Scanner); + /// caller must free the data -fn parseJson(self: *Self, raw: []const u8) !zmpl.Data { +fn parseJson(self: *Self, raw: []const u8) JsonResolutionError!zmpl.Data { var data = zmpl.Data.init(self.allocator); try data.fromJson(raw); return data; @@ -96,7 +96,7 @@ pub fn resumable(self: *Self) bool { self.sequence.load(.monotonic) > 0; } -pub fn resume_(self: *Self) !void { +pub fn resume_(self: *Self) SendError!void { const data = .{ .op = @intFromEnum(Opcode.Resume), .d = .{ .token = self.details.token, .session_id = self.session_id, @@ -111,7 +111,7 @@ inline fn gatewayUrl(self: ?*Self) []const u8 { } /// identifies in order to connect to Discord and get the online status, this shall be done on hello perhaps -fn identify(self: *Self, properties: ?IdentifyProperties) !void { +pub fn identify(self: *Self, properties: ?IdentifyProperties) SendError!void { self.logif("intents: {d}", .{self.details.intents.toRaw()}); if (self.details.intents.toRaw() != 0) { @@ -205,6 +205,8 @@ pub fn deinit(self: *Self) void { self.logif("killing the whole bot", .{}); } +const ReadMessageError = mem.Allocator.Error || zlib.Error || json.ParseError(json.Scanner) || json.ParseFromValueError; + /// listens for messages fn readMessage(self: *Self, _: anytype) !void { try self.client.readTimeout(0); @@ -313,7 +315,9 @@ fn readMessage(self: *Self, _: anytype) !void { } } -pub fn heartbeat(self: *Self, initial_jitter: f64) !void { +pub const SendHeartbeatError = CloseError || SendError; + +pub fn heartbeat(self: *Self, initial_jitter: f64) SendHeartbeatError!void { var jitter = initial_jitter; while (true) { @@ -341,7 +345,7 @@ pub fn heartbeat(self: *Self, initial_jitter: f64) !void { self.ws_mutex.unlock(); if ((std.time.milliTimestamp() - last) > (5000 * self.heart.heartbeatInterval)) { - self.close(ShardSocketCloseCodes.ZombiedConnection, "Zombied connection") catch unreachable; + try self.close(ShardSocketCloseCodes.ZombiedConnection, "Zombied connection"); @panic("zombied conn\n"); } @@ -349,13 +353,18 @@ pub fn heartbeat(self: *Self, initial_jitter: f64) !void { } } -pub inline fn reconnect(self: *Self) !void { +pub const ReconnectError = ConnectError || CloseError; + +pub fn reconnect(self: *Self) ReconnectError!void { try self.disconnect(); try self.connect(); } pub const ConnectError = - std.net.TcpConnectToAddressError || std.crypto.tls.Client.InitError(std.net.Stream) || std.net.Stream.ReadError || std.net.IPParseError || std.crypto.Certificate.Bundle.RescanError || std.net.TcpConnectToHostError || std.fmt.BufPrintError || mem.Allocator.Error; + net.TcpConnectToAddressError || crypto.tls.Client.InitError(net.Stream) || + net.Stream.ReadError || net.IPParseError || + crypto.Certificate.Bundle.RescanError || net.TcpConnectToHostError || + std.fmt.BufPrintError || mem.Allocator.Error; pub fn connect(self: *Self) ConnectError!void { //std.time.sleep(std.time.ms_per_s * 5); @@ -366,7 +375,7 @@ pub fn connect(self: *Self) ConnectError!void { self.readMessage(null) catch unreachable; } -pub fn disconnect(self: *Self) !void { +pub fn disconnect(self: *Self) CloseError!void { try self.close(ShardSocketCloseCodes.Shutdown, "Shard down request"); } @@ -456,7 +465,7 @@ pub fn handleEvent(self: *Self, name: []const u8, payload: []const u8) !void { else => unreachable, }; } - if (self.handler.ready) |event| event(self, ready); + if (self.handler.ready) |event| try event(self, ready); } if (std.ascii.eqlIgnoreCase(name, "message_delete")) { @@ -470,7 +479,7 @@ pub fn handleEvent(self: *Self, name: []const u8, payload: []const u8) !void { .guild_id = try Shared.Snowflake.fromMaybe(obj.getT(.string, "guild_id")), }; - if (self.handler.message_delete) |event| event(self, data); + if (self.handler.message_delete) |event| try event(self, data); } if (std.ascii.eqlIgnoreCase(name, "message_delete_bulk")) { @@ -491,7 +500,7 @@ pub fn handleEvent(self: *Self, name: []const u8, payload: []const u8) !void { .guild_id = try Shared.Snowflake.fromMaybe(obj.getT(.string, "guild_id")), }; - if (self.handler.message_delete_bulk) |event| event(self, data); + if (self.handler.message_delete_bulk) |event| try event(self, data); } if (std.ascii.eqlIgnoreCase(name, "message_update")) { @@ -502,7 +511,7 @@ pub fn handleEvent(self: *Self, name: []const u8, payload: []const u8) !void { const message = try Parser.parseMessage(self.allocator, obj); //defer if (message.referenced_message) |mptr| self.allocator.destroy(mptr); - if (self.handler.message_update) |event| event(self, message); + if (self.handler.message_update) |event| try event(self, message); } if (std.ascii.eqlIgnoreCase(name, "message_create")) { @@ -515,9 +524,9 @@ pub fn handleEvent(self: *Self, name: []const u8, payload: []const u8) !void { //defer if (message.referenced_message) |mptr| self.allocator.destroy(mptr); self.logif("it worked {s} {?s}", .{ name, message.content }); - if (self.handler.message_create) |event| event(self, message); + if (self.handler.message_create) |event| try event(self, message); } else { - if (self.handler.any) |anyEvent| anyEvent(self, payload); + if (self.handler.any) |anyEvent| try anyEvent(self, payload); } } diff --git a/src/sharder.zig b/src/sharder.zig index 68395da..114d53d 100644 --- a/src/sharder.zig +++ b/src/sharder.zig @@ -15,8 +15,8 @@ const debug = Internal.debug; pub const discord_epoch = 1420070400000; /// Calculate and return the shard ID for a given guild ID -pub inline fn calculateShardId(guildId: u64, shards: ?usize) u64 { - return (guildId >> 22) % shards orelse 1; +pub inline fn calculateShardId(guild_id: u64, shards: ?usize) u64 { + return (guild_id >> 22) % shards orelse 1; } /// Convert a timestamp to a snowflake. @@ -106,11 +106,11 @@ pub fn forceIdentify(self: *Self, shard_id: usize) !void { return shard.identify(null); } -pub fn disconnect(self: *Self, shard_id: usize) !void { +pub fn disconnect(self: *Self, shard_id: usize) Shard.CloseError!void { return if (self.shards.get(shard_id)) |shard| shard.disconnect(); } -pub fn disconnectAll(self: *Self) !void { +pub fn disconnectAll(self: *Self) Shard.CloseError!void { while (self.shards.iterator().next()) |shard| shard.value_ptr.disconnect(); } @@ -199,10 +199,8 @@ pub fn spawnShards(self: *Self) !void { //self.startResharder(); } -pub fn send(self: *Self, shard_id: usize, data: anytype) !void { - if (self.shards.get(shard_id)) |shard| { - try shard.send(data); - } +pub fn send(self: *Self, shard_id: usize, data: anytype) Shard.SendError!void { + if (self.shards.get(shard_id)) |shard| try shard.send(data); } // SPEC OF THE RESHARDER: diff --git a/src/shared.zig b/src/shared.zig index c1e4d55..e663ea0 100644 --- a/src/shared.zig +++ b/src/shared.zig @@ -70,7 +70,7 @@ pub const ShardDetails = struct { pub const Snowflake = struct { id: u64, - pub fn fromMaybe(raw: ?[]const u8) !?Snowflake { + pub fn fromMaybe(raw: ?[]const u8) std.fmt.ParseIntError!?Snowflake { if (raw) |id| { return .{ .id = try std.fmt.parseInt(u64, id, 10), @@ -78,7 +78,7 @@ pub const Snowflake = struct { } else return null; } - pub fn fromRaw(raw: []const u8) !Snowflake { + pub fn fromRaw(raw: []const u8) std.fmt.ParseIntError!Snowflake { return .{ .id = try std.fmt.parseInt(u64, raw, 10), }; diff --git a/src/test.zig b/src/test.zig index 20747b9..a171748 100644 --- a/src/test.zig +++ b/src/test.zig @@ -1,17 +1,18 @@ -const Client = @import("discord.zig").Client; -const Shard = @import("discord.zig").Shard; -const Discord = @import("discord.zig").Discord; -const Internal = @import("discord.zig").Internal; -const FetchReq = @import("discord.zig").FetchReq; +const Discord = @import("discord.zig"); + +const Shard = Discord.Shard; +const Internal = Discord.Internal; +const FetchReq = Discord.FetchReq; const Intents = Discord.Intents; const Thread = std.Thread; const std = @import("std"); +const fmt = std.fmt; -fn ready(_: *Shard, payload: Discord.Ready) void { +fn ready(_: *Shard, payload: Discord.Ready) !void { std.debug.print("logged in as {s}\n", .{payload.user.username}); } -fn message_create(session: *Shard, message: Discord.Message) void { +fn message_create(session: *Shard, message: Discord.Message) fmt.AllocPrintError!void { std.debug.print("captured: {?s} send by {s}\n", .{ message.content, message.author.username }); if (message.content) |mc| if (std.ascii.eqlIgnoreCase(mc, "!hi")) { @@ -21,7 +22,7 @@ fn message_create(session: *Shard, message: Discord.Message) void { const payload: Discord.Partial(Discord.CreateMessage) = .{ .content = "Hi, I'm hang man, your personal assistant" }; const json = std.json.stringifyAlloc(session.allocator, payload, .{}) catch unreachable; defer session.allocator.free(json); - const path = std.fmt.allocPrint(session.allocator, "/channels/{d}/messages", .{message.channel_id.value()}) catch unreachable; + const path = try fmt.allocPrint(session.allocator, "/channels/{d}/messages", .{message.channel_id.value()}); _ = req.makeRequest(.POST, path, json) catch unreachable; }; @@ -30,7 +31,7 @@ fn message_create(session: *Shard, message: Discord.Message) void { pub fn main() !void { var tsa = std.heap.ThreadSafeAllocator{ .child_allocator = std.heap.c_allocator }; - var handler = Client.init(tsa.allocator()); + var handler = Discord.init(tsa.allocator()); try handler.start(.{ .token = std.posix.getenv("TOKEN") orelse unreachable, .intents = Intents.fromRaw(37379),