added sharding support

This commit is contained in:
Yuzu 2024-11-09 20:48:50 -05:00
parent 425c8646d7
commit ed67067f6b
11 changed files with 609 additions and 294 deletions

View File

@ -20,11 +20,6 @@ pub fn build(b: *std.Build) void {
.optimize = optimize, .optimize = optimize,
}); });
const zig_tls = b.dependency("zig-tls", .{
.target = target,
.optimize = optimize,
});
const zlib = b.dependency("zlib", .{}); const zlib = b.dependency("zlib", .{});
const zmpl = b.dependency("zmpl", .{ const zmpl = b.dependency("zmpl", .{
@ -48,14 +43,12 @@ pub fn build(b: *std.Build) void {
// now install your own executable after it's built correctly // now install your own executable after it's built correctly
dzig.addImport("ws", websocket.module("websocket")); dzig.addImport("ws", websocket.module("websocket"));
dzig.addImport("tls12", zig_tls.module("zig-tls12"));
dzig.addImport("zlib", zlib.module("zlib")); dzig.addImport("zlib", zlib.module("zlib"));
dzig.addImport("zmpl", zmpl.module("zmpl")); dzig.addImport("zmpl", zmpl.module("zmpl"));
dzig.addImport("deque", deque.module("zig-deque")); dzig.addImport("deque", deque.module("zig-deque"));
marin.root_module.addImport("discord.zig", dzig); marin.root_module.addImport("discord.zig", dzig);
marin.root_module.addImport("ws", websocket.module("websocket")); marin.root_module.addImport("ws", websocket.module("websocket"));
marin.root_module.addImport("tls12", zig_tls.module("zig-tls12"));
marin.root_module.addImport("zlib", zlib.module("zlib")); marin.root_module.addImport("zlib", zlib.module("zlib"));
marin.root_module.addImport("zmpl", zmpl.module("zmpl")); marin.root_module.addImport("zmpl", zmpl.module("zmpl"));
marin.root_module.addImport("deque", deque.module("zig-deque")); marin.root_module.addImport("deque", deque.module("zig-deque"));

View File

@ -27,21 +27,17 @@
.url = "https://github.com/magurotuna/zig-deque/archive/refs/heads/main.zip", .url = "https://github.com/magurotuna/zig-deque/archive/refs/heads/main.zip",
.hash = "1220d1bedf7d5cfc7475842b3d4e8f03f1308be2e724a036677cceb5c4db13c3da3d", .hash = "1220d1bedf7d5cfc7475842b3d4e8f03f1308be2e724a036677cceb5c4db13c3da3d",
}, },
.@"zig-tls" = .{
.url = "https://github.com/yuzudev/zig-tls12/archive/refs/heads/main.zip",
.hash = "122079aad9eebc5945e207715c50722bbeca34e7b7a5876b3a322fad6468b28f2ed3",
},
.websocket = .{
.url = "https://github.com/yuzudev/websocket.zig/archive/refs/heads/master.zip",
.hash = "122062d9dda015ba25780d00697e0e2c1cbc3ffa5b5eae55c9ea28b39b6d99ef1d85",
},
.zlib = .{ .zlib = .{
.url = "https://github.com/yuzudev/zig-zlib/archive/refs/heads/main.zip", .url = "https://github.com/yuzudev/zig-zlib/archive/refs/heads/main.zip",
.hash = "1220cd041e8d04f1da9d6f46d0438f4e6809b113ba3454fffdaae96b59d2b35a6b2b", .hash = "1220cd041e8d04f1da9d6f46d0438f4e6809b113ba3454fffdaae96b59d2b35a6b2b",
}, },
.zmpl = .{ .zmpl = .{
.url = "https://github.com/jetzig-framework/zmpl/archive/refs/heads/main.zip", .url = "https://github.com/jetzig-framework/zmpl/archive/refs/heads/main.zip",
.hash = "12209a9d6f652ce712448da1bd41517cceef8eebf86884c1bb29c02fc3a933713afd", .hash = "1220798a4647e3b0766aad653830a2601e11c567ba6bfe83e526eb91d04a6c45f7d8",
},
.websocket = .{
.url = "https://github.com/yuzudev/websocket.zig/archive/refs/heads/master.zip",
.hash = "12207c03624f9f5a1c444bde3d484a9b1e927a902067fded98364b714de412d318e0",
}, },
}, },
.paths = .{ .paths = .{

View File

@ -1,4 +1,80 @@
pub const Discord = @import("types.zig"); pub const Discord = @import("types.zig");
const Intents = Discord.Intents;
pub const Shard = @import("shard.zig"); pub const Shard = @import("shard.zig");
pub const Internal = @import("internal.zig"); pub const Internal = @import("internal.zig");
pub const Session = @import("session.zig"); const Log = Internal.Log;
const GatewayDispatchEvent = Internal.GatewayDispatchEvent;
pub const Sharder = @import("sharder.zig");
const SessionOptions = Sharder.SessionOptions;
pub const Shared = @import("shared.zig");
const GatewayBotInfo = Shared.GatewayBotInfo;
pub const FetchReq = @import("http.zig").FetchReq;
const std = @import("std");
const mem = std.mem;
const http = std.http;
const json = std.json;
pub const Client = struct {
allocator: mem.Allocator,
sharder: Sharder,
pub fn init(allocator: mem.Allocator) Client {
return .{
.allocator = allocator,
.sharder = undefined,
};
}
pub fn deinit(self: *Client) void {
self.sharder.deinit();
}
pub fn start(self: *Client, settings: struct {
token: []const u8,
intents: Intents,
options: struct {
spawn_shard_delay: u64 = 5300,
total_shards: usize = 1,
shard_start: usize = 0,
shard_end: usize = 1,
},
run: GatewayDispatchEvent(*Shard),
log: Log,
}) !void {
var req = FetchReq.init(self.allocator, settings.token);
defer req.deinit();
const res = try req.makeRequest(.GET, "/gateway/bot", null);
const body = try req.body.toOwnedSlice();
defer self.allocator.free(body);
// check status idk
if (res.status != http.Status.ok) {
@panic("we are cooked\n");
}
const parsed = try json.parseFromSlice(GatewayBotInfo, self.allocator, body, .{});
defer parsed.deinit();
self.sharder = try Sharder.init(self.allocator, .{
.token = settings.token,
.intents = settings.intents,
.run = settings.run,
.options = SessionOptions{
.info = parsed.value,
.shard_start = settings.options.shard_start,
.shard_end = @intCast(parsed.value.shards),
.total_shards = @intCast(parsed.value.shards),
.spawn_shard_delay = settings.options.spawn_shard_delay,
},
.log = settings.log,
});
try self.sharder.spawnShards();
}
};

49
src/http.zig Normal file
View File

@ -0,0 +1,49 @@
const std = @import("std");
const mem = std.mem;
const http = std.http;
pub const BASE_URL = "https://discord.com/api/v10";
pub const FetchReq = struct {
allocator: mem.Allocator,
token: []const u8,
client: http.Client,
body: std.ArrayList(u8),
pub fn init(allocator: mem.Allocator, token: []const u8) FetchReq {
const client = http.Client{ .allocator = allocator };
return FetchReq{
.allocator = allocator,
.client = client,
.body = std.ArrayList(u8).init(allocator),
.token = token,
};
}
pub fn deinit(self: *FetchReq) void {
self.client.deinit();
self.body.deinit();
}
pub fn makeRequest(self: *FetchReq, method: http.Method, path: []const u8, to_post: ?[]const u8) !http.Client.FetchResult {
var fetch_options = http.Client.FetchOptions{
.location = http.Client.FetchOptions.Location{
.url = try std.fmt.allocPrint(self.allocator, "{s}{s}", .{ BASE_URL, path }),
},
.extra_headers = &[_]http.Header{
http.Header{ .name = "Accept", .value = "application/json" },
http.Header{ .name = "Content-Type", .value = "application/json" },
http.Header{ .name = "Authorization", .value = self.token },
},
.method = method,
.response_storage = .{ .dynamic = &self.body },
};
if (to_post != null) {
fetch_options.payload = to_post;
}
const res = try self.client.fetch(fetch_options);
return res;
}
};

View File

@ -1,89 +1,105 @@
const std = @import("std"); const std = @import("std");
const mem = std.mem; const mem = std.mem;
const Deque = @import("deque"); const Deque = @import("deque").Deque;
const Discord = @import("types.zig"); const Discord = @import("types.zig");
const builtin = @import("builtin");
const IdentifyProperties = @import("shared.zig").IdentifyProperties;
pub const debug = std.log.scoped(.@"discord.zig"); pub const debug = std.log.scoped(.@"discord.zig");
pub const Log = union(enum) { yes, no }; pub const Log = union(enum) { yes, no };
pub const default_identify_properties = IdentifyProperties{
.os = @tagName(builtin.os.tag),
.browser = "discord.zig",
.device = "discord.zig",
};
/// inspired from: /// inspired from:
/// https://github.com/tiramisulabs/seyfert/blob/main/src/websocket/structures/timeout.ts /// https://github.com/tiramisulabs/seyfert/blob/main/src/websocket/structures/timeout.ts
pub const ConnectQueue = struct { pub fn ConnectQueue(comptime T: type) type {
dequeue: Deque(*const fn () void), return struct {
allocator: mem.Allocator, pub const RequestWithShard = struct {
remaining: usize, callback: *const fn (self: *RequestWithShard) anyerror!void,
interval_time: u64 = 5000, shard: T,
running: bool,
concurrency: usize = 1,
pub fn init(allocator: mem.Allocator, concurrency: usize, interval_time: u64) !ConnectQueue {
return .{
.allocator = allocator,
.dequeue = try Deque(*const fn () void).init(allocator),
.remaining = concurrency,
.interval_time = interval_time,
.concurrency = concurrency,
}; };
}
pub fn deinit(self: *ConnectQueue) void { dequeue: Deque(RequestWithShard),
self.dequeue.deinit(); allocator: mem.Allocator,
} remaining: usize,
interval_time: u64 = 5000,
running: bool = false,
concurrency: usize = 1,
pub fn push(self: *ConnectQueue, callback: *const fn () void) !void { pub fn init(allocator: mem.Allocator, concurrency: usize, interval_time: u64) !ConnectQueue(T) {
if (self.remaining == 0) { return .{
return self.dequeue.pushBack(callback); .allocator = allocator,
} .dequeue = try Deque(RequestWithShard).init(allocator),
self.remaining -= 1; .remaining = concurrency,
.interval_time = interval_time,
if (!self.running) { .concurrency = concurrency,
self.startInterval(); };
self.running = true;
} }
if (self.dequeue.items.len < self.concurrency) { pub fn deinit(self: *ConnectQueue(T)) void {
@call(.auto, callback, .{}); self.dequeue.deinit();
return;
} }
return self.dequeue.pushBack(callback); pub fn push(self: *ConnectQueue(T), req: RequestWithShard) !void {
} if (self.remaining == 0) {
return self.dequeue.pushBack(req);
}
self.remaining -= 1;
fn startInterval(self: *ConnectQueue) void { if (!self.running) {
while (self.running) { try self.startInterval();
std.Thread.sleep(std.time.ns_per_ms * (self.interval_time / self.concurrency)); self.running = true;
const callback: ?*const fn () void = self.dequeue.popFront(); }
while (self.dequeue.items.len == 0 and callback == null) {} if (self.dequeue.len() < self.concurrency) {
// perhaps store this?
if (callback) |cb| { const ptr = try self.allocator.create(RequestWithShard);
@call(.auto, cb, .{}); ptr.* = req;
try @call(.auto, req.callback, .{ptr});
return; return;
} }
if (self.remaining < self.concurrency) { return self.dequeue.pushBack(req);
self.remaining += 1; }
}
if (self.dequeue.len() == 0) { fn startInterval(self: *ConnectQueue(T)) !void {
self.running = false; while (self.running) {
std.Thread.sleep(std.time.ns_per_ms * (self.interval_time / self.concurrency));
const req: ?RequestWithShard = self.dequeue.popFront();
while (self.dequeue.len() == 0 and req == null) {}
if (req) |r| {
const ptr = try self.allocator.create(RequestWithShard);
ptr.* = r;
try @call(.auto, r.callback, .{ptr});
return;
}
if (self.remaining < self.concurrency) {
self.remaining += 1;
}
if (self.dequeue.len() == 0) {
self.running = false;
}
} }
} }
} };
};
fn lessthan(_: void, a: RequestWithPrio, b: RequestWithPrio) void {
return std.math.order(a, b);
} }
pub const Bucket = struct { pub const Bucket = struct {
/// The queue of requests to acquire an available request. Mapped by (shardId, RequestWithPrio) /// The queue of requests to acquire an available request. Mapped by (shardId, RequestWithPrio)
queue: std.PriorityQueue(RequestWithPrio, void, lessthan), queue: std.PriorityQueue(RequestWithPrio, void, Bucket.lessthan),
limit: usize, limit: usize,
refillInterval: u64, refill_interval: u64,
refillAmount: usize, refill_amount: usize,
/// The amount of requests that have been used up already. /// The amount of requests that have been used up already.
used: usize = 0, used: usize = 0,
@ -92,21 +108,82 @@ pub const Bucket = struct {
processing: bool = false, processing: bool = false,
/// Whether the timeout should be killed because there is already one running /// Whether the timeout should be killed because there is already one running
shouldStop: bool = false, should_stop: std.atomic.Value(bool) = std.atomic.Value(bool).init(false),
/// The timestamp in milliseconds when the next refill is scheduled. /// The timestamp in milliseconds when the next refill is scheduled.
refillsAt: ?u64, refills_at: ?u64 = null,
/// comes in handy pub const RequestWithPrio = struct {
m: std.Thread.Mutex = .{}, callback: *const fn () void,
c: std.Thread.Condition = .{}, priority: u32 = 1,
};
fn timeout(self: *Bucket) void { fn lessthan(_: void, a: RequestWithPrio, b: RequestWithPrio) std.math.Order {
_ = self; return std.math.order(a.priority, b.priority);
} }
pub fn processQueue() !void {} pub fn init(allocator: mem.Allocator, limit: usize, refill_interval: u64, refill_amount: usize) Bucket {
pub fn refill() void {} return .{
.queue = std.PriorityQueue(RequestWithPrio, void, lessthan).init(allocator, {}),
.limit = limit,
.refill_interval = refill_interval,
.refill_amount = refill_amount,
};
}
fn remaining(self: *Bucket) usize {
if (self.limit < self.used) {
return 0;
} else {
return self.limit - self.used;
}
}
pub fn refill(self: *Bucket) std.Thread.SpawnError!void {
// Lower the used amount by the refill amount
self.used = if (self.refill_amount > self.used) 0 else self.used - self.refill_amount;
// Reset the refills_at timestamp since it just got refilled
self.refills_at = null;
if (self.used > 0) {
if (self.should_stop.load(.monotonic) == true) {
self.should_stop.store(false, .monotonic);
}
const thread = try std.Thread.spawn(.{}, Bucket.timeout, .{self});
thread.detach;
self.refills_at = std.time.milliTimestamp() + self.refill_interval;
}
}
fn timeout(self: *Bucket) void {
while (!self.should_stop.load(.monotonic)) {
self.refill();
std.time.sleep(std.time.ns_per_ms * self.refill_interval);
}
}
pub fn processQueue(self: *Bucket) std.Thread.SpawnError!void {
if (self.processing) return;
while (self.queue.remove()) |first_element| {
if (self.remaining() != 0) {
first_element.callback();
self.used += 1;
if (!self.should_stop.load(.monotonic)) {
const thread = try std.Thread.spawn(.{}, Bucket.timeout, .{self});
thread.detach;
self.refills_at = std.time.milliTimestamp() + self.refill_interval;
}
} else if (self.refills_at) |ra| {
const now = std.time.milliTimestamp();
if (ra > now) std.time.sleep(std.time.ns_per_ms * (ra - now));
}
}
self.processing = false;
}
pub fn acquire(self: *Bucket, rq: RequestWithPrio) !void { pub fn acquire(self: *Bucket, rq: RequestWithPrio) !void {
try self.queue.add(rq); try self.queue.add(rq);
@ -114,11 +191,6 @@ pub const Bucket = struct {
} }
}; };
pub const RequestWithPrio = struct {
callback: *const fn () void,
priority: u32,
};
pub fn GatewayDispatchEvent(comptime T: type) type { pub fn GatewayDispatchEvent(comptime T: type) type {
return struct { return struct {
// TODO: implement // application_command_permissions_update: null = null, // TODO: implement // application_command_permissions_update: null = null,

View File

@ -70,8 +70,6 @@ pub fn parseMessage(allocator: mem.Allocator, obj: *zmpl.Data.Object) !Discord.M
try mentions.append(try parseUser(allocator, &m.object)); try mentions.append(try parseUser(allocator, &m.object));
} }
std.debug.print("parsing mentions done\n", .{});
// parse member // parse member
const member = if (obj.getT(.object, "member")) |m| try parseMember(allocator, m) else null; const member = if (obj.getT(.object, "member")) |m| try parseMember(allocator, m) else null;

View File

@ -1,80 +0,0 @@
const Intents = @import("types.zig").Intents;
const GatewayBotInfo = @import("shared.zig").GatewayBotInfo;
const Shared = @import("shared.zig");
const IdentifyProperties = Shared.IdentifyProperties;
const Internal = @import("internal.zig");
const ConnectQueue = Internal.ConnectQueue;
const GatewayDispatchEvent = Internal.GatewayDispatchEvent;
const Shard = @import("shard.zig");
const std = @import("std");
const mem = std.mem;
const debug = Internal.debug;
const Self = @This();
token: []const u8,
intents: Intents,
allocator: mem.Allocator,
connectQueue: ConnectQueue,
shards: std.AutoArrayHashMap(usize, Shard),
run: GatewayDispatchEvent(*Self),
/// spawn buckets in order
/// https://discord.com/developers/docs/events/gateway#sharding-max-concurrency
fn spawnBuckers(self: *Self) !void {
_ = self;
}
/// creates a shard and stores it
fn create(self: *Self, shard_id: usize) !Shard {
const shard = try self.shards.getOrPutValue(shard_id, try Shard.login(self.allocator, .{
.token = self.token,
.intents = self.intents,
.run = self.run,
.log = self.log,
}));
return shard;
}
pub const ShardDetails = struct {
/// Bot token which is used to connect to Discord */
token: []const u8,
///
/// The URL of the gateway which should be connected to.
///
url: []const u8 = "wss://gateway.discord.gg",
///
/// The gateway version which should be used.
/// @default 10
///
version: ?usize = 10,
///
/// The calculated intent value of the events which the shard should receive.
///
intents: Intents,
///
/// Identify properties to use
///
properties: ?IdentifyProperties,
};
pub const SessionOptions = struct {
/// Important data which is used by the manager to connect shards to the gateway. */
info: GatewayBotInfo,
/// Delay in milliseconds to wait before spawning next shard. OPTIMAL IS ABOVE 5100. YOU DON'T WANT TO HIT THE RATE LIMIT!!!
spawnShardDelay: ?u64 = 5300,
/// Total amount of shards your bot uses. Useful for zero-downtime updates or resharding.
totalShards: ?usize = 1,
shardStart: ?usize,
shardEnd: ?usize,
///
/// The payload handlers for messages on the shard.
/// TODO:
/// handlePayload: (shardId: number, packet: GatewayDispatchPayload): unknown;
///
resharding: ?struct {
interval: u64,
percentage: usize,
},
};

View File

@ -1,6 +1,5 @@
const ws = @import("ws"); const ws = @import("ws");
const builtin = @import("builtin"); const builtin = @import("builtin");
const HttpClient = @import("tls12").HttpClient;
const std = @import("std"); const std = @import("std");
const net = std.net; const net = std.net;
@ -27,68 +26,21 @@ const IdentifyProperties = Shared.IdentifyProperties;
const GatewayInfo = Shared.GatewayInfo; const GatewayInfo = Shared.GatewayInfo;
const GatewayBotInfo = Shared.GatewayBotInfo; const GatewayBotInfo = Shared.GatewayBotInfo;
const GatewaySessionStartLimit = Shared.GatewaySessionStartLimit; const GatewaySessionStartLimit = Shared.GatewaySessionStartLimit;
const ShardDetails = Shared.ShardDetails;
const Internal = @import("internal.zig"); const Internal = @import("internal.zig");
const Log = Internal.Log; const Log = Internal.Log;
const GatewayDispatchEvent = Internal.GatewayDispatchEvent; const GatewayDispatchEvent = Internal.GatewayDispatchEvent;
const Bucket = Internal.Bucket;
const default_identify_properties = Internal.default_identify_properties;
const FetchRequest = @import("http.zig").FetchReq;
const ShardSocketCloseCodes = enum(u16) { const ShardSocketCloseCodes = enum(u16) {
Shutdown = 3000, Shutdown = 3000,
ZombiedConnection = 3010, ZombiedConnection = 3010,
}; };
const BASE_URL = "https://discord.com/api/v10";
pub const FetchReq = struct {
allocator: mem.Allocator,
token: []const u8,
client: HttpClient,
body: std.ArrayList(u8),
pub fn init(allocator: mem.Allocator, token: []const u8) FetchReq {
const client = HttpClient{ .allocator = allocator };
return FetchReq{
.allocator = allocator,
.client = client,
.body = std.ArrayList(u8).init(allocator),
.token = token,
};
}
pub fn deinit(self: *FetchReq) void {
self.client.deinit();
self.body.deinit();
}
pub fn makeRequest(self: *FetchReq, method: http.Method, path: []const u8, to_post: ?[]const u8) !HttpClient.FetchResult {
var fetch_options = HttpClient.FetchOptions{
.location = HttpClient.FetchOptions.Location{
.url = try std.fmt.allocPrint(self.allocator, "{s}{s}", .{ BASE_URL, path }),
},
.extra_headers = &[_]http.Header{
http.Header{ .name = "Accept", .value = "application/json" },
http.Header{ .name = "Content-Type", .value = "application/json" },
http.Header{ .name = "Authorization", .value = self.token },
},
.method = method,
.response_storage = .{ .dynamic = &self.body },
};
if (to_post != null) {
fetch_options.payload = to_post;
}
const res = try self.client.fetch(fetch_options);
return res;
}
};
const _default_properties = IdentifyProperties{
.os = @tagName(builtin.os.tag),
.browser = "discord.zig",
.device = "discord.zig",
};
const Heart = struct { const Heart = struct {
heartbeatInterval: u64, heartbeatInterval: u64,
ack: bool, ack: bool,
@ -96,16 +48,27 @@ const Heart = struct {
lastBeat: i64, lastBeat: i64,
}; };
const RatelimitOptions = struct {
max_requests_per_ratelimit_tick: ?usize = 120,
ratelimit_reset_interval: u64 = 60000,
};
pub const ShardOptions = struct {
ratelimit_options: RatelimitOptions = .{},
};
id: usize,
client: ws.Client, client: ws.Client,
token: []const u8, details: ShardDetails,
intents: Intents,
//heart: Heart = //heart: Heart =
allocator: mem.Allocator, allocator: mem.Allocator,
resume_gateway_url: ?[]const u8 = null, resume_gateway_url: ?[]const u8 = null,
info: GatewayBotInfo, info: GatewayBotInfo,
bucket: Bucket,
ratelimit_options: RatelimitOptions,
properties: IdentifyProperties = _default_properties,
session_id: ?[]const u8, session_id: ?[]const u8,
sequence: std.atomic.Value(isize) = std.atomic.Value(isize).init(0), sequence: std.atomic.Value(isize) = std.atomic.Value(isize).init(0),
heart: Heart = .{ .heartbeatInterval = 45000, .ack = false, .lastBeat = 0 }, heart: Heart = .{ .heartbeatInterval = 45000, .ack = false, .lastBeat = 0 },
@ -135,86 +98,87 @@ pub fn resumable(self: *Self) bool {
pub fn resume_(self: *Self) !void { pub fn resume_(self: *Self) !void {
const data = .{ .op = @intFromEnum(Opcode.Resume), .d = .{ const data = .{ .op = @intFromEnum(Opcode.Resume), .d = .{
.token = self.token, .token = self.details.token,
.session_id = self.session_id, .session_id = self.session_id,
.seq = self.sequence.load(.monotonic), .seq = self.sequence.load(.monotonic),
} }; } };
try self.send(data); try self.send(false, data);
} }
inline fn gatewayUrl(self: ?*Self) []const u8 { inline fn gatewayUrl(self: ?*Self) []const u8 {
return if (self) |s| (s.resume_gateway_url orelse s.info.url)["wss://".len..] else "gateway.discord.gg"; return if (self) |s| (s.resume_gateway_url orelse s.info.url)["wss://".len..] else "gateway.discord.gg";
} }
// identifies in order to connect to Discord and get the online status, this shall be done on hello perhaps /// identifies in order to connect to Discord and get the online status, this shall be done on hello perhaps
fn identify(self: *Self, properties: ?IdentifyProperties) !void { fn identify(self: *Self, properties: ?IdentifyProperties) !void {
self.logif("intents: {d}", .{self.intents.toRaw()}); self.logif("intents: {d}", .{self.details.intents.toRaw()});
if (self.intents.toRaw() != 0) { if (self.details.intents.toRaw() != 0) {
const data = .{ const data = .{
.op = @intFromEnum(Opcode.Identify), .op = @intFromEnum(Opcode.Identify),
.d = .{ .d = .{
.intents = self.intents.toRaw(), .intents = self.details.intents.toRaw(),
.properties = properties orelse Self._default_properties, .properties = properties orelse default_identify_properties,
.token = self.token, .token = self.details.token,
}, },
}; };
try self.send(data); try self.send(false, data);
} else { } else {
const data = .{ const data = .{
.op = @intFromEnum(Opcode.Identify), .op = @intFromEnum(Opcode.Identify),
.d = .{ .d = .{
.capabilities = 30717, .capabilities = 30717,
.properties = properties orelse Self._default_properties, .properties = properties orelse default_identify_properties,
.token = self.token, .token = self.details.token,
}, },
}; };
try self.send(data); try self.send(false, data);
} }
} }
// asks /gateway/bot initializes both the ws client and the http client pub fn init(allocator: mem.Allocator, shard_id: usize, settings: struct {
pub fn login(allocator: mem.Allocator, args: struct {
token: []const u8, token: []const u8,
intents: Intents, intents: Intents,
options: ShardOptions,
run: GatewayDispatchEvent(*Self), run: GatewayDispatchEvent(*Self),
log: Log, log: Log,
}) !Self { }) zlib.Error!Self {
var req = FetchReq.init(allocator, args.token); return Self{
defer req.deinit(); .info = .{ .url = "wss://gateway.discord.gg", .shards = 1, .session_start_limit = null },
.id = shard_id,
const res = try req.makeRequest(.GET, "/gateway/bot", null);
const body = try req.body.toOwnedSlice();
defer allocator.free(body);
// check status idk
if (res.status != http.Status.ok) {
@panic("we are cooked\n");
}
const parsed = try json.parseFromSlice(GatewayBotInfo, allocator, body, .{});
defer parsed.deinit();
const url = parsed.value.url["wss://".len..];
var self: Self = .{
.allocator = allocator, .allocator = allocator,
.token = args.token, .details = ShardDetails{
.intents = args.intents, .token = settings.token,
.intents = settings.intents,
},
.client = undefined,
// maybe there is a better way to do this // maybe there is a better way to do this
.client = try Self._connect_ws(allocator, url),
.session_id = undefined, .session_id = undefined,
.info = parsed.value, .handler = settings.run,
.handler = args.run, .log = settings.log,
.log = args.log,
.packets = std.ArrayList(u8).init(allocator), .packets = std.ArrayList(u8).init(allocator),
.inflator = try zlib.Decompressor.init(allocator, .{ .header = .zlib_or_gzip }), .inflator = try zlib.Decompressor.init(allocator, .{ .header = .zlib_or_gzip }),
.bucket = Bucket.init(
allocator,
Self.calculateSafeRequests(settings.options.ratelimit_options),
settings.options.ratelimit_options.ratelimit_reset_interval,
Self.calculateSafeRequests(settings.options.ratelimit_options),
),
.ratelimit_options = settings.options.ratelimit_options,
}; };
}
const event_listener = try std.Thread.spawn(.{}, Self.readMessage, .{ &self, null }); inline fn calculateSafeRequests(options: RatelimitOptions) usize {
event_listener.join(); const safe_requests =
@as(f64, @floatFromInt(options.max_requests_per_ratelimit_tick orelse 120)) -
@ceil(@as(f64, @floatFromInt(options.ratelimit_reset_interval)) / 30000.0) * 2;
return self; if (safe_requests < 0) {
return 0;
}
return @intFromFloat(safe_requests);
} }
inline fn _connect_ws(allocator: mem.Allocator, url: []const u8) !ws.Client { inline fn _connect_ws(allocator: mem.Allocator, url: []const u8) !ws.Client {
@ -241,8 +205,8 @@ pub fn deinit(self: *Self) void {
self.logif("killing the whole bot", .{}); self.logif("killing the whole bot", .{});
} }
// listens for messages /// listens for messages
pub fn readMessage(self: *Self, _: anytype) !void { fn readMessage(self: *Self, _: anytype) !void {
try self.client.readTimeout(0); try self.client.readTimeout(0);
while (true) { while (true) {
@ -301,7 +265,7 @@ pub fn readMessage(self: *Self, _: anytype) !void {
try self.resume_(); try self.resume_();
return; return;
} else { } else {
try self.identify(self.properties); try self.identify(self.details.properties);
} }
var prng = std.Random.DefaultPrng.init(0); var prng = std.Random.DefaultPrng.init(0);
@ -321,7 +285,7 @@ pub fn readMessage(self: *Self, _: anytype) !void {
self.logif("sending requested heartbeat", .{}); self.logif("sending requested heartbeat", .{});
self.ws_mutex.lock(); self.ws_mutex.lock();
defer self.ws_mutex.unlock(); defer self.ws_mutex.unlock();
try self.send(.{ .op = @intFromEnum(Opcode.Heartbeat), .d = self.sequence.load(.monotonic) }); try self.send(false, .{ .op = @intFromEnum(Opcode.Heartbeat), .d = self.sequence.load(.monotonic) });
}, },
Opcode.Reconnect => { Opcode.Reconnect => {
self.logif("reconnecting", .{}); self.logif("reconnecting", .{});
@ -373,7 +337,7 @@ pub fn heartbeat(self: *Self, initial_jitter: f64) !void {
const seq = self.sequence.load(.monotonic); const seq = self.sequence.load(.monotonic);
self.logif("sending unrequested heartbeat", .{}); self.logif("sending unrequested heartbeat", .{});
self.ws_mutex.lock(); self.ws_mutex.lock();
try self.send(.{ .op = @intFromEnum(Opcode.Heartbeat), .d = seq }); try self.send(false, .{ .op = @intFromEnum(Opcode.Heartbeat), .d = seq });
self.ws_mutex.unlock(); self.ws_mutex.unlock();
if ((std.time.milliTimestamp() - last) > (5000 * self.heart.heartbeatInterval)) { if ((std.time.milliTimestamp() - last) > (5000 * self.heart.heartbeatInterval)) {
@ -390,16 +354,25 @@ pub inline fn reconnect(self: *Self) !void {
try self.connect(); try self.connect();
} }
pub fn connect(self: *Self) !void { pub const ConnectError =
std.net.TcpConnectToAddressError || std.crypto.tls.Client.InitError(std.net.Stream) || std.net.Stream.ReadError || std.net.IPParseError || std.crypto.Certificate.Bundle.RescanError || std.net.TcpConnectToHostError || std.fmt.BufPrintError || mem.Allocator.Error;
pub fn connect(self: *Self) ConnectError!void {
//std.time.sleep(std.time.ms_per_s * 5); //std.time.sleep(std.time.ms_per_s * 5);
self.client = try Self._connect_ws(self.allocator, self.gatewayUrl()); self.client = try Self._connect_ws(self.allocator, self.gatewayUrl());
//const event_listener = try std.Thread.spawn(.{}, Self.readMessage, .{ &self, null });
//event_listener.join();
self.readMessage(null) catch unreachable;
} }
pub fn disconnect(self: *Self) !void { pub fn disconnect(self: *Self) !void {
try self.close(ShardSocketCloseCodes.Shutdown, "Shard down request"); try self.close(ShardSocketCloseCodes.Shutdown, "Shard down request");
} }
pub fn close(self: *Self, code: ShardSocketCloseCodes, reason: []const u8) !void { pub const CloseError = mem.Allocator.Error || error{ReasonTooLong};
pub fn close(self: *Self, code: ShardSocketCloseCodes, reason: []const u8) CloseError!void {
self.logif("cooked closing ws conn...\n", .{}); self.logif("cooked closing ws conn...\n", .{});
// Implement reconnection logic here // Implement reconnection logic here
try self.client.close(.{ try self.client.close(.{
@ -408,7 +381,9 @@ pub fn close(self: *Self, code: ShardSocketCloseCodes, reason: []const u8) !void
}); });
} }
pub fn send(self: *Self, data: anytype) !void { pub const SendError = net.Stream.WriteError || std.ArrayList(u8).Writer.Error;
pub fn send(self: *Self, _: bool, data: anytype) SendError!void {
var buf: [1000]u8 = undefined; var buf: [1000]u8 = undefined;
var fba = std.heap.FixedBufferAllocator.init(&buf); var fba = std.heap.FixedBufferAllocator.init(&buf);
var string = std.ArrayList(u8).init(fba.allocator()); var string = std.ArrayList(u8).init(fba.allocator());
@ -546,6 +521,7 @@ pub fn handleEvent(self: *Self, name: []const u8, payload: []const u8) !void {
} }
} }
/// highly experimental, do not use
pub fn loginWithEmail(allocator: mem.Allocator, settings: struct { login: []const u8, password: []const u8, run: GatewayDispatchEvent(*Self), log: Log }) !Self { pub fn loginWithEmail(allocator: mem.Allocator, settings: struct { login: []const u8, password: []const u8, run: GatewayDispatchEvent(*Self), log: Log }) !Self {
const AUTH_LOGIN = "https://discord.com/api/v9/auth/login"; const AUTH_LOGIN = "https://discord.com/api/v9/auth/login";
const WS_CONNECT = "gateway.discord.gg"; const WS_CONNECT = "gateway.discord.gg";
@ -555,8 +531,8 @@ pub fn loginWithEmail(allocator: mem.Allocator, settings: struct { login: []cons
const AuthLoginResponse = struct { user_id: []const u8, token: []const u8, user_settings: struct { locale: []const u8, theme: []const u8 } }; const AuthLoginResponse = struct { user_id: []const u8, token: []const u8, user_settings: struct { locale: []const u8, theme: []const u8 } };
var fetch_options = HttpClient.FetchOptions{ var fetch_options = http.Client.FetchOptions{
.location = HttpClient.FetchOptions.Location{ .location = http.Client.FetchOptions.Location{
.url = AUTH_LOGIN, .url = AUTH_LOGIN,
}, },
.extra_headers = &[_]http.Header{ .extra_headers = &[_]http.Header{
@ -572,7 +548,7 @@ pub fn loginWithEmail(allocator: mem.Allocator, settings: struct { login: []cons
.password = settings.password, .password = settings.password,
}, .{}); }, .{});
var client = HttpClient{ .allocator = allocator }; var client = http.Client{ .allocator = allocator };
defer client.deinit(); defer client.deinit();
_ = try client.fetch(fetch_options); _ = try client.fetch(fetch_options);
@ -581,8 +557,10 @@ pub fn loginWithEmail(allocator: mem.Allocator, settings: struct { login: []cons
return .{ return .{
.allocator = allocator, .allocator = allocator,
.token = response.token, .details = ShardDetails{
.intents = @bitCast(@as(u28, @intCast(0))), .token = response.token,
.intents = @bitCast(@as(u28, @intCast(0))),
},
// maybe there is a better way to do this // maybe there is a better way to do this
.client = try Self._connect_ws(allocator, WS_CONNECT), .client = try Self._connect_ws(allocator, WS_CONNECT),
.session_id = undefined, .session_id = undefined,

227
src/sharder.zig Normal file
View File

@ -0,0 +1,227 @@
const Intents = @import("types.zig").Intents;
const GatewayBotInfo = @import("shared.zig").GatewayBotInfo;
const Shared = @import("shared.zig");
const IdentifyProperties = Shared.IdentifyProperties;
const ShardDetails = Shared.ShardDetails;
const Internal = @import("internal.zig");
const ConnectQueue = Internal.ConnectQueue;
const GatewayDispatchEvent = Internal.GatewayDispatchEvent;
const Log = @import("internal.zig").Log;
const Shard = @import("shard.zig");
const std = @import("std");
const mem = std.mem;
const debug = Internal.debug;
pub const discord_epoch = 1420070400000;
/// Calculate and return the shard ID for a given guild ID
pub inline fn calculateShardId(guildId: u64, shards: ?usize) u64 {
return (guildId >> 22) % shards orelse 1;
}
/// Convert a timestamp to a snowflake.
pub inline fn snowflakeToTimestamp(id: u64) u64 {
return (id >> 22) + discord_epoch;
}
const Self = @This();
shard_details: ShardDetails,
allocator: mem.Allocator,
/// Queue for managing shard connections
connect_queue: ConnectQueue(Shard),
shards: std.AutoArrayHashMap(usize, Shard),
handler: GatewayDispatchEvent(*Shard),
/// configuration settings
options: SessionOptions,
log: Log,
pub const ShardData = struct {
/// resume seq to resume connections
resume_seq: ?usize,
/// resume_gateway_url is the url to resume the connection
/// https://discord.com/developers/docs/topics/gateway#ready-event
resume_gateway_url: ?[]const u8,
/// session_id is the unique session id of the gateway
session_id: ?[]const u8,
};
pub const SessionOptions = struct {
/// Important data which is used by the manager to connect shards to the gateway. */
info: GatewayBotInfo,
/// Delay in milliseconds to wait before spawning next shard. OPTIMAL IS ABOVE 5100. YOU DON'T WANT TO HIT THE RATE LIMIT!!!
spawn_shard_delay: ?u64 = 5300,
/// Total amount of shards your bot uses. Useful for zero-downtime updates or resharding.
total_shards: usize = 1,
shard_start: usize = 0,
shard_end: usize = 1,
/// The payload handlers for messages on the shard.
resharding: ?struct { interval: u64, percentage: usize } = null,
};
pub fn init(allocator: mem.Allocator, settings: struct {
token: []const u8,
intents: Intents,
options: SessionOptions,
run: GatewayDispatchEvent(*Shard),
log: Log,
}) mem.Allocator.Error!Self {
const concurrency = settings.options.info.session_start_limit.?.max_concurrency;
return .{
.allocator = allocator,
.connect_queue = try ConnectQueue(Shard).init(allocator, concurrency, 5000),
.shards = .init(allocator),
.shard_details = ShardDetails{
.token = settings.token,
.intents = settings.intents,
},
.handler = settings.run,
.options = .{
.info = .{
.url = settings.options.info.url,
.shards = settings.options.info.shards,
.session_start_limit = settings.options.info.session_start_limit,
},
.total_shards = settings.options.total_shards,
.shard_start = settings.options.shard_start,
.shard_end = settings.options.shard_end,
},
.log = settings.log,
};
}
pub fn deinit(self: *Self) void {
self.connect_queue.deinit();
self.shards.deinit();
}
pub fn forceIdentify(self: *Self, shard_id: usize) !void {
self.logif("#{d} force identify", .{shard_id});
const shard = try self.create(shard_id);
return shard.identify(null);
}
pub fn disconnect(self: *Self, shard_id: usize) !void {
return if (self.shards.get(shard_id)) |shard| shard.disconnect();
}
pub fn disconnectAll(self: *Self) !void {
while (self.shards.iterator().next()) |shard| shard.value_ptr.disconnect();
}
/// spawn buckets in order
/// Log bucket preparation
/// Divide shards into chunks based on concurrency
/// Assign each shard to a bucket
/// Return list of buckets
/// https://discord.com/developers/docs/events/gateway#sharding-max-concurrency
fn spawnBuckets(self: *Self) ![][]Shard {
const concurrency = self.options.info.session_start_limit.?.max_concurrency;
self.logif("{d}-{d}", .{ self.options.shard_start, self.options.shard_end });
const range = std.math.sub(usize, self.options.shard_start, self.options.shard_end) catch 1;
const bucket_count = (range + concurrency - 1) / concurrency;
self.logif("#0 preparing buckets", .{});
const buckets = try self.allocator.alloc([]Shard, bucket_count);
for (buckets, 0..) |*bucket, i| {
const bucket_size = if ((i + 1) * concurrency > range) range - (i * concurrency) else concurrency;
bucket.* = try self.allocator.alloc(Shard, bucket_size);
for (bucket.*, 0..) |*shard, j| {
shard.* = try self.create(self.options.shard_start + i * concurrency + j);
}
}
self.logif("{d} buckets created", .{bucket_count});
return buckets;
}
/// creates a shard and stores it
fn create(self: *Self, shard_id: usize) !Shard {
if (self.shards.get(shard_id)) |s| return s;
const shard: Shard = try Shard.init(self.allocator, shard_id, .{
.token = self.shard_details.token,
.intents = self.shard_details.intents,
.options = Shard.ShardOptions{},
.run = self.handler,
.log = self.log,
});
try self.shards.put(shard_id, shard);
return shard;
}
pub fn resume_(self: *Self, shard_id: usize, shard_data: ShardData) void {
if (self.shards.contains(shard_id)) return error.CannotOverrideExistingShard;
const shard = self.create(shard_id);
shard.data = shard_data;
return self.connect_queue.push(.{
.shard = shard,
.callback = &callback,
});
}
fn callback(self: *ConnectQueue(Shard).RequestWithShard) anyerror!void {
try self.shard.connect();
}
pub fn spawnShards(self: *Self) !void {
const buckets = try self.spawnBuckets();
self.logif("Spawning shards", .{});
for (buckets) |bucket| {
for (bucket) |shard| {
self.logif("adding {d} to connect queue", .{shard.id});
try self.connect_queue.push(.{
.shard = shard,
.callback = &callback,
});
}
}
//self.startResharder();
}
pub fn send(self: *Self, shard_id: usize, data: anytype) !void {
if (self.shards.get(shard_id)) |shard| {
try shard.send(data);
}
}
// SPEC OF THE RESHARDER:
// Class Self
//
// Method startResharder():
// If resharding interval is not set or shard bounds are not valid:
// Exit
// Set up periodic check for resharding:
// If new shards are required:
// Log resharding process
// Update options with new shard settings
// Disconnect old shards and clear them from manager
// Spawn shards again with updated configuration
//
inline fn logif(self: *Self, comptime format: []const u8, args: anytype) void {
switch (self.log) {
.yes => Internal.debug.info(format, args),
.no => {},
}
}

View File

@ -1,17 +1,13 @@
const Intents = @import("types.zig").Intents;
const default_identify_properties = @import("internal.zig").default_identify_properties;
const std = @import("std"); const std = @import("std");
pub const IdentifyProperties = struct { pub const IdentifyProperties = struct {
///
/// Operating system the shard runs on. /// Operating system the shard runs on.
///
os: []const u8, os: []const u8,
///
/// The "browser" where this shard is running on. /// The "browser" where this shard is running on.
///
browser: []const u8, browser: []const u8,
///
/// The device on which the shard is running. /// The device on which the shard is running.
///
device: []const u8, device: []const u8,
system_locale: ?[]const u8 = null, // TODO parse this system_locale: ?[]const u8 = null, // TODO parse this
@ -48,20 +44,29 @@ pub const GatewaySessionStartLimit = struct {
/// https://discord.com/developers/docs/topics/gateway#get-gateway-bot /// https://discord.com/developers/docs/topics/gateway#get-gateway-bot
pub const GatewayBotInfo = struct { pub const GatewayBotInfo = struct {
url: []const u8, url: []const u8,
///
/// The recommended number of shards to use when connecting /// The recommended number of shards to use when connecting
/// ///
/// See https://discord.com/developers/docs/topics/gateway#sharding /// See https://discord.com/developers/docs/topics/gateway#sharding
///
shards: u32, shards: u32,
///
/// Information on the current session start limit /// Information on the current session start limit
/// ///
/// See https://discord.com/developers/docs/topics/gateway#session-start-limit-object /// See https://discord.com/developers/docs/topics/gateway#session-start-limit-object
///
session_start_limit: ?GatewaySessionStartLimit, session_start_limit: ?GatewaySessionStartLimit,
}; };
pub const ShardDetails = struct {
/// Bot token which is used to connect to Discord */
token: []const u8,
/// The URL of the gateway which should be connected to.
url: []const u8 = "wss://gateway.discord.gg",
/// The gateway version which should be used.
version: ?usize = 10,
/// The calculated intent value of the events which the shard should receive.
intents: Intents,
/// Identify properties to use
properties: IdentifyProperties = default_identify_properties,
};
pub const Snowflake = struct { pub const Snowflake = struct {
id: u64, id: u64,

View File

@ -1,6 +1,8 @@
const Client = @import("discord.zig").Client;
const Shard = @import("discord.zig").Shard; const Shard = @import("discord.zig").Shard;
const Discord = @import("discord.zig").Discord; const Discord = @import("discord.zig").Discord;
const Internal = @import("discord.zig").Internal; const Internal = @import("discord.zig").Internal;
const FetchReq = @import("discord.zig").FetchReq;
const Intents = Discord.Intents; const Intents = Discord.Intents;
const Thread = std.Thread; const Thread = std.Thread;
const std = @import("std"); const std = @import("std");
@ -13,7 +15,7 @@ fn message_create(session: *Shard, message: Discord.Message) void {
std.debug.print("captured: {?s} send by {s}\n", .{ message.content, message.author.username }); std.debug.print("captured: {?s} send by {s}\n", .{ message.content, message.author.username });
if (message.content) |mc| if (std.ascii.eqlIgnoreCase(mc, "!hi")) { if (message.content) |mc| if (std.ascii.eqlIgnoreCase(mc, "!hi")) {
var req = Shard.FetchReq.init(session.allocator, session.token); var req = FetchReq.init(session.allocator, session.details.token);
defer req.deinit(); defer req.deinit();
const payload: Discord.Partial(Discord.CreateMessage) = .{ .content = "Hi, I'm hang man, your personal assistant" }; const payload: Discord.Partial(Discord.CreateMessage) = .{ .content = "Hi, I'm hang man, your personal assistant" };
@ -28,14 +30,13 @@ fn message_create(session: *Shard, message: Discord.Message) void {
pub fn main() !void { pub fn main() !void {
var tsa = std.heap.ThreadSafeAllocator{ .child_allocator = std.heap.c_allocator }; var tsa = std.heap.ThreadSafeAllocator{ .child_allocator = std.heap.c_allocator };
var handler = try Shard.login(tsa.allocator(), .{ var handler = Client.init(tsa.allocator());
try handler.start(.{
.token = std.posix.getenv("TOKEN") orelse unreachable, .token = std.posix.getenv("TOKEN") orelse unreachable,
.intents = Intents.fromRaw(37379), .intents = Intents.fromRaw(37379),
.run = Internal.GatewayDispatchEvent(*Shard){ .run = .{ .message_create = &message_create, .ready = &ready },
.message_create = &message_create,
.ready = &ready,
},
.log = .yes, .log = .yes,
.options = .{},
}); });
errdefer handler.deinit(); errdefer handler.deinit();
} }