commit 9bc0183800e12777113fd39374a161ae06f8c588 Author: rainfall Date: Wed Oct 30 01:15:44 2024 -0500 the king is back diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..acb70f3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +lib/zlib +.zig-cache diff --git a/build.zig b/build.zig new file mode 100644 index 0000000..76ec6a9 --- /dev/null +++ b/build.zig @@ -0,0 +1,96 @@ +const std = @import("std"); + +// Although this function looks imperative, note that its job is to +// declaratively construct a build graph that will be executed by an external +// runner. +pub fn build(b: *std.Build) void { + // these are boiler plate code until you know what you are doing + // and you need to add additional options + const target = b.standardTargetOptions(.{}); + const optimize = b.standardOptimizeOption(.{}); + + // this is your own program + const exe = b.addExecutable(.{ + // the name of your project + .name = "oculus-2", + // your main function + .root_source_file = b.path("src/main.zig"), + // references the ones you declared above + .target = target, + .optimize = optimize, + .link_libc = true, + }); + + const test_comp = b.addTest(.{ + .root_source_file = b.path("src/test.zig"), + .target = target, + .optimize = optimize, + }); + + const websocket = b.createModule(.{ + .root_source_file = b.path("lib/websocket.zig/src/websocket.zig"), + .target = target, + .optimize = optimize, + .link_libc = true, + }); + + const zig_tls_http = b.createModule(.{ + .root_source_file = b.path("lib/zig-tls12/src/HttpClient.zig"), + .target = target, + .optimize = optimize, + }); + + const zlib_zig = b.createModule(.{ + //.name = "zlib", + .target = target, + .optimize = optimize, + .root_source_file = b.path("zlib.zig"), + .link_libc = true, + }); + + const srcs = &.{ + "lib/zlib/adler32.c", + "lib/zlib/compress.c", + "lib/zlib/crc32.c", + "lib/zlib/deflate.c", + "lib/zlib/gzclose.c", + "lib/zlib/gzlib.c", + "lib/zlib/gzread.c", + "lib/zlib/gzwrite.c", + "lib/zlib/inflate.c", + "lib/zlib/infback.c", + "lib/zlib/inftrees.c", + "lib/zlib/inffast.c", + "lib/zlib/trees.c", + "lib/zlib/uncompr.c", + "lib/zlib/zutil.c", + }; + + zlib_zig.addCSourceFiles(.{ .files = srcs, .flags = &.{"-std=c89"} }); + zlib_zig.addIncludePath(b.path("lib/zlib/")); + + websocket.addImport("zlib", zlib_zig); + websocket.addImport("tls12", zig_tls_http); + + // now install your own executable after it's built correctly + + exe.root_module.addImport("ws", websocket); + exe.root_module.addImport("tls12", zig_tls_http); + exe.root_module.addImport("zlib", zlib_zig); + + // test + test_comp.root_module.addImport("ws", websocket); + test_comp.root_module.addImport("tls12", zig_tls_http); + test_comp.root_module.addImport("zlib", zlib_zig); + + const run_test_comp = b.addRunArtifact(test_comp); + const test_step = b.step("test", "Run unit tests"); + test_step.dependOn(&test_comp.step); + test_step.dependOn(&run_test_comp.step); + + const run_cmd = b.addRunArtifact(exe); + run_cmd.step.dependOn(b.getInstallStep()); + + const run_step = b.step("run", "Run the app"); + run_step.dependOn(&run_cmd.step); +} diff --git a/install-zlib.sh b/install-zlib.sh new file mode 100644 index 0000000..909ff74 --- /dev/null +++ b/install-zlib.sh @@ -0,0 +1,4 @@ +wget https://github.com/madler/zlib/releases/download/v1.2.13/zlib-1.2.13.tar.gz +tar xvf zlib-1.2.13.tar.gz +rm zlib-1.2.13.tar.gz +mv zlib-1.2.13 zlib diff --git a/lib/websocket.zig b/lib/websocket.zig new file mode 160000 index 0000000..d8561ca --- /dev/null +++ b/lib/websocket.zig @@ -0,0 +1 @@ +Subproject commit d8561ca98eca4d904ac9383a7f30b2360bed4d3c diff --git a/lib/zig-tls12 b/lib/zig-tls12 new file mode 160000 index 0000000..f2cbb84 --- /dev/null +++ b/lib/zig-tls12 @@ -0,0 +1 @@ +Subproject commit f2cbb846f8a98cb5e19c8476a8e6cf3b9bbcdb0c diff --git a/src/main.zig b/src/main.zig new file mode 100644 index 0000000..8374afe --- /dev/null +++ b/src/main.zig @@ -0,0 +1,474 @@ +const std = @import("std"); +const json = std.json; +const mem = std.mem; +const http = std.http; +const ws = @import("ws"); +const builtin = @import("builtin"); +const HttpClient = @import("tls12"); +const net = std.net; +const crypto = std.crypto; +const tls = std.crypto.tls; +//const TlsClient = @import("tls12").TlsClient; +//const Certificate = @import("tls12").Certificate; +// todo use this to read compressed messages +const zlib = @import("zlib"); + +const Opcode = enum(u4) { + Dispatch = 0, + Heartbeat = 1, + Identify = 2, + PresenceUpdate = 3, + VoiceStateUpdate = 4, + Resume = 6, + Reconnect = 7, + RequestGuildMember = 8, + InvalidSession = 9, + Hello = 10, + HeartbeatACK = 11, +}; + +const ShardSocketCloseCodes = enum(u16) { + Shutdown = 3000, + ZombiedConnection = 3010, +}; + +const BASE_URL = "https://discord.com/api/v10"; + +pub const Intents = packed struct { + guilds: bool = false, + guild_members: bool = false, + guild_bans: bool = false, + guild_emojis: bool = false, + guild_integrations: bool = false, + guild_webhooks: bool = false, + guild_invites: bool = false, + guild_voice_states: bool = false, + + guild_presences: bool = false, + guild_messages: bool = false, + guild_message_reactions: bool = false, + guild_message_typing: bool = false, + direct_messages: bool = false, + direct_message_reactions: bool = false, + direct_message_typing: bool = false, + message_content: bool = false, + + guild_scheduled_events: bool = false, + _pad: u3 = 0, + auto_moderation_configuration: bool = false, + auto_moderation_execution: bool = false, + _pad2: u2 = 0, + + _pad3: u8 = 0, + + pub fn toRaw(self: Intents) u32 { + return @as(u32, @bitCast(self)); + } + + pub fn fromRaw(raw: u32) Intents { + return @as(Intents, @bitCast(raw)); + } + + pub fn jsonStringify(self: Intents, options: std.json.StringifyOptions, writer: anytype) !void { + _ = options; + try writer.print("{}", .{self.toRaw()}); + } +}; + +pub fn main() !void { + const TOKEN = "Bot MTI5ODgzOTgzMDY3OTEzMDE4OA.GNojts.iyblGKK0xTWU57QCG5n3hr2Be1whyylTGr44P0"; + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + defer if (gpa.deinit() == .leak) { + std.log.warn("Has leaked\n", .{}); + }; + const alloc = gpa.allocator(); + + var handler = try Handler.init(alloc, .{ .token = TOKEN, .intents = Intents.fromRaw(513) }); + //errdefer handler.deinit(); + + try handler.readMessage(); + //try handler.identify(); +} + +const HeartbeatHandler = struct { + gateway: *Handler, + heartbeatInterval: u64, + lastHeartbeatAck: bool, + /// useful for calculating ping + lastBeat: u64 = 0, + + pub fn init(gateway: *Handler, interval: u64) !HeartbeatHandler { + var gateway_mut = gateway.*; + return .{ + .gateway = &gateway_mut, + .heartbeatInterval = interval, + .lastHeartbeatAck = false, + .lastBeat = 0, + }; + } + + pub fn deinit(self: *HeartbeatHandler) void { + _ = self; + } + + pub fn run(self: *HeartbeatHandler) !void { + while (true) { + std.time.sleep(self.heartbeatInterval); + try self.gateway.heartbeat(false); + } + } + + pub fn loop(self: *HeartbeatHandler) !void { + _ = self; + std.debug.print("start loop\n", .{}); + //var self_mut = self.*; + //self.thread = try std.Thread.spawn(.{}, HeartbeatHandler.run, .{&self_mut}); + } +}; + +const FetchReq = struct { + allocator: mem.Allocator, + token: []const u8, + client: HttpClient, + body: std.ArrayList(u8), + + pub fn init(allocator: mem.Allocator, token: []const u8) FetchReq { + const client = HttpClient{ .allocator = allocator }; + return FetchReq{ + .allocator = allocator, + .client = client, + .body = std.ArrayList(u8).init(allocator), + .token = token, + }; + } + + pub fn deinit(self: *FetchReq) void { + self.client.deinit(); + self.body.deinit(); + } + + pub fn makeRequest(self: *FetchReq, method: http.Method, path: []const u8, body: ?[]const u8) !HttpClient.FetchResult { + var fetch_options = HttpClient.FetchOptions{ + .location = HttpClient.FetchOptions.Location{ + .url = path, + }, + .extra_headers = &[_]http.Header{ + http.Header{ .name = "Accept", .value = "application/json" }, + http.Header{ .name = "Content-Type", .value = "application/json" }, + http.Header{ .name = "Authorization", .value = self.token }, + }, + .method = method, + .response_storage = .{ .dynamic = &self.body }, + }; + + if (body != null) { + fetch_options.payload = body; + } + + const res = try self.client.fetch(fetch_options); + return res; + } +}; + +pub const Handler = struct { + /// + /// https://discord.com/developers/docs/topics/gateway#get-gateway + /// + const GatewayInfo = struct { + /// The WSS URL that can be used for connecting to the gateway + url: []const u8, + }; + + /// + /// https://discord.com/developers/docs/events/gateway#session-start-limit-object + /// + const GatewaySessionStartLimit = struct { + /// Total number of session starts the current user is allowed + total: u32, + /// Remaining number of session starts the current user is allowed + remaining: u32, + /// Number of milliseconds after which the limit resets + reset_after: u32, + /// Number of identify requests allowed per 5 seconds + max_concurrency: u32, + }; + + /// + /// https://discord.com/developers/docs/topics/gateway#get-gateway-bot + /// + const GatewayBotInfo = struct { + url: []const u8, + /// + /// The recommended number of shards to use when connecting + /// + /// See https://discord.com/developers/docs/topics/gateway#sharding + /// + shards: u32, + /// + /// Information on the current session start limit + /// + /// See https://discord.com/developers/docs/topics/gateway#session-start-limit-object + /// + session_start_limit: ?GatewaySessionStartLimit, + }; + + const IdentifyProperties = struct { + /// + /// Operating system the shard runs on. + /// + os: []const u8, + /// + /// The "browser" where this shard is running on. + /// + browser: []const u8, + /// + /// The device on which the shard is running. + /// + device: []const u8, + }; + + const _default_properties = IdentifyProperties{ + .os = @tagName(builtin.os.tag), + .browser = "seyfert", + .device = "seyfert", + }; + + client: ws.Client, + token: []const u8, + intents: Intents, + session_id: ?[]const u8, + sequence: ?u16, + heartbeater: ?HeartbeatHandler, + allocator: mem.Allocator, + resume_gateway_url: ?[]const u8 = null, + info: GatewayBotInfo, + + pub fn resume_conn(self: *Handler, out: anytype) !void { + const data = .{ .op = @intFromEnum(Opcode.Resume), .d = .{ + .token = self.token, + .session_id = self.session_id, + .seq = self.sequence, + } }; + + try json.stringify(data, .{}, out); + + try self.client.write(&out); + } + + inline fn gateway_url(self: ?*Handler) []const u8 { + // wtf is this? + if (self) |s| { + return s.resume_gateway_url orelse s.info.url; + } + + return "wss://gateway.discord.gg"; + } + + // identifies in order to connect to Discord and get the online status, this shall be done on hello perhaps + fn identify(self: *Handler) !void { + std.debug.print("identifying now...\n", .{}); + const data = .{ + .op = @intFromEnum(Opcode.Identify), + .d = .{ + //.compress = false, + .intents = self.intents.toRaw(), + .properties = Handler._default_properties, + .token = self.token, + }, + }; + + var buf: [1000]u8 = undefined; + var fba = std.heap.FixedBufferAllocator.init(&buf); + var string = std.ArrayList(u8).init(fba.allocator()); + try std.json.stringify(data, .{}, string.writer()); + + std.debug.print("{s}\n", .{string.items}); + // try posting our shitty data + try self.client.write(string.items); + } + + // asks /gateway/bot initializes both the ws client and the http client + pub fn init(allocator: mem.Allocator, args: struct { token: []const u8, intents: Intents }) !Handler { + var req = FetchReq.init(allocator, args.token); + defer req.deinit(); + + const res = try req.makeRequest(.GET, BASE_URL ++ "/gateway/bot", null); + const body = try req.body.toOwnedSlice(); + defer allocator.free(body); + + // check status idk + if (res.status != http.Status.ok) { + @panic("we are cooked"); + } + + const parsed = try json.parseFromSlice(GatewayBotInfo, allocator, body, .{}); + defer parsed.deinit(); + + return .{ + .allocator = allocator, + .token = args.token, + .intents = args.intents, + // maybe there is a better way to do this + .client = try Handler._connect_ws(allocator, parsed.value.url["wss://".len..]), + .session_id = undefined, + .sequence = undefined, + .heartbeater = null, + .info = parsed.value, + }; + } + + inline fn _connect_ws(allocator: mem.Allocator, url: []const u8) !ws.Client { + var conn = try ws.Client.init(allocator, .{ + .tls = true, // important: zig.http doesn't support this, type shit + .port = 443, + .host = url, + //.ca_bundle = @import("tls12").Certificate.Bundle{}, + }); + + try conn.handshake("/?v=10&encoding=json", .{ + .timeout_ms = 1000, + .headers = "host: gateway.discord.gg", + }); + + return conn; + } + + pub fn deinit(self: *Handler) void { + defer self.client.deinit(); + if (self.heartbeater) |*hb| { + hb.deinit(); + } + } + + // listens for messages + pub fn readMessage(self: *Handler) !void { + try self.client.readTimeout(std.time.ms_per_s * 1); + while (true) { + const msg = (try self.client.read()) orelse { + // no message after our 1 second + std.debug.print(".", .{}); + continue; + }; + + // must be called once you're done processing the request + defer self.client.done(msg); + + std.debug.print("received: {s}\n", .{msg.data}); + + const DiscordData = struct { + s: ?u16, //well figure it out + op: Opcode, + d: json.Value, // needs parsing + t: ?[]const u8, + }; + + const raw = try json.parseFromSlice(DiscordData, self.allocator, msg.data, .{}); + defer raw.deinit(); + + const payload = raw.value; + + if (payload.op == Opcode.Dispatch) { + self.sequence = @as(?u16, payload.s); + } + + switch (payload.op) { + Opcode.Dispatch => {}, + Opcode.Hello => { + const HelloPayload = struct { heartbeat_interval: u32, _trace: [][]const u8 }; + const parsed = try json.parseFromValue(HelloPayload, self.allocator, payload.d, .{}); + const helloPayload = parsed.value; + + // PARSE NEW URL IN READY + + if (self.heartbeater == null) { + var self_mut = self.*; + + self.heartbeater = try HeartbeatHandler.init( + // we cooking + &self_mut, helloPayload.heartbeat_interval); + } + + var heartbeater = self.heartbeater.?; + + heartbeater.heartbeatInterval = helloPayload.heartbeat_interval; + + try self.heartbeat(false); + try heartbeater.loop(); + try self.identify(); + }, // heartbeat_interval + Opcode.HeartbeatACK => { + if (self.heartbeater) |*hb| { + hb.lastHeartbeatAck = true; + } + }, // keep this shit alive otherwise kill it + Opcode.Reconnect => {}, + Opcode.Resume => { + const WithSequence = struct { + token: []const u8, + session_id: []const u8, + seq: ?u16, + }; + const parsed = try json.parseFromValue(WithSequence, self.allocator, payload.d, .{}); + const payload_new = parsed.value; + + self.sequence = @as(?u16, payload_new.seq); + self.session_id = payload_new.session_id; + }, + else => { + std.debug.print("Unhandled {} -- {s}", .{ payload.op, "none" }); + }, + } + //try client.write(message.data); + } + } + + pub fn heartbeat(self: *Handler, requested: bool) !void { + var heartbeater = self.heartbeater.?; + //std.time.sleep(heartbeater.heartbeatInterval); + + if (!requested) { + if (!heartbeater.lastHeartbeatAck) { + //try self.close(ShardSocketCloseCodes.ZombiedConnection, "Zombied connection"); + heartbeater.deinit(); + return; + } + heartbeater.lastHeartbeatAck = false; + } + const data = .{ .op = @intFromEnum(Opcode.Heartbeat), .d = self.sequence }; + + var buf: [1000]u8 = undefined; + var fba = std.heap.FixedBufferAllocator.init(&buf); + var string = std.ArrayList(u8).init(fba.allocator()); + try std.json.stringify(data, .{}, string.writer()); + + // try posting our shitty data + std.debug.print("sending heartbeat rn\n", .{}); + try self.client.write(string.items); + } + + pub inline fn reconnect(self: *Handler) !void { + _ = self; + //try self.disconnect(); + //_ = try self.connect(); + } + + pub fn connect(self: *Handler) !Handler { + std.time.sleep(std.time.ms_per_s * 5); + self.client = try Handler._connect_ws(self.allocator, self.gateway_url()); + + return self.*; + } + + pub fn disconnect(self: *Handler) !void { + try self.close(ShardSocketCloseCodes.Shutdown, "Shard down request"); + } + + pub fn close(self: *Handler, code: ShardSocketCloseCodes, reason: []const u8) !void { + std.debug.print("cooked closing ws conn...\n", .{}); + // Implement reconnection logic here + try self.client.close(.{ + .code = @intFromEnum(code), //u16 + .reason = reason, //[]const u8 + }); + } +}; diff --git a/zlib.zig b/zlib.zig new file mode 100644 index 0000000..ef04fae --- /dev/null +++ b/zlib.zig @@ -0,0 +1,587 @@ +// taken from https://github.com/ianic/zig-zlib/blob/main/src/main.zig +// +//// reference: https://zlib.net/manual.html#Advanced + +const std = @import("std"); +const builtin = @import("builtin"); +const c = @cImport({ + @cInclude("zlib.h"); + @cInclude("stddef.h"); +}); + +const alignment = @alignOf(c.max_align_t); +const Allocator = std.mem.Allocator; + +pub const Error = error{ + StreamEnd, + NeedDict, + Errno, + StreamError, + DataError, + MemError, + BufError, + VersionError, + OutOfMemory, + Unknown, +}; + +pub fn errorFromInt(val: c_int) Error { + return switch (val) { + c.Z_STREAM_END => error.StreamEnd, + c.Z_NEED_DICT => error.NeedDict, + c.Z_ERRNO => error.Errno, + c.Z_STREAM_ERROR => error.StreamError, + c.Z_DATA_ERROR => error.DataError, + c.Z_MEM_ERROR => error.MemError, + c.Z_BUF_ERROR => error.BufError, + c.Z_VERSION_ERROR => error.VersionError, + else => error.Unknown, + }; +} + +pub fn checkRC(val: c_int) Error!void { + if (val == c.Z_OK) return; + return errorFromInt(val); +} + +// method is copied from pfg's https://gist.github.com/pfgithub/65c13d7dc889a4b2ba25131994be0d20 +// we have a header for each allocation that records the length, which we need +// for the allocator. Assuming that there aren't many small allocations this is +// acceptable overhead. +const magic_value = 0x1234; +const ZallocHeader = struct { + magic: usize, + size: usize, + + const size_of_aligned = (std.math.divCeil(usize, @sizeOf(ZallocHeader), alignment) catch unreachable) * alignment; +}; + +comptime { + if (@alignOf(ZallocHeader) > alignment) { + @compileError("header has incorrect alignment"); + } +} + +fn zalloc(private: ?*anyopaque, items: c_uint, size: c_uint) callconv(.C) ?*anyopaque { + if (private == null) + return null; + + const allocator: *Allocator = @ptrCast(@alignCast(private.?)); + var buf = allocator.allocWithOptions(u8, ZallocHeader.size_of_aligned + (items * size), @alignOf(*ZallocHeader), null) catch return null; + const header: *ZallocHeader = @ptrCast(@alignCast(buf.ptr)); + header.* = .{ + .magic = magic_value, + .size = items * size, + }; + + return buf[ZallocHeader.size_of_aligned..].ptr; +} + +fn zfree(private: ?*anyopaque, addr: ?*anyopaque) callconv(.C) void { + if (private == null) + return; + + const allocator: *Allocator = @ptrCast(@alignCast(private.?)); + const header = @as(*ZallocHeader, @ptrFromInt(@intFromPtr(addr.?) - ZallocHeader.size_of_aligned)); + + if (builtin.mode != .ReleaseFast) { + if (header.magic != magic_value) + @panic("magic value is incorrect"); + } + + var buf: []align(alignment) u8 = undefined; + buf.ptr = @as([*]align(alignment) u8, @ptrCast(@alignCast(header))); + buf.len = ZallocHeader.size_of_aligned + header.size; + allocator.free(buf); +} + +pub fn compressorWriter(allocator: Allocator, writer: anytype, options: CompressorOptions) Error!CompressorWriter(@TypeOf(writer)) { + return CompressorWriter(@TypeOf(writer)).init(allocator, writer, options); +} + +pub fn decompressorReader(allocator: Allocator, writer: anytype, options: DecompressorOptions) Error!DecompressorReader(@TypeOf(writer)) { + return DecompressorReader(@TypeOf(writer)).init(allocator, writer, options); +} + +fn zStreamInit(allocator: Allocator) !*c.z_stream { + var stream: *c.z_stream = try allocator.create(c.z_stream); + errdefer allocator.destroy(stream); + + // if the user provides an allocator zlib uses an opaque pointer for + // custom malloc an free callbacks, this requires pinning, so we use + // the allocator to allocate the Allocator struct on the heap + const pinned = try allocator.create(Allocator); + errdefer allocator.destroy(pinned); + + pinned.* = allocator; + stream.@"opaque" = pinned; + stream.zalloc = zalloc; + stream.zfree = zfree; + return stream; +} + +fn zStreamDeinit(allocator: Allocator, stream: *c.z_stream) void { + const pinned: *Allocator = @ptrCast(@alignCast(stream.@"opaque".?)); + allocator.destroy(pinned); + allocator.destroy(stream); +} + +pub const CompressorOptions = struct { + const HeaderOptions = enum { + none, // raw deflate data with no zlib header or trailer + zlib, + gzip, // to write a simple gzip header and trailer around the compressed data instead of a zlib wrapper + ws, // same as none for header, but also removes 4 octets (that are 0x00 0x00 0xff 0xff) from the tail end. + // ref: https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.1 + }; + compression_level: c_int = c.Z_DEFAULT_COMPRESSION, + + // memLevel=1 uses minimum memory but is slow and reduces compression ratio; memLevel=9 uses maximum memory for optimal speed. The default value is 8. + memory_level: c_int = 8, + + strategy: c_int = c.Z_DEFAULT_STRATEGY, + + header: HeaderOptions = .zlib, + window_size: u4 = 15, // in the range 9..15, base two logarithm of the maximum window size (the size of the history buffer). + + const Self = @This(); + + pub fn windowSize(self: Self) i6 { + const ws = @as(i6, if (self.window_size < 9) 9 else self.window_size); + return switch (self.header) { + .zlib => ws, + .none, .ws => -@as(i6, ws), + .gzip => ws + 16, + }; + } +}; + +pub fn CompressorWriter(comptime WriterType: type) type { + return struct { + allocator: Allocator, + stream: *c.z_stream, + inner: WriterType, + + const Self = @This(); + const WriterError = Error || WriterType.Error; + const Writer = std.io.Writer(*Self, WriterError, write); + + pub fn init(allocator: Allocator, inner_writer: WriterType, opt: CompressorOptions) !Self { + const stream = try zStreamInit(allocator); + errdefer zStreamDeinit(allocator, stream); + + try checkRC(c.deflateInit2( + stream, + opt.compression_level, + c.Z_DEFLATED, // only option + opt.windowSize(), + opt.memory_level, + opt.strategy, + )); + + return .{ .allocator = allocator, .stream = stream, .inner = inner_writer }; + } + + pub fn deinit(self: *Self) void { + _ = c.deflateEnd(self.stream); + zStreamDeinit(self.allocator, self.stream); + } + + pub fn flush(self: *Self) !void { + var tmp: [4096]u8 = undefined; + while (true) { + self.stream.next_out = &tmp; + self.stream.avail_out = tmp.len; + const rc = c.deflate(self.stream, c.Z_FINISH); + if (rc != c.Z_STREAM_END) + return errorFromInt(rc); + + if (self.stream.avail_out != 0) { + const n = tmp.len - self.stream.avail_out; + try self.inner.writeAll(tmp[0..n]); + break; + } else try self.inner.writeAll(&tmp); + } + } + + pub fn write(self: *Self, buf: []const u8) WriterError!usize { + var tmp: [4096]u8 = undefined; + + // uncompressed + self.stream.next_in = @as([*]u8, @ptrFromInt(@intFromPtr(buf.ptr))); + self.stream.avail_in = @as(c_uint, @intCast(buf.len)); + + while (true) { + // compressed + self.stream.next_out = &tmp; + self.stream.avail_out = tmp.len; + const rc = c.deflate(self.stream, c.Z_PARTIAL_FLUSH); + if (rc != c.Z_OK) + return errorFromInt(rc); + + if (self.stream.avail_out != 0) { + const n = tmp.len - self.stream.avail_out; + try self.inner.writeAll(tmp[0..n]); + break; + } else try self.inner.writeAll(&tmp); + } + + return buf.len - self.stream.avail_in; + } + + pub fn writer(self: *Self) Writer { + return .{ .context = self }; + } + }; +} + +pub const DecompressorOptions = struct { + const HeaderOptions = enum { + none, // raw deflate data with no zlib header or trailer, + zlib_or_gzip, + ws, // websocket compatibile, append deflate tail to the end + }; + + header: HeaderOptions = .zlib_or_gzip, + window_size: u4 = 15, // in the range 8..15, base two logarithm of the maximum window size (the size of the history buffer). + + const Self = @This(); + + pub fn windowSize(self: Self) i5 { + const window_size = if (self.window_size < 8) 15 else self.window_size; + return if (self.header == .none or self.header == .ws) -@as(i5, window_size) else window_size; + } +}; + +pub fn DecompressorReader(comptime ReaderType: type) type { + return struct { + allocator: Allocator, + stream: *c.z_stream, + inner: ReaderType, + tmp: [4096]u8 = undefined, + pos: usize = 0, + + const Self = @This(); + const ReaderError = Error || ReaderType.Error; + const Reader = std.io.Reader(*Self, ReaderError, read); + + pub fn init(allocator: Allocator, inner_reader: ReaderType, options: DecompressorOptions) !Self { + const stream = try zStreamInit(allocator); + errdefer zStreamDeinit(allocator, stream); + + const rc = c.inflateInit2(stream, options.windowSize()); + if (rc != c.Z_OK) return errorFromInt(rc); + + return .{ + .allocator = allocator, + .stream = stream, + .inner = inner_reader, + }; + } + + pub fn deinit(self: *Self) void { + _ = c.inflateEnd(self.stream); + zStreamDeinit(self.allocator, self.stream); + } + + pub fn reset(self: *Self) void { + const rc = c.inflateReset(self.stream); + if (rc != c.Z_OK) return errorFromInt(rc); + } + + pub fn read(self: *Self, buf: []u8) ReaderError!usize { + //std.debug.print("pos: {d} buf.len {d}\n", .{ self.pos, buf.len }); + self.pos += try self.inner.readAll(self.tmp[self.pos..]); + + self.stream.next_in = &self.tmp; + self.stream.avail_in = @as(c_uint, @intCast(self.pos)); + + self.stream.next_out = @as([*]u8, @ptrFromInt(@intFromPtr(buf.ptr))); + self.stream.avail_out = @as(c_uint, @intCast(buf.len)); + + const rc = c.inflate(self.stream, c.Z_SYNC_FLUSH); + if (rc != c.Z_OK and rc != c.Z_STREAM_END) + return errorFromInt(rc); + + if (self.stream.avail_in != 0) { + const done_pos = self.pos - self.stream.avail_in; + std.mem.copyForwards(u8, self.tmp[0..], self.tmp[done_pos..]); + self.pos = self.tmp[done_pos..].len; + } + + return buf.len - self.stream.avail_out; + } + + pub fn reader(self: *Self) Reader { + return .{ .context = self }; + } + }; +} + +pub const Compressor = struct { + allocator: Allocator, + stream: *c.z_stream, + strip_tail: bool = false, + + const Self = @This(); + + pub fn init(allocator: Allocator, opt: CompressorOptions) !Self { + const stream = try zStreamInit(allocator); + errdefer zStreamDeinit(allocator, stream); + try checkRC(c.deflateInit2( + stream, + opt.compression_level, + c.Z_DEFLATED, // only option + opt.windowSize(), + opt.memory_level, + opt.strategy, + )); + return .{ + .allocator = allocator, + .stream = stream, + .strip_tail = opt.header == .ws, + }; + } + + pub fn deinit(self: *Self) void { + _ = c.deflateEnd(self.stream); + zStreamDeinit(self.allocator, self.stream); + } + + pub fn reset(self: *Self) !void { + try checkRC(c.deflateReset(self.stream)); + } + + // Compresses to new allocated buffer. + // Caller owns returned memory. + pub fn compressAllAlloc(self: *Self, uncompressed: []const u8) ![]u8 { + self.stream.next_in = @as([*]u8, @ptrFromInt(@intFromPtr(uncompressed.ptr))); + self.stream.avail_in = @as(c_uint, @intCast(uncompressed.len)); + + var tmp = try self.allocator.alloc(u8, chunk_size); + var len: usize = 0; // used part of the tmp buffer + + var flag = c.Z_PARTIAL_FLUSH; + while (true) { + const out = tmp[len..]; + self.stream.next_out = @as([*]u8, @ptrFromInt(@intFromPtr(out.ptr))); + self.stream.avail_out = @as(c_uint, @intCast(out.len)); + + const rc = c.deflate(self.stream, flag); + if (rc != c.Z_OK and rc != c.Z_STREAM_END) + return errorFromInt(rc); + + len += out.len - self.stream.avail_out; + if (self.stream.avail_out == 0) { // out is full + tmp = try self.allocator.realloc(tmp, tmp.len * 2); + continue; + } + + if (flag == c.Z_SYNC_FLUSH) break; + flag = c.Z_SYNC_FLUSH; + } + if (self.strip_tail and len > 4 and tmp[len - 1] == 0xff and tmp[len - 2] == 0xff and tmp[len - 3] == 0x00 and tmp[len - 4] == 0x00) + len -= 4; + return try self.allocator.realloc(tmp, len); + } +}; + +const chunk_size = 4096; + +const deflate_tail = [_]u8{ 0x00, 0x00, 0xff, 0xff }; + +pub const Decompressor = struct { + allocator: Allocator, + stream: *c.z_stream, + append_tail: bool = false, + + const Self = @This(); + + pub fn init(allocator: Allocator, options: DecompressorOptions) !Self { + const stream = try zStreamInit(allocator); + errdefer zStreamDeinit(allocator, stream); + try checkRC(c.inflateInit2(stream, options.windowSize())); + return .{ + .allocator = allocator, + .stream = stream, + .append_tail = options.header == .ws, + }; + } + + pub fn deinit(self: *Self) void { + _ = c.inflateEnd(self.stream); + zStreamDeinit(self.allocator, self.stream); + } + + pub fn reset(self: *Self) !void { + try checkRC(c.inflateReset(self.stream)); + } + + // Decompresses to new allocated buffer. + // Caller owns returned memory. + pub fn decompressAllAlloc(self: *Self, compressed: []const u8) ![]u8 { + self.stream.next_in = @as([*]u8, @ptrFromInt(@intFromPtr(compressed.ptr))); + self.stream.avail_in = @as(c_uint, @intCast(compressed.len)); + + var tail_appended = false; + var tmp = try self.allocator.alloc(u8, chunk_size); + var len: usize = 0; // inflated part of the tmp buffer + while (true) { + const out = tmp[len..]; + self.stream.next_out = @as([*]u8, @ptrFromInt(@intFromPtr(out.ptr))); + self.stream.avail_out = @as(c_uint, @intCast(out.len)); + + const rc = c.inflate(self.stream, c.Z_SYNC_FLUSH); + if (rc != c.Z_OK and rc != c.Z_STREAM_END) { + return errorFromInt(rc); + } + len += out.len - self.stream.avail_out; + if (self.stream.avail_in != 0 and self.stream.avail_out == 0) { // in not empty, out full + tmp = try self.allocator.realloc(tmp, tmp.len * 2); // make more space + continue; + } + + if (self.append_tail and !tail_appended) { + self.stream.next_in = @as([*]u8, @ptrFromInt(@intFromPtr(&deflate_tail))); + self.stream.avail_in = @as(c_uint, @intCast(deflate_tail.len)); + tail_appended = true; + continue; + } + break; + } + return try self.allocator.realloc(tmp, len); + } +}; + +test "compress gzip with zig interface" { + const allocator = std.testing.allocator; + var fifo = std.fifo.LinearFifo(u8, .Dynamic).init(allocator); + defer fifo.deinit(); + + // compress with zlib + const input = @embedFile("rfc1951.txt"); + var cmp = try compressorWriter(allocator, fifo.writer(), .{ .header = .gzip }); + defer cmp.deinit(); + const writer = cmp.writer(); + try writer.writeAll(input); + try cmp.flush(); + + // decompress with zig std lib gzip + var dcmp = try std.compress.gzip.decompress(allocator, fifo.reader()); + defer dcmp.deinit(); + const actual = try dcmp.reader().readAllAlloc(allocator, std.math.maxInt(usize)); + defer allocator.free(actual); + + try std.testing.expectEqualStrings(input, actual); +} + +test "compress/decompress" { + const allocator = std.testing.allocator; + var fifo = std.fifo.LinearFifo(u8, .Dynamic).init(allocator); + defer fifo.deinit(); + + // compress + const input = @embedFile("rfc1951.txt"); + var cmp = try compressorWriter(allocator, fifo.writer(), .{}); + defer cmp.deinit(); + const writer = cmp.writer(); + try writer.writeAll(input); + try cmp.flush(); + + // decompress + var dcmp = try decompressorReader(allocator, fifo.reader(), .{}); + defer dcmp.deinit(); + const actual = try dcmp.reader().readAllAlloc(allocator, std.math.maxInt(usize)); + defer allocator.free(actual); + + try std.testing.expectEqualStrings(input, actual); +} + +test "buffer compress/decompress" { + const allocator = std.testing.allocator; + + const input = @embedFile("rfc1951.txt"); + var cmp = try Compressor.init(allocator, .{ .header = .none }); + defer cmp.deinit(); + const compressed = try cmp.compressAllAlloc(input); + defer allocator.free(compressed); + + var dcmp = try Decompressor.init(allocator, .{ .header = .none }); + defer dcmp.deinit(); + const decompressed = try dcmp.decompressAllAlloc(compressed); + defer allocator.free(decompressed); + + try std.testing.expectEqualSlices(u8, input, decompressed); +} + +test "compress gzip with C interface" { + var input = [_]u8{ 'b', 'l', 'a', 'r', 'g' }; + var output_buf: [4096]u8 = undefined; + + var zs: c.z_stream = undefined; + zs.zalloc = null; + zs.zfree = null; + zs.@"opaque" = null; + zs.avail_in = input.len; + zs.next_in = &input; + zs.avail_out = output_buf.len; + zs.next_out = &output_buf; + + _ = c.deflateInit2(&zs, c.Z_DEFAULT_COMPRESSION, c.Z_DEFLATED, 15 | 16, 8, c.Z_DEFAULT_STRATEGY); + _ = c.deflate(&zs, c.Z_FINISH); + _ = c.deflateEnd(&zs); +} + +// debug helper +fn showBuf(buf: []const u8) void { + std.debug.print("\n", .{}); + for (buf) |b| + std.debug.print("0x{x:0>2}, ", .{b}); + std.debug.print("\n", .{}); +} + +test "Hello compress/decompress websocket compatibile" { + const allocator = std.testing.allocator; + const input = "Hello"; + + var cmp = try Compressor.init(allocator, .{ .header = .ws }); + defer cmp.deinit(); + const compressed = try cmp.compressAllAlloc(input); + defer allocator.free(compressed); + try std.testing.expectEqualSlices(u8, &[_]u8{ 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x08, 0x00 }, compressed); + + var dcp = try Decompressor.init(allocator, .{ .header = .ws }); + defer dcp.deinit(); + const decompressed = try dcp.decompressAllAlloc(compressed); + defer allocator.free(decompressed); + + try std.testing.expectEqualSlices(u8, input, decompressed); +} + +// reference: https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.2 +test "Sharing LZ77 Sliding Window" { + const allocator = std.testing.allocator; + const input = "Hello"; + + var cmp = try Compressor.init(allocator, .{ .header = .ws }); + defer cmp.deinit(); + + const c1 = try cmp.compressAllAlloc(input); + defer allocator.free(c1); + try std.testing.expectEqualSlices(u8, &[_]u8{ 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x08, 0x00 }, c1); + + // compress second message using same sliding window, should be little shorter + const c2 = try cmp.compressAllAlloc(input); + defer allocator.free(c2); + try std.testing.expectEqualSlices(u8, &[_]u8{ 0xf2, 0x00, 0x11, 0x00, 0x01, 0x00 }, c2); + + var dcp = try Decompressor.init(allocator, .{ .header = .ws }); + defer dcp.deinit(); + const d1 = try dcp.decompressAllAlloc(c1); + defer allocator.free(d1); + try std.testing.expectEqualSlices(u8, input, d1); + + const d2 = try dcp.decompressAllAlloc(c1); + defer allocator.free(d2); + try std.testing.expectEqualSlices(u8, input, d2); +}