eaapi/eaapi/lz77.py

136 lines
4.1 KiB
Python

from .misc import assert_true
WINDOW_SIZE = 0x1000
WINDOW_MASK = WINDOW_SIZE - 1
THRESHOLD = 3
INPLACE_THRESHOLD = 0xA
LOOK_RANGE = 0x200
MAX_LEN = 0xF + THRESHOLD
MAX_BUFFER = 0x10 + 1
def match_current(window: bytes, pos: int, max_len: int, data: bytes, dpos: int) -> int:
length = 0
while (
dpos + length < len(data)
and length < max_len
and window[(pos + length) & WINDOW_MASK] == data[dpos + length]
and length < MAX_LEN
):
length += 1
return length
def match_window(window: bytes, pos: int, data: bytes, d_pos: int) -> None | tuple[int, int]:
max_pos = 0
max_len = 0
for i in range(THRESHOLD, LOOK_RANGE):
length = match_current(window, (pos - i) & WINDOW_MASK, i, data, d_pos)
if length >= INPLACE_THRESHOLD:
return (i, length)
if length >= THRESHOLD:
max_pos = i
max_len = length
if max_len >= THRESHOLD:
return (max_pos, max_len)
return None
def lz77_compress(data: bytes) -> bytes:
output = bytearray()
window = bytearray(WINDOW_SIZE)
current_pos = 0
current_window = 0
current_buffer = 0
flag_byte = 0
bit = 0
buffer = [0] * MAX_BUFFER
pad = 3
while current_pos < len(data):
flag_byte = 0
current_buffer = 0
for bit_pos in range(8):
if current_pos >= len(data):
pad = 0
flag_byte = flag_byte >> (8 - bit_pos)
buffer[current_buffer] = 0
buffer[current_buffer + 1] = 0
current_buffer += 2
break
else:
found = match_window(window, current_window, data, current_pos)
if found is not None and found[1] >= THRESHOLD:
pos, length = found
byte1 = pos >> 4
byte2 = (((pos & 0x0F) << 4) | ((length - THRESHOLD) & 0x0F))
buffer[current_buffer] = byte1
buffer[current_buffer + 1] = byte2
current_buffer += 2
bit = 0
for _ in range(length):
window[current_window & WINDOW_MASK] = data[current_pos]
current_pos += 1
current_window += 1
else:
buffer[current_buffer] = data[current_pos]
window[current_window] = data[current_pos]
current_pos += 1
current_window += 1
current_buffer += 1
bit = 1
flag_byte = (flag_byte >> 1) | ((bit & 1) << 7)
current_window = current_window & WINDOW_MASK
assert_true(current_buffer < MAX_BUFFER, f"current buffer {current_buffer} > max buffer {MAX_BUFFER}")
output.append(flag_byte)
for i in range(current_buffer):
output.append(buffer[i])
for _ in range(pad):
output.append(0)
return bytes(output)
def lz77_decompress(data: bytes) -> bytes:
output = bytearray()
cur_byte = 0
window = bytearray(WINDOW_SIZE)
window_cursor = 0
while cur_byte < len(data):
flag = data[cur_byte]
cur_byte += 1
for i in range(8):
if (flag >> i) & 1 == 1:
output.append(data[cur_byte])
window[window_cursor] = data[cur_byte]
window_cursor = (window_cursor + 1) & WINDOW_MASK
cur_byte += 1
else:
w = ((data[cur_byte]) << 8) | (data[cur_byte + 1])
if w == 0:
return bytes(output)
cur_byte += 2
position = ((window_cursor - (w >> 4)) & WINDOW_MASK)
length = (w & 0x0F) + THRESHOLD
for _ in range(length):
b = window[position & WINDOW_MASK]
output.append(b)
window[window_cursor] = b
window_cursor = (window_cursor + 1) & WINDOW_MASK
position += 1
return bytes(output)
__all__ = (
"lz77_compress", "lz77_decompress"
)