fix all concurrency issues

This commit is contained in:
Yuzu 2024-11-08 17:58:43 +00:00
parent f6945d8caf
commit aea624ea64
8 changed files with 346 additions and 254 deletions

View File

@ -60,7 +60,7 @@ pub fn build(b: *std.Build) void {
marin.root_module.addImport("zmpl", zmpl.module("zmpl")); marin.root_module.addImport("zmpl", zmpl.module("zmpl"));
marin.root_module.addImport("deque", deque.module("zig-deque")); marin.root_module.addImport("deque", deque.module("zig-deque"));
b.installArtifact(marin); //b.installArtifact(marin);
// test // test
const run_cmd = b.addRunArtifact(marin); const run_cmd = b.addRunArtifact(marin);

View File

@ -1,6 +1,11 @@
const std = @import("std"); const std = @import("std");
const mem = std.mem; const mem = std.mem;
const Deque = @import("deque"); const Deque = @import("deque");
const Discord = @import("types.zig");
pub const debug = std.log.scoped(.@"discord.zig");
pub const Log = union(enum) { yes, no };
/// inspired from: /// inspired from:
/// https://github.com/tiramisulabs/seyfert/blob/main/src/websocket/structures/timeout.ts /// https://github.com/tiramisulabs/seyfert/blob/main/src/websocket/structures/timeout.ts
@ -67,3 +72,124 @@ pub const ConnectQueue = struct {
} }
} }
}; };
fn lessthan(_: void, a: RequestWithPrio, b: RequestWithPrio) void {
return std.math.order(a, b);
}
pub const Bucket = struct {
/// The queue of requests to acquire an available request. Mapped by (shardId, RequestWithPrio)
queue: std.PriorityQueue(RequestWithPrio, void, lessthan),
limit: usize,
refillInterval: u64,
refillAmount: usize,
/// The amount of requests that have been used up already.
used: usize = 0,
/// Whether or not the queue is already processing.
processing: bool = false,
/// Whether the timeout should be killed because there is already one running
shouldStop: bool = false,
/// The timestamp in milliseconds when the next refill is scheduled.
refillsAt: ?u64,
/// comes in handy
m: std.Thread.Mutex = .{},
c: std.Thread.Condition = .{},
fn timeout(self: *Bucket) void {
_ = self;
}
pub fn processQueue() !void {}
pub fn refill() void {}
pub fn acquire(self: *Bucket, rq: RequestWithPrio) !void {
try self.queue.add(rq);
try self.processQueue();
}
};
pub const RequestWithPrio = struct {
callback: *const fn () void,
priority: u32,
};
pub fn GatewayDispatchEvent(comptime T: type) type {
return struct {
// TODO: implement // application_command_permissions_update: null = null,
// TODO: implement // auto_moderation_rule_create: null = null,
// TODO: implement // auto_moderation_rule_update: null = null,
// TODO: implement // auto_moderation_rule_delete: null = null,
// TODO: implement // auto_moderation_action_execution: null = null,
// TODO: implement // channel_create: null = null,
// TODO: implement // channel_update: null = null,
// TODO: implement // channel_delete: null = null,
// TODO: implement // channel_pins_update: null = null,
// TODO: implement // thread_create: null = null,
// TODO: implement // thread_update: null = null,
// TODO: implement // thread_delete: null = null,
// TODO: implement // thread_list_sync: null = null,
// TODO: implement // thread_member_update: null = null,
// TODO: implement // thread_members_update: null = null,
// TODO: implement // guild_audit_log_entry_create: null = null,
// TODO: implement // guild_create: null = null,
// TODO: implement // guild_update: null = null,
// TODO: implement // guild_delete: null = null,
// TODO: implement // guild_ban_add: null = null,
// TODO: implement // guild_ban_remove: null = null,
// TODO: implement // guild_emojis_update: null = null,
// TODO: implement // guild_stickers_update: null = null,
// TODO: implement // guild_integrations_update: null = null,
// TODO: implement // guild_member_add: null = null,
// TODO: implement // guild_member_remove: null = null,
// TODO: implement // guild_member_update: null = null,
// TODO: implement // guild_members_chunk: null = null,
// TODO: implement // guild_role_create: null = null,
// TODO: implement // guild_role_update: null = null,
// TODO: implement // guild_role_delete: null = null,
// TODO: implement // guild_scheduled_event_create: null = null,
// TODO: implement // guild_scheduled_event_update: null = null,
// TODO: implement // guild_scheduled_event_delete: null = null,
// TODO: implement // guild_scheduled_event_user_add: null = null,
// TODO: implement // guild_scheduled_event_user_remove: null = null,
// TODO: implement // integration_create: null = null,
// TODO: implement // integration_update: null = null,
// TODO: implement // integration_delete: null = null,
// TODO: implement // interaction_create: null = null,
// TODO: implement // invite_create: null = null,
// TODO: implement // invite_delete: null = null,
message_create: ?*const fn (save: T, message: Discord.Message) void = undefined,
message_update: ?*const fn (save: T, message: Discord.Message) void = undefined,
message_delete: ?*const fn (save: T, log: Discord.MessageDelete) void = undefined,
message_delete_bulk: ?*const fn (save: T, log: Discord.MessageDeleteBulk) void = undefined,
// TODO: implement // message_delete_bulk: null = null,
// TODO: implement // message_reaction_add: null = null,
// TODO: implement // message_reaction_remove: null = null,
// TODO: implement // message_reaction_remove_all: null = null,
// TODO: implement // message_reaction_remove_emoji: null = null,
// TODO: implement // presence_update: null = null,
// TODO: implement // stage_instance_create: null = null,
// TODO: implement // stage_instance_update: null = null,
// TODO: implement // stage_instance_delete: null = null,
// TODO: implement // typing_start: null = null,
// TODO: implement // user_update: null = null,
// TODO: implement // voice_channel_effect_send: null = null,
// TODO: implement // voice_state_update: null = null,
// TODO: implement // voice_server_update: null = null,
// TODO: implement // webhooks_update: null = null,
// TODO: implement // entitlement_create: null = null,
// TODO: implement // entitlement_update: null = null,
// TODO: implement // entitlement_delete: null = null,
// TODO: implement // message_poll_vote_add: null = null,
// TODO: implement // message_poll_vote_remove: null = null,
ready: ?*const fn (save: T, data: Discord.Ready) void = undefined,
// TODO: implement // resumed: null = null,
any: ?*const fn (save: T, data: []const u8) void = undefined,
};
}

View File

@ -70,8 +70,10 @@ pub fn parseMessage(allocator: mem.Allocator, obj: *zmpl.Data.Object) !Discord.M
try mentions.append(try parseUser(allocator, &m.object)); try mentions.append(try parseUser(allocator, &m.object));
} }
std.debug.print("parsing mentions done\n", .{});
// parse member // parse member
const member = try parseMember(allocator, obj.getT(.object, "member").?); const member = if (obj.getT(.object, "member")) |m| try parseMember(allocator, m) else null;
// parse message // parse message
const author = try parseUser(allocator, obj.getT(.object, "author").?); const author = try parseUser(allocator, obj.getT(.object, "author").?);
@ -79,13 +81,10 @@ pub fn parseMessage(allocator: mem.Allocator, obj: *zmpl.Data.Object) !Discord.M
// the referenced_message if any // the referenced_message if any
const refmp = try allocator.create(Discord.Message); const refmp = try allocator.create(Discord.Message);
var invalid_ptr = false;
if (obj.getT(.object, "referenced_message")) |m| { if (obj.getT(.object, "referenced_message")) |m| {
refmp.* = try parseMessage(allocator, m); refmp.* = try parseMessage(allocator, m);
} else { } else {
allocator.destroy(refmp); allocator.destroy(refmp);
invalid_ptr = true;
} }
// parse message // parse message
@ -109,7 +108,12 @@ pub fn parseMessage(allocator: mem.Allocator, obj: *zmpl.Data.Object) !Discord.M
.mention_channels = &[0]?Discord.ChannelMention{}, .mention_channels = &[0]?Discord.ChannelMention{},
.embeds = &[0]Discord.Embed{}, .embeds = &[0]Discord.Embed{},
.reactions = &[0]?Discord.Reaction{}, .reactions = &[0]?Discord.Reaction{},
.nonce = .{ .string = obj.getT(.string, "nonce").? }, .nonce = if (obj.get("nonce")) |nonce| switch (nonce.*) {
.integer => |n| .{ .int = @as(isize, @intCast(n.value)) },
.string => |n| .{ .string = n.value },
.Null => null,
else => unreachable,
} else null,
.webhook_id = try Snowflake.fromMaybe(obj.getT(.string, "webhook_id")), .webhook_id = try Snowflake.fromMaybe(obj.getT(.string, "webhook_id")),
.activity = null, .activity = null,
.application = null, .application = null,
@ -126,7 +130,7 @@ pub fn parseMessage(allocator: mem.Allocator, obj: *zmpl.Data.Object) !Discord.M
.position = if (obj.getT(.integer, "position")) |p| @as(isize, @intCast(p)) else null, .position = if (obj.getT(.integer, "position")) |p| @as(isize, @intCast(p)) else null,
.poll = null, .poll = null,
.call = null, .call = null,
.referenced_message = if (invalid_ptr) null else refmp, .referenced_message = refmp,
}; };
return message; return message;
} }

View File

@ -1,7 +1,41 @@
const ConnectQueue = @import("internal.zig").ConnectQueue;
const Intents = @import("types.zig").Intents; const Intents = @import("types.zig").Intents;
const GatewayBotInfo = @import("shared.zig").GatewayBotInfo; const GatewayBotInfo = @import("shared.zig").GatewayBotInfo;
const IdentifyProperties = @import("shared.zig").IdentifyProperties; const Shared = @import("shared.zig");
const IdentifyProperties = Shared.IdentifyProperties;
const Internal = @import("internal.zig");
const ConnectQueue = Internal.ConnectQueue;
const GatewayDispatchEvent = Internal.GatewayDispatchEvent;
const Shard = @import("shard.zig");
const std = @import("std");
const mem = std.mem;
const debug = Internal.debug;
const Self = @This();
token: []const u8,
intents: Intents,
allocator: mem.Allocator,
connectQueue: ConnectQueue,
shards: std.AutoArrayHashMap(usize, Shard),
run: GatewayDispatchEvent(*Self),
/// spawn buckets in order
/// https://discord.com/developers/docs/events/gateway#sharding-max-concurrency
fn spawnBuckers(self: *Self) !void {
_ = self;
}
/// creates a shard and stores it
fn create(self: *Self, shard_id: usize) !Shard {
const shard = try self.shards.getOrPutValue(shard_id, try Shard.login(self.allocator, .{
.token = self.token,
.intents = self.intents,
.run = self.run,
.log = self.log,
}));
return shard;
}
pub const ShardDetails = struct { pub const ShardDetails = struct {
/// Bot token which is used to connect to Discord */ /// Bot token which is used to connect to Discord */

View File

@ -1,32 +1,37 @@
const std = @import("std");
const json = std.json;
const mem = std.mem;
const http = std.http;
const ws = @import("ws"); const ws = @import("ws");
const builtin = @import("builtin"); const builtin = @import("builtin");
const HttpClient = @import("tls12").HttpClient; const HttpClient = @import("tls12").HttpClient;
const std = @import("std");
const net = std.net; const net = std.net;
const crypto = std.crypto; const crypto = std.crypto;
const tls = std.crypto.tls; const tls = std.crypto.tls;
const json = std.json;
const mem = std.mem;
const http = std.http;
// todo use this to read compressed messages // todo use this to read compressed messages
const zlib = @import("zlib"); const zlib = @import("zlib");
const zmpl = @import("zmpl"); const zmpl = @import("zmpl");
const Discord = @import("types.zig");
const Parser = @import("parser.zig"); const Parser = @import("parser.zig");
const debug = std.log.scoped(.@"discord.zig");
const Self = @This(); const Self = @This();
const Discord = @import("types.zig");
const GatewayPayload = Discord.GatewayPayload; const GatewayPayload = Discord.GatewayPayload;
const Opcode = Discord.GatewayOpcodes; const Opcode = Discord.GatewayOpcodes;
const Intents = Discord.Intents; const Intents = Discord.Intents;
const Shared = @import("shared.zig"); const Shared = @import("shared.zig");
const IdentifyProperties = Shared.IdentifyProperties; const IdentifyProperties = Shared.IdentifyProperties;
const GatewayInfo = Shared.GatewayInfo; const GatewayInfo = Shared.GatewayInfo;
const GatewayBotInfo = Shared.GatewayBotInfo; const GatewayBotInfo = Shared.GatewayBotInfo;
const GatewaySessionStartLimit = Shared.GatewaySessionStartLimit; const GatewaySessionStartLimit = Shared.GatewaySessionStartLimit;
const Internal = @import("internal.zig");
const Log = Internal.Log;
const GatewayDispatchEvent = Internal.GatewayDispatchEvent;
const ShardSocketCloseCodes = enum(u16) { const ShardSocketCloseCodes = enum(u16) {
Shutdown = 3000, Shutdown = 3000,
ZombiedConnection = 3010, ZombiedConnection = 3010,
@ -34,80 +39,7 @@ const ShardSocketCloseCodes = enum(u16) {
const BASE_URL = "https://discord.com/api/v10"; const BASE_URL = "https://discord.com/api/v10";
pub const GatewayDispatchEvent = struct { pub const FetchReq = struct {
// TODO: implement // application_command_permissions_update: null = null,
// TODO: implement // auto_moderation_rule_create: null = null,
// TODO: implement // auto_moderation_rule_update: null = null,
// TODO: implement // auto_moderation_rule_delete: null = null,
// TODO: implement // auto_moderation_action_execution: null = null,
// TODO: implement // channel_create: null = null,
// TODO: implement // channel_update: null = null,
// TODO: implement // channel_delete: null = null,
// TODO: implement // channel_pins_update: null = null,
// TODO: implement // thread_create: null = null,
// TODO: implement // thread_update: null = null,
// TODO: implement // thread_delete: null = null,
// TODO: implement // thread_list_sync: null = null,
// TODO: implement // thread_member_update: null = null,
// TODO: implement // thread_members_update: null = null,
// TODO: implement // guild_audit_log_entry_create: null = null,
// TODO: implement // guild_create: null = null,
// TODO: implement // guild_update: null = null,
// TODO: implement // guild_delete: null = null,
// TODO: implement // guild_ban_add: null = null,
// TODO: implement // guild_ban_remove: null = null,
// TODO: implement // guild_emojis_update: null = null,
// TODO: implement // guild_stickers_update: null = null,
// TODO: implement // guild_integrations_update: null = null,
// TODO: implement // guild_member_add: null = null,
// TODO: implement // guild_member_remove: null = null,
// TODO: implement // guild_member_update: null = null,
// TODO: implement // guild_members_chunk: null = null,
// TODO: implement // guild_role_create: null = null,
// TODO: implement // guild_role_update: null = null,
// TODO: implement // guild_role_delete: null = null,
// TODO: implement // guild_scheduled_event_create: null = null,
// TODO: implement // guild_scheduled_event_update: null = null,
// TODO: implement // guild_scheduled_event_delete: null = null,
// TODO: implement // guild_scheduled_event_user_add: null = null,
// TODO: implement // guild_scheduled_event_user_remove: null = null,
// TODO: implement // integration_create: null = null,
// TODO: implement // integration_update: null = null,
// TODO: implement // integration_delete: null = null,
// TODO: implement // interaction_create: null = null,
// TODO: implement // invite_create: null = null,
// TODO: implement // invite_delete: null = null,
message_create: ?*const fn (message: Discord.Message) void = undefined,
message_update: ?*const fn (message: Discord.Message) void = undefined,
message_delete: ?*const fn (log: Discord.MessageDelete) void = undefined,
message_delete_bulk: ?*const fn (log: Discord.MessageDeleteBulk) void = undefined,
// TODO: implement // message_delete_bulk: null = null,
// TODO: implement // message_reaction_add: null = null,
// TODO: implement // message_reaction_remove: null = null,
// TODO: implement // message_reaction_remove_all: null = null,
// TODO: implement // message_reaction_remove_emoji: null = null,
// TODO: implement // presence_update: null = null,
// TODO: implement // stage_instance_create: null = null,
// TODO: implement // stage_instance_update: null = null,
// TODO: implement // stage_instance_delete: null = null,
// TODO: implement // typing_start: null = null,
// TODO: implement // user_update: null = null,
// TODO: implement // voice_channel_effect_send: null = null,
// TODO: implement // voice_state_update: null = null,
// TODO: implement // voice_server_update: null = null,
// TODO: implement // webhooks_update: null = null,
// TODO: implement // entitlement_create: null = null,
// TODO: implement // entitlement_update: null = null,
// TODO: implement // entitlement_delete: null = null,
// TODO: implement // message_poll_vote_add: null = null,
// TODO: implement // message_poll_vote_remove: null = null,
ready: ?*const fn (data: Discord.Ready) void = undefined,
// TODO: implement // resumed: null = null,
any: ?*const fn (data: []const u8) void = undefined,
};
const FetchReq = struct {
allocator: mem.Allocator, allocator: mem.Allocator,
token: []const u8, token: []const u8,
client: HttpClient, client: HttpClient,
@ -128,10 +60,10 @@ const FetchReq = struct {
self.body.deinit(); self.body.deinit();
} }
pub fn makeRequest(self: *FetchReq, method: http.Method, path: []const u8, body: ?[]const u8) !HttpClient.FetchResult { pub fn makeRequest(self: *FetchReq, method: http.Method, path: []const u8, to_post: ?[]const u8) !HttpClient.FetchResult {
var fetch_options = HttpClient.FetchOptions{ var fetch_options = HttpClient.FetchOptions{
.location = HttpClient.FetchOptions.Location{ .location = HttpClient.FetchOptions.Location{
.url = path, .url = try std.fmt.allocPrint(self.allocator, "{s}{s}", .{ BASE_URL, path }),
}, },
.extra_headers = &[_]http.Header{ .extra_headers = &[_]http.Header{
http.Header{ .name = "Accept", .value = "application/json" }, http.Header{ .name = "Accept", .value = "application/json" },
@ -142,8 +74,8 @@ const FetchReq = struct {
.response_storage = .{ .dynamic = &self.body }, .response_storage = .{ .dynamic = &self.body },
}; };
if (body != null) { if (to_post != null) {
fetch_options.payload = body; fetch_options.payload = to_post;
} }
const res = try self.client.fetch(fetch_options); const res = try self.client.fetch(fetch_options);
@ -160,13 +92,14 @@ const _default_properties = IdentifyProperties{
const Heart = struct { const Heart = struct {
heartbeatInterval: u64, heartbeatInterval: u64,
ack: bool, ack: bool,
/// useful for calculating ping /// useful for calculating ping and resuming
lastBeat: u64, lastBeat: i64,
}; };
client: ws.Client, client: ws.Client,
token: []const u8, token: []const u8,
intents: Intents, intents: Intents,
//heart: Heart = //heart: Heart =
allocator: mem.Allocator, allocator: mem.Allocator,
resume_gateway_url: ?[]const u8 = null, resume_gateway_url: ?[]const u8 = null,
@ -174,16 +107,17 @@ info: GatewayBotInfo,
properties: IdentifyProperties = _default_properties, properties: IdentifyProperties = _default_properties,
session_id: ?[]const u8, session_id: ?[]const u8,
sequence: isize, sequence: std.atomic.Value(isize) = std.atomic.Value(isize).init(0),
heart: Heart = .{ .heartbeatInterval = 45000, .ack = false, .lastBeat = 0 }, heart: Heart = .{ .heartbeatInterval = 45000, .ack = false, .lastBeat = 0 },
/// ///
handler: GatewayDispatchEvent, handler: GatewayDispatchEvent(*Self),
packets: std.ArrayList(u8), packets: std.ArrayList(u8),
inflator: zlib.Decompressor, inflator: zlib.Decompressor,
///useful for closing the conn ///useful for closing the conn
mutex: std.Thread.Mutex = .{}, ws_mutex: std.Thread.Mutex = .{},
rw_mutex: std.Thread.RwLock = .{},
log: Log = .no, log: Log = .no,
/// caller must free the data /// caller must free the data
@ -193,17 +127,17 @@ fn parseJson(self: *Self, raw: []const u8) !zmpl.Data {
return data; return data;
} }
pub inline fn resumable(self: *Self) bool { pub fn resumable(self: *Self) bool {
return self.resume_gateway_url != null and return self.resume_gateway_url != null and
self.session_id != null and self.session_id != null and
self.getSequence() > 0; self.sequence.load(.monotonic) > 0;
} }
pub fn resume_(self: *Self) !void { pub fn resume_(self: *Self) !void {
const data = .{ .op = @intFromEnum(Opcode.Resume), .d = .{ const data = .{ .op = @intFromEnum(Opcode.Resume), .d = .{
.token = self.token, .token = self.token,
.session_id = self.session_id, .session_id = self.session_id,
.seq = self.getSequence(), .seq = self.sequence.load(.monotonic),
} }; } };
try self.send(data); try self.send(data);
@ -240,19 +174,17 @@ fn identify(self: *Self, properties: ?IdentifyProperties) !void {
} }
} }
const Log = union(enum) { yes, no };
// asks /gateway/bot initializes both the ws client and the http client // asks /gateway/bot initializes both the ws client and the http client
pub fn login(allocator: mem.Allocator, args: struct { pub fn login(allocator: mem.Allocator, args: struct {
token: []const u8, token: []const u8,
intents: Intents, intents: Intents,
run: GatewayDispatchEvent, run: GatewayDispatchEvent(*Self),
log: Log, log: Log,
}) !Self { }) !Self {
var req = FetchReq.init(allocator, args.token); var req = FetchReq.init(allocator, args.token);
defer req.deinit(); defer req.deinit();
const res = try req.makeRequest(.GET, BASE_URL ++ "/gateway/bot", null); const res = try req.makeRequest(.GET, "/gateway/bot", null);
const body = try req.body.toOwnedSlice(); const body = try req.body.toOwnedSlice();
defer allocator.free(body); defer allocator.free(body);
@ -272,7 +204,6 @@ pub fn login(allocator: mem.Allocator, args: struct {
// maybe there is a better way to do this // maybe there is a better way to do this
.client = try Self._connect_ws(allocator, url), .client = try Self._connect_ws(allocator, url),
.session_id = undefined, .session_id = undefined,
.sequence = 0,
.info = parsed.value, .info = parsed.value,
.handler = args.run, .handler = args.run,
.log = args.log, .log = args.log,
@ -322,31 +253,30 @@ pub fn readMessage(self: *Self, _: anytype) !void {
if (!std.mem.endsWith(u8, msg.data, &[4]u8{ 0x00, 0x00, 0xFF, 0xFF })) if (!std.mem.endsWith(u8, msg.data, &[4]u8{ 0x00, 0x00, 0xFF, 0xFF }))
continue; continue;
// self.logif("{b}\n", .{self.packets.items});
const buf = try self.packets.toOwnedSlice(); const buf = try self.packets.toOwnedSlice();
const decompressed = try self.inflator.decompressAllAlloc(buf); const decompressed = try self.inflator.decompressAllAlloc(buf);
defer self.allocator.free(decompressed);
const raw = try json.parseFromSlice(struct { const raw = try json.parseFromSlice(struct {
/// opcode for the payload
op: isize, op: isize,
/// Event data
d: json.Value, d: json.Value,
/// Sequence isize, used for resuming sessions and heartbeats
s: ?i64, s: ?i64,
/// The event name for this payload
t: ?[]const u8, t: ?[]const u8,
}, self.allocator, decompressed, .{}); }, self.allocator, decompressed, .{});
defer raw.deinit();
const payload = raw.value; const payload = raw.value;
switch (@as(Opcode, @enumFromInt(payload.op))) { switch (@as(Opcode, @enumFromInt(payload.op))) {
Opcode.Dispatch => { Opcode.Dispatch => {
self.setSequence(payload.s orelse 0);
// maybe use threads and call it instead from there // maybe use threads and call it instead from there
if (payload.t) |name| try self.handleEvent(name, decompressed); if (payload.t) |name| {
self.logif("logging event {s}", .{name});
self.sequence.store(payload.s orelse 0, .monotonic);
try self.handleEvent(name, decompressed);
}
}, },
Opcode.Hello => { Opcode.Hello => {
{
const HelloPayload = struct { heartbeat_interval: u64, _trace: [][]const u8 }; const HelloPayload = struct { heartbeat_interval: u64, _trace: [][]const u8 };
const parsed = try json.parseFromValue(HelloPayload, self.allocator, payload.d, .{}); const parsed = try json.parseFromValue(HelloPayload, self.allocator, payload.d, .{});
defer parsed.deinit(); defer parsed.deinit();
@ -362,36 +292,31 @@ pub fn readMessage(self: *Self, _: anytype) !void {
.lastBeat = 0, .lastBeat = 0,
}; };
self.logif("starting heart beater. seconds:{d}...", .{self.heart.heartbeatInterval});
try self.heartbeat();
var prng = std.Random.DefaultPrng.init(0);
const jitter = std.Random.float(prng.random(), f64);
const thread = try std.Thread.spawn(.{}, Self.heartbeat_wait, .{ self, jitter });
thread.detach();
if (self.resumable()) { if (self.resumable()) {
try self.resume_(); try self.resume_();
return; return;
} else { } else {
try self.identify(self.properties); try self.identify(self.properties);
} }
}
var prng = std.Random.DefaultPrng.init(0);
const jitter = std.Random.float(prng.random(), f64);
const heartbeat_writer = try std.Thread.spawn(.{}, Self.heartbeat, .{ self, jitter });
heartbeat_writer.detach();
}, },
Opcode.HeartbeatACK => { Opcode.HeartbeatACK => {
// perhaps this needs a mutex? // perhaps this needs a mutex?
self.logif("got heartbeat ack", .{}); self.logif("got heartbeat ack", .{});
self.rw_mutex.lock();
self.mutex.lock(); defer self.rw_mutex.unlock();
defer self.mutex.unlock(); self.heart.lastBeat = std.time.milliTimestamp();
self.heart.ack = true;
}, },
Opcode.Heartbeat => { Opcode.Heartbeat => {
self.logif("sending requested heartbeat", .{}); self.logif("sending requested heartbeat", .{});
try self.heartbeat(); self.ws_mutex.lock();
defer self.ws_mutex.unlock();
try self.send(.{ .op = @intFromEnum(Opcode.Heartbeat), .d = self.sequence.load(.monotonic) });
}, },
Opcode.Reconnect => { Opcode.Reconnect => {
self.logif("reconnecting", .{}); self.logif("reconnecting", .{});
@ -403,15 +328,13 @@ pub fn readMessage(self: *Self, _: anytype) !void {
session_id: []const u8, session_id: []const u8,
seq: ?isize, seq: ?isize,
}; };
{
const parsed = try json.parseFromValue(WithSequence, self.allocator, payload.d, .{}); const parsed = try json.parseFromValue(WithSequence, self.allocator, payload.d, .{});
defer parsed.deinit(); defer parsed.deinit();
const resume_payload = parsed.value; const resume_payload = parsed.value;
self.setSequence(resume_payload.seq orelse 0); self.sequence.store(resume_payload.seq orelse 0, .monotonic);
self.session_id = resume_payload.session_id; self.session_id = resume_payload.session_id;
}
}, },
Opcode.InvalidSession => {}, Opcode.InvalidSession => {},
else => { else => {
@ -421,13 +344,11 @@ pub fn readMessage(self: *Self, _: anytype) !void {
} }
} }
pub fn heartbeat(self: *Self) !void { pub fn heartbeat(self: *Self, initial_jitter: f64) !void {
const data = .{ .op = @intFromEnum(Opcode.Heartbeat), .d = if (self.getSequence() > 0) self.getSequence() else null }; var jitter = initial_jitter;
try self.send(data); while (true) {
} // basecase
pub fn heartbeat_wait(self: *Self, jitter: f64) !void {
if (jitter == 1.0) { if (jitter == 1.0) {
// self.logif("zzz for {d}", .{self.heart.heartbeatInterval}); // self.logif("zzz for {d}", .{self.heart.heartbeatInterval});
std.Thread.sleep(std.time.ns_per_ms * self.heart.heartbeatInterval); std.Thread.sleep(std.time.ns_per_ms * self.heart.heartbeatInterval);
@ -435,20 +356,28 @@ pub fn heartbeat_wait(self: *Self, jitter: f64) !void {
const timeout = @as(f64, @floatFromInt(self.heart.heartbeatInterval)) * jitter; const timeout = @as(f64, @floatFromInt(self.heart.heartbeatInterval)) * jitter;
self.logif("zzz for {d} and jitter {d}", .{ @as(u64, @intFromFloat(timeout)), jitter }); self.logif("zzz for {d} and jitter {d}", .{ @as(u64, @intFromFloat(timeout)), jitter });
std.Thread.sleep(std.time.ns_per_ms * @as(u64, @intFromFloat(timeout))); std.Thread.sleep(std.time.ns_per_ms * @as(u64, @intFromFloat(timeout)));
self.logif("end timeout", .{});
} }
self.logif(">> ♥ and ack received: {}", .{self.heart.ack}); self.logif(">> ♥ and ack received: {d}", .{self.heart.lastBeat});
if (self.heart.ack) { self.rw_mutex.lock();
const last = self.heart.lastBeat;
self.rw_mutex.unlock();
const seq = self.sequence.load(.monotonic);
self.logif("sending unrequested heartbeat", .{}); self.logif("sending unrequested heartbeat", .{});
try self.heartbeat(); self.ws_mutex.lock();
try self.client.readTimeout(1000); try self.send(.{ .op = @intFromEnum(Opcode.Heartbeat), .d = seq });
} else { self.ws_mutex.unlock();
if (last > (5100 * self.heart.heartbeatInterval)) {
self.close(ShardSocketCloseCodes.ZombiedConnection, "Zombied connection") catch unreachable; self.close(ShardSocketCloseCodes.ZombiedConnection, "Zombied connection") catch unreachable;
@panic("zombied conn\n"); @panic("zombied conn\n");
} }
return heartbeat_wait(self, 1.0); jitter = 1.0;
}
} }
pub inline fn reconnect(self: *Self) !void { pub inline fn reconnect(self: *Self) !void {
@ -457,9 +386,6 @@ pub inline fn reconnect(self: *Self) !void {
} }
pub fn connect(self: *Self) !void { pub fn connect(self: *Self) !void {
self.mutex.lock();
defer self.mutex.unlock();
//std.time.sleep(std.time.ms_per_s * 5); //std.time.sleep(std.time.ms_per_s * 5);
self.client = try Self._connect_ws(self.allocator, self.gatewayUrl()); self.client = try Self._connect_ws(self.allocator, self.gatewayUrl());
} }
@ -469,9 +395,6 @@ pub fn disconnect(self: *Self) !void {
} }
pub fn close(self: *Self, code: ShardSocketCloseCodes, reason: []const u8) !void { pub fn close(self: *Self, code: ShardSocketCloseCodes, reason: []const u8) !void {
self.mutex.lock();
defer self.mutex.unlock();
self.logif("cooked closing ws conn...\n", .{}); self.logif("cooked closing ws conn...\n", .{});
// Implement reconnection logic here // Implement reconnection logic here
try self.client.close(.{ try self.client.close(.{
@ -486,22 +409,11 @@ pub fn send(self: *Self, data: anytype) !void {
var string = std.ArrayList(u8).init(fba.allocator()); var string = std.ArrayList(u8).init(fba.allocator());
try std.json.stringify(data, .{}, string.writer()); try std.json.stringify(data, .{}, string.writer());
//self.logif("{s}\n", .{string.items}); self.logif("{s}\n", .{string.items});
try self.client.write(try string.toOwnedSlice()); try self.client.write(try string.toOwnedSlice());
} }
pub inline fn getSequence(self: *Self) isize {
return self.sequence;
}
pub inline fn setSequence(self: *Self, new: isize) void {
self.mutex.lock();
defer self.mutex.unlock();
self.sequence = new;
}
pub fn handleEvent(self: *Self, name: []const u8, payload: []const u8) !void { pub fn handleEvent(self: *Self, name: []const u8, payload: []const u8) !void {
if (std.ascii.eqlIgnoreCase(name, "ready")) { if (std.ascii.eqlIgnoreCase(name, "ready")) {
var attempt = try self.parseJson(payload); var attempt = try self.parseJson(payload);
@ -564,7 +476,7 @@ pub fn handleEvent(self: *Self, name: []const u8, payload: []const u8) !void {
else => unreachable, else => unreachable,
}; };
} }
if (self.handler.ready) |event| event(ready); if (self.handler.ready) |event| event(self, ready);
} }
if (std.ascii.eqlIgnoreCase(name, "message_delete")) { if (std.ascii.eqlIgnoreCase(name, "message_delete")) {
@ -578,7 +490,7 @@ pub fn handleEvent(self: *Self, name: []const u8, payload: []const u8) !void {
.guild_id = try Shared.Snowflake.fromMaybe(obj.getT(.string, "guild_id")), .guild_id = try Shared.Snowflake.fromMaybe(obj.getT(.string, "guild_id")),
}; };
if (self.handler.message_delete) |event| event(data); if (self.handler.message_delete) |event| event(self, data);
} }
if (std.ascii.eqlIgnoreCase(name, "message_delete_bulk")) { if (std.ascii.eqlIgnoreCase(name, "message_delete_bulk")) {
@ -587,6 +499,7 @@ pub fn handleEvent(self: *Self, name: []const u8, payload: []const u8) !void {
const obj = attempt.getT(.object, "d").?; const obj = attempt.getT(.object, "d").?;
var ids = std.ArrayList([]const u8).init(self.allocator); var ids = std.ArrayList([]const u8).init(self.allocator);
defer ids.deinit();
while (obj.getT(.array, "ids").?.iterator().next()) |id| { while (obj.getT(.array, "ids").?.iterator().next()) |id| {
ids.append(id.string.value) catch unreachable; ids.append(id.string.value) catch unreachable;
@ -598,7 +511,7 @@ pub fn handleEvent(self: *Self, name: []const u8, payload: []const u8) !void {
.guild_id = try Shared.Snowflake.fromMaybe(obj.getT(.string, "guild_id")), .guild_id = try Shared.Snowflake.fromMaybe(obj.getT(.string, "guild_id")),
}; };
if (self.handler.message_delete_bulk) |event| event(data); if (self.handler.message_delete_bulk) |event| event(self, data);
} }
if (std.ascii.eqlIgnoreCase(name, "message_update")) { if (std.ascii.eqlIgnoreCase(name, "message_update")) {
@ -607,9 +520,9 @@ pub fn handleEvent(self: *Self, name: []const u8, payload: []const u8) !void {
const obj = attempt.getT(.object, "d").?; const obj = attempt.getT(.object, "d").?;
const message = try Parser.parseMessage(self.allocator, obj); const message = try Parser.parseMessage(self.allocator, obj);
defer if (message.referenced_message) |mptr| self.allocator.destroy(mptr); //defer if (message.referenced_message) |mptr| self.allocator.destroy(mptr);
if (self.handler.message_update) |event| event(message); if (self.handler.message_update) |event| event(self, message);
} }
if (std.ascii.eqlIgnoreCase(name, "message_create")) { if (std.ascii.eqlIgnoreCase(name, "message_create")) {
@ -617,20 +530,23 @@ pub fn handleEvent(self: *Self, name: []const u8, payload: []const u8) !void {
defer attempt.deinit(); defer attempt.deinit();
const obj = attempt.getT(.object, "d").?; const obj = attempt.getT(.object, "d").?;
self.logif("it worked {s}", .{name});
const message = try Parser.parseMessage(self.allocator, obj); const message = try Parser.parseMessage(self.allocator, obj);
defer if (message.referenced_message) |mptr| self.allocator.destroy(mptr); //defer if (message.referenced_message) |mptr| self.allocator.destroy(mptr);
self.logif("it worked {s} {?s}", .{ name, message.content });
if (self.handler.message_create) |event| event(message); if (self.handler.message_create) |event| event(self, message);
} else { } else {
if (self.handler.any) |anyEvent| anyEvent(payload); if (self.handler.any) |anyEvent| anyEvent(self, payload);
} }
} }
pub fn loginWithEmail(allocator: mem.Allocator, settings: struct { login: []const u8, password: []const u8, run: GatewayDispatchEvent, log: Log }) !Self { pub fn loginWithEmail(allocator: mem.Allocator, settings: struct { login: []const u8, password: []const u8, run: GatewayDispatchEvent(*Self), log: Log }) !Self {
const AUTH_LOGIN = "https://discord.com/api/v9/auth/login"; const AUTH_LOGIN = "https://discord.com/api/v9/auth/login";
const WS_CONNECT = "gateway.discord.gg"; const WS_CONNECT = "gateway.discord.gg";
var body = std.ArrayList(u8).init(allocator); var body = std.ArrayList(u8).init(allocator);
defer body.deinit();
const AuthLoginResponse = struct { user_id: []const u8, token: []const u8, user_settings: struct { locale: []const u8, theme: []const u8 } }; const AuthLoginResponse = struct { user_id: []const u8, token: []const u8, user_settings: struct { locale: []const u8, theme: []const u8 } };
@ -654,9 +570,8 @@ pub fn loginWithEmail(allocator: mem.Allocator, settings: struct { login: []cons
var client = HttpClient{ .allocator = allocator }; var client = HttpClient{ .allocator = allocator };
defer client.deinit(); defer client.deinit();
const res = try client.fetch(fetch_options); _ = try client.fetch(fetch_options);
if (res.status == std.http.Status.ok) {
const response = try std.json.parseFromSliceLeaky(AuthLoginResponse, allocator, try body.toOwnedSlice(), .{}); const response = try std.json.parseFromSliceLeaky(AuthLoginResponse, allocator, try body.toOwnedSlice(), .{});
return .{ return .{
@ -666,7 +581,6 @@ pub fn loginWithEmail(allocator: mem.Allocator, settings: struct { login: []cons
// maybe there is a better way to do this // maybe there is a better way to do this
.client = try Self._connect_ws(allocator, WS_CONNECT), .client = try Self._connect_ws(allocator, WS_CONNECT),
.session_id = undefined, .session_id = undefined,
.sequence = 0,
.info = GatewayBotInfo{ .url = "wss://" ++ WS_CONNECT, .shards = 0, .session_start_limit = null }, .info = GatewayBotInfo{ .url = "wss://" ++ WS_CONNECT, .shards = 0, .session_start_limit = null },
.handler = settings.run, .handler = settings.run,
.log = settings.log, .log = settings.log,
@ -689,14 +603,11 @@ pub fn loginWithEmail(allocator: mem.Allocator, settings: struct { login: []cons
.client_event_source = null, .client_event_source = null,
}, },
}; };
} else {
return error.effn;
}
} }
inline fn logif(self: *Self, comptime format: []const u8, args: anytype) void { inline fn logif(self: *Self, comptime format: []const u8, args: anytype) void {
switch (self.log) { switch (self.log) {
.yes => debug.info(format, args), .yes => Internal.debug.info(format, args),
.no => {}, .no => {},
} }
} }

View File

@ -89,7 +89,7 @@ pub const Snowflake = struct {
return array.slice(); return array.slice();
} }
pub fn value(self: *Snowflake) u64 { pub fn value(self: Snowflake) u64 {
return self.value; return self.id;
} }
}; };

View File

@ -1,25 +1,42 @@
const Shard = @import("discord.zig").Shard; const Shard = @import("discord.zig").Shard;
const Discord = @import("discord.zig").Discord; const Discord = @import("discord.zig").Discord;
const Internal = @import("discord.zig").Internal;
const Intents = Discord.Intents; const Intents = Discord.Intents;
const Thread = std.Thread; const Thread = std.Thread;
const std = @import("std"); const std = @import("std");
fn ready(payload: Discord.Ready) void { fn ready(_: *Shard, payload: Discord.Ready) void {
std.debug.print("logged in as {s}\n", .{payload.user.username}); std.debug.print("logged in as {s}\n", .{payload.user.username});
} }
fn message_create(message: Discord.Message) void { fn message_create(session: *Shard, message: Discord.Message) void {
std.debug.print("captured: {?s} send by {s}\n", .{ message.content, message.author.username }); std.debug.print("captured: {?s} send by {s}\n", .{ message.content, message.author.username });
if (message.content) |mc| if (std.ascii.eqlIgnoreCase(mc, "!hi")) {
var req = Shard.FetchReq.init(session.allocator, session.token);
defer req.deinit();
const payload: Discord.Partial(Discord.CreateMessage) = .{ .content = "Hi, I'm hang man, your personal assistant" };
const json = std.json.stringifyAlloc(session.allocator, payload, .{}) catch unreachable;
defer session.allocator.free(json);
const path = std.fmt.allocPrint(session.allocator, "/channels/{d}/messages", .{message.channel_id.value()}) catch unreachable;
_ = req.makeRequest(
.POST,
path,
json,
) catch unreachable;
};
} }
pub fn main() !void { pub fn main() !void {
const allocator = std.heap.c_allocator; var gpa = std.heap.GeneralPurposeAllocator(.{}){};
const token = std.posix.getenv("TOKEN") orelse unreachable; const allocator = gpa.allocator();
var handler = try Shard.login(allocator, .{ var handler = try Shard.login(allocator, .{
.token = token, .token = std.posix.getenv("TOKEN") orelse unreachable,
.intents = Intents.fromRaw(37379), .intents = Intents.fromRaw(37379),
.run = Shard.GatewayDispatchEvent{ .run = Internal.GatewayDispatchEvent(*Shard){
.message_create = &message_create, .message_create = &message_create,
.ready = &ready, .ready = &ready,
}, },
@ -27,6 +44,6 @@ pub fn main() !void {
}); });
errdefer handler.deinit(); errdefer handler.deinit();
const t = try Thread.spawn(.{}, Shard.readMessage, .{ &handler, null }); const event_listener = try std.Thread.spawn(.{}, Shard.readMessage, .{ &handler, null });
defer t.join(); event_listener.join();
} }

View File

@ -2703,8 +2703,8 @@ pub const Message = struct {
/// Reactions to the message /// Reactions to the message
reactions: []?Reaction, reactions: []?Reaction,
/// Used for validating a message was sent /// Used for validating a message was sent
nonce: union(enum) { nonce: ?union(enum) {
int: ?isize, int: isize,
string: []const u8, string: []const u8,
}, },
/// Whether this message is pinned /// Whether this message is pinned
@ -4750,7 +4750,7 @@ pub const CreateMessage = struct {
/// The message contents (up to 2000 characters) /// The message contents (up to 2000 characters)
content: ?[]const u8, content: ?[]const u8,
/// Can be used to verify a message was sent (up to 25 characters). Value will appear in the Message Create event. /// Can be used to verify a message was sent (up to 25 characters). Value will appear in the Message Create event.
nonce: union(enum) { nonce: ?union(enum) {
string: ?[]const u8, string: ?[]const u8,
integer: isize, integer: isize,
}, },