diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index 330db1f..0000000 --- a/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "server"] - path = eaapi/server - url = https://gitea.tendokyu.moe/eamuse/server.git diff --git a/eaapi/__init__.py b/eaapi/__init__.py index 6c732e8..2b4c1d1 100644 --- a/eaapi/__init__.py +++ b/eaapi/__init__.py @@ -1,13 +1,18 @@ -from .const import Type, ServicesMode, Compression -from .node import XMLNode -from .encoder import Encoder -from .decoder import Decoder -from .wrapper import wrap, unwrap -from .misc import parse_model - -__all__ = ( - "Type", "ServicesMode", "Compression", - "XMLNode", "Encoder", "Decoder", - "wrap", "unwrap", - "parse_model", -) +from .const import Type, ServicesMode, Compression +from .node import XMLNode +from .encoder import Encoder +from .decoder import Decoder +from .wrapper import wrap, unwrap +from .misc import parse_model + +from .exception import EAAPIException +from . import crypt + +__all__ = ( + "Type", "ServicesMode", "Compression", + "XMLNode", "Encoder", "Decoder", + "wrap", "unwrap", + "parse_model", + "EAAPIException", + "crypt", +) diff --git a/eaapi/cardconv.py b/eaapi/cardconv.py index f0031aa..8194382 100644 --- a/eaapi/cardconv.py +++ b/eaapi/cardconv.py @@ -1,85 +1,85 @@ -import binascii - -from Crypto.Cipher import DES3 - -from .misc import assert_true, pack, unpack -from .exception import InvalidCard -from .keys import CARDCONV_KEY -from .const import CARD_ALPHABET - - -def enc_des(uid): - cipher = DES3.new(CARDCONV_KEY, DES3.MODE_CBC, iv=b'\0' * 8) - return cipher.encrypt(uid) - - -def dec_des(uid): - cipher = DES3.new(CARDCONV_KEY, DES3.MODE_CBC, iv=b'\0' * 8) - return cipher.decrypt(uid) - - -def checksum(data): - chk = sum(data[i] * (i % 3 + 1) for i in range(15)) - - while chk > 31: - chk = (chk >> 5) + (chk & 31) - - return chk - - -def uid_to_konami(uid): - assert_true(len(uid) == 16, "UID must be 16 bytes", InvalidCard) - - if uid.upper().startswith("E004"): - card_type = 1 - elif uid.upper().startswith("0"): - card_type = 2 - else: - raise InvalidCard("Invalid UID prefix") - - kid = binascii.unhexlify(uid) - assert_true(len(kid) == 8, "ID must be 8 bytes", InvalidCard) - - out = bytearray(unpack(enc_des(kid[::-1]), 5)[:13]) + b'\0\0\0' - - out[0] ^= card_type - out[13] = 1 - for i in range(1, 14): - out[i] ^= out[i - 1] - out[14] = card_type - out[15] = checksum(out) - - return "".join(CARD_ALPHABET[i] for i in out) - - -def konami_to_uid(konami_id): - if konami_id[14] == "1": - card_type = 1 - elif konami_id[14] == "2": - card_type = 2 - else: - raise InvalidCard("Invalid ID") - - assert_true(len(konami_id) == 16, "ID must be 16 characters", InvalidCard) - assert_true(all(i in CARD_ALPHABET for i in konami_id), "ID contains invalid characters", InvalidCard) - card = [CARD_ALPHABET.index(i) for i in konami_id] - assert_true(card[11] % 2 == card[12] % 2, "Parity check failed", InvalidCard) - assert_true(card[13] == card[12] ^ 1, "Card invalid", InvalidCard) - assert_true(card[15] == checksum(card), "Checksum failed", InvalidCard) - - for i in range(13, 0, -1): - card[i] ^= card[i - 1] - - card[0] ^= card_type - - card_id = dec_des(pack(card[:13], 5)[:8])[::-1] - card_id = binascii.hexlify(card_id).decode().upper() - - if card_type == 1: - assert_true(card_id[:4] == "E004", "Invalid card type", InvalidCard) - elif card_type == 2: - assert_true(card_id[0] == "0", "Invalid card type", InvalidCard) - return card_id - - -__all__ = ("konami_to_uid", "uid_to_konami") +import binascii + +from Crypto.Cipher import DES3 + +from .misc import assert_true, pack, unpack +from .exception import InvalidCard +from .keys import CARDCONV_KEY +from .const import CARD_ALPHABET + + +def enc_des(uid: bytes) -> bytes: + cipher = DES3.new(CARDCONV_KEY, DES3.MODE_CBC, iv=b'\0' * 8) + return cipher.encrypt(uid) + + +def dec_des(uid: bytes) -> bytes: + cipher = DES3.new(CARDCONV_KEY, DES3.MODE_CBC, iv=b'\0' * 8) + return cipher.decrypt(uid) + + +def checksum(data: bytes) -> int: + chk = sum(data[i] * (i % 3 + 1) for i in range(15)) + + while chk > 31: + chk = (chk >> 5) + (chk & 31) + + return chk + + +def uid_to_konami(uid: str) -> str: + assert_true(len(uid) == 16, "UID must be 16 bytes", InvalidCard) + + if uid.upper().startswith("E004"): + card_type = 1 + elif uid.upper().startswith("0"): + card_type = 2 + else: + raise InvalidCard("Invalid UID prefix") + + kid = binascii.unhexlify(uid) + assert_true(len(kid) == 8, "ID must be 8 bytes", InvalidCard) + + out = bytearray(unpack(enc_des(kid[::-1]), 5)[:13]) + b'\0\0\0' + + out[0] ^= card_type + out[13] = 1 + for i in range(1, 14): + out[i] ^= out[i - 1] + out[14] = card_type + out[15] = checksum(out) + + return "".join(CARD_ALPHABET[i] for i in out) + + +def konami_to_uid(konami_id: str) -> str: + if konami_id[14] == "1": + card_type = 1 + elif konami_id[14] == "2": + card_type = 2 + else: + raise InvalidCard("Invalid ID") + + assert_true(len(konami_id) == 16, "ID must be 16 characters", InvalidCard) + assert_true(all(i in CARD_ALPHABET for i in konami_id), "ID contains invalid characters", InvalidCard) + card = bytearray([CARD_ALPHABET.index(i) for i in konami_id]) + assert_true(card[11] % 2 == card[12] % 2, "Parity check failed", InvalidCard) + assert_true(card[13] == card[12] ^ 1, "Card invalid", InvalidCard) + assert_true(card[15] == checksum(card), "Checksum failed", InvalidCard) + + for i in range(13, 0, -1): + card[i] ^= card[i - 1] + + card[0] ^= card_type + + card_id = dec_des(pack(card[:13], 5)[:8])[::-1] + card_id = binascii.hexlify(card_id).decode().upper() + + if card_type == 1: + assert_true(card_id[:4] == "E004", "Invalid card type", InvalidCard) + elif card_type == 2: + assert_true(card_id[0] == "0", "Invalid card type", InvalidCard) + return card_id + + +__all__ = ("konami_to_uid", "uid_to_konami") diff --git a/eaapi/const.py b/eaapi/const.py index 7f86292..090cc9c 100644 --- a/eaapi/const.py +++ b/eaapi/const.py @@ -1,163 +1,171 @@ -import enum - -from dataclasses import dataclass -from html import unescape -from typing import List, Callable - -from .misc import assert_true - - -CARD_ALPHABET = "0123456789ABCDEFGHJKLMNPRSTUWXYZ" - -NAME_MAX_COMPRESSED = 0x24 -NAME_MAX_DECOMPRESSED = 0x1000 - -ENCODING = { - 0x20: "ascii", - 0x40: "iso-8859-1", - 0x60: "euc-jp", - 0x80: "shift-jis", - 0xA0: "utf-8", -} -DEFAULT_ENCODING = ENCODING[0x80] # Shift-JIS -ENCODING[0x00] = DEFAULT_ENCODING -XML_ENCODING = { - "ASCII": "ascii", - "ISO-8859-1": "iso-8859-1", - "EUC-JP": "euc-jp", - "SHIFT_JIS": "shift-jis", - "SHIFT-JIS": "shift-jis", - "UTF-8": "utf-8", -} -ENCODING_BACK = {v: k for k, v in ENCODING.items()} -XML_ENCODING_BACK = {v: k for k, v in XML_ENCODING.items()} -ATTR = 0x2E -END_NODE = 0xFE -END_DOC = 0xFF -PACK_ALPHABET = "0123456789:ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz" - -CONTENT_COMP_FULL = 0x42 -CONTENT_COMP_SCHEMA = 0x43 -CONTENT_FINGERPRINT = 0x44 # TODO: Identify how exactly this differs from the others -CONTENT_ASCII_FULL = 0x45 -CONTENT_ASCII_SCHEMA = 0x46 - -CONTENT_COMP = (CONTENT_COMP_FULL, CONTENT_COMP_SCHEMA) -CONTENT_FULL = (CONTENT_COMP_FULL, CONTENT_ASCII_FULL) -CONTENT = ( - CONTENT_COMP_FULL, CONTENT_COMP_SCHEMA, CONTENT_FINGERPRINT, - CONTENT_ASCII_FULL, CONTENT_ASCII_SCHEMA -) - -ARRAY_BIT = 0x40 -ARRAY_MASK = ARRAY_BIT - 1 - - -@dataclass -class _Type: - id: int - fmt: str - names: List[str] - c_name: str - convert: Callable - size: int = 1 - no_check: bool = False - - def _parse(self, value): - if self.convert is None: - return () - if self.size == 1: - if isinstance(value, (list, tuple)) and len(value) == 1: - value = value[0] - return self.convert(value) - if not self.no_check: - assert_true(len(value) == self.size, "Invalid node data") - return (*map(self.convert, value),) - - -def parse_ip(ip): - return (*map(int, ip.split(".")),) - - -class Type(enum.Enum): - Void = _Type(0x01, "", ["void"], "void", None) - S8 = _Type(0x02, "b", ["s8"], "int8", int) - U8 = _Type(0x03, "B", ["u8"], "uint8", int) - S16 = _Type(0x04, "h", ["s16"], "int16", int) - U16 = _Type(0x05, "H", ["u16"], "uint16", int) - S32 = _Type(0x06, "i", ["s32"], "int32", int) - U32 = _Type(0x07, "I", ["u32"], "uint32", int) - S64 = _Type(0x08, "q", ["s64"], "int64", int) - U64 = _Type(0x09, "Q", ["u64"], "uint64", int) - Blob = _Type(0x0a, "S", ["bin", "binary"], "char[]", bytes) - Str = _Type(0x0b, "s", ["str", "string"], "char[]", unescape) - IPv4 = _Type(0x0c, "4B", ["ip4"], "uint8[4]", parse_ip, 1, True) - Time = _Type(0x0d, "I", ["time"], "uint32", int) - Float = _Type(0x0e, "f", ["float", "f"], "float", float) - Double = _Type(0x0f, "d", ["double", "d"], "double", float) - - TwoS8 = _Type(0x10, "2b", ["2s8"], "int8[2]", int, 2) - TwoU8 = _Type(0x11, "2B", ["2u8"], "uint8[2]", int, 2) - TwoS16 = _Type(0x12, "2h", ["2s16"], "int16[2]", int, 2) - TwoU16 = _Type(0x13, "2H", ["2u16"], "uint16[2]", int, 2) - TwoS32 = _Type(0x14, "2i", ["2s32"], "int32[2]", int, 2) - TwoU32 = _Type(0x15, "2I", ["2u32"], "uint32[2]", int, 2) - TwoS64 = _Type(0x16, "2q", ["2s64", "vs64"], "int16[2]", int, 2) - TwoU64 = _Type(0x17, "2Q", ["2u64", "vu64"], "uint16[2]", int, 2) - TwoFloat = _Type(0x18, "2f", ["2f"], "float[2]", float, 2) - TwoDouble = _Type(0x19, "2d", ["2d", "vd"], "double[2]", float, 2) - - ThreeS8 = _Type(0x1a, "3b", ["3s8"], "int8[3]", int, 3) - ThreeU8 = _Type(0x1b, "3B", ["3u8"], "uint8[3]", int, 3) - ThreeS16 = _Type(0x1c, "3h", ["3s16"], "int16[3]", int, 3) - ThreeU16 = _Type(0x1d, "3H", ["3u16"], "uint16[3]", int, 3) - ThreeS32 = _Type(0x1e, "3i", ["3s32"], "int32[3]", int, 3) - ThreeU32 = _Type(0x1f, "3I", ["3u32"], "uint32[3]", int, 3) - ThreeS64 = _Type(0x20, "3q", ["3s64"], "int64[3]", int, 3) - ThreeU64 = _Type(0x21, "3Q", ["3u64"], "uint64[3]", int, 3) - ThreeFloat = _Type(0x22, "3f", ["3f"], "float[3]", float, 3) - ThreeDouble = _Type(0x23, "3d", ["3d"], "double[3]", float, 3) - - FourS8 = _Type(0x24, "4b", ["4s8"], "int8[4]", int, 4) - FourU8 = _Type(0x25, "4B", ["4u8"], "uint8[4]", int, 4) - FourS16 = _Type(0x26, "4h", ["4s16"], "int16[4]", int, 4) - FourU16 = _Type(0x27, "4H", ["4u16"], "uint8[4]", int, 4) - FourS32 = _Type(0x28, "4i", ["4s32", "vs32"], "int32[4]", int, 4) - FourU32 = _Type(0x29, "4I", ["4u32", "vs32"], "uint32[4]", int, 4) - FourS64 = _Type(0x2a, "4q", ["4s64"], "int64[4]", int, 4) - FourU64 = _Type(0x2b, "4Q", ["4u64"], "uint64[4]", int, 4) - FourFloat = _Type(0x2c, "4f", ["4f", "vf"], "float[4]", float, 4) - FourDouble = _Type(0x2d, "4d", ["4d"], "double[4]", float, 4) - - Attr = _Type(0x2e, "s", ["attr"], "char[]", None) - Array = _Type(0x2f, "", ["array"], "", None) - - VecS8 = _Type(0x30, "16b", ["vs8"], "int8[16]", int, 16) - VecU8 = _Type(0x31, "16B", ["vu8"], "uint8[16]", int, 16) - VecS16 = _Type(0x32, "8h", ["vs16"], "int8[8]", int, 8) - VecU16 = _Type(0x33, "8H", ["vu16"], "uint8[8]", int, 8) - - Bool = _Type(0x34, "b", ["bool", "b"], "bool", int) - TwoBool = _Type(0x35, "2b", ["2b"], "bool[2]", int, 2) - ThreeBool = _Type(0x36, "3b", ["3b"], "bool[3]", int, 3) - FourBool = _Type(0x37, "4b", ["4b"], "bool[4]", int, 4) - VecBool = _Type(0x38, "16b", ["vb"], "bool[16]", int, 16) - - @classmethod - def from_val(cls, value): - for i in cls: - if i.value.id == value & ARRAY_MASK: - return i - raise ValueError(f"Unknown node type {value}") - - -class ServicesMode(enum.Enum): - Operation = "operation" - Debug = "debug" - Test = "test" - Factory = "factory" - - -class Compression(enum.Enum): - Lz77 = "lz77" - None_ = "none" +import enum + +from dataclasses import dataclass +from html import unescape +from typing import List, Callable + +from .misc import assert_true + + +CARD_ALPHABET = "0123456789ABCDEFGHJKLMNPRSTUWXYZ" + +NAME_MAX_COMPRESSED = 0x24 +NAME_MAX_DECOMPRESSED = 0x1000 + +ENCODING = { + 0x20: "ascii", + 0x40: "iso-8859-1", + 0x60: "euc-jp", + 0x80: "shift-jis", + 0xA0: "utf-8", +} +DEFAULT_ENCODING = ENCODING[0x80] # Shift-JIS +ENCODING[0x00] = DEFAULT_ENCODING +XML_ENCODING = { + "ASCII": "ascii", + "ISO-8859-1": "iso-8859-1", + "EUC-JP": "euc-jp", + "SHIFT_JIS": "shift-jis", + "SHIFT-JIS": "shift-jis", + "UTF-8": "utf-8", +} +ENCODING_BACK = {v: k for k, v in ENCODING.items()} +XML_ENCODING_BACK = {v: k for k, v in XML_ENCODING.items()} +ATTR = 0x2E +END_NODE = 0xFE +END_DOC = 0xFF +PACK_ALPHABET = "0123456789:ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz" + +CONTENT_COMP_FULL = 0x42 +CONTENT_COMP_SCHEMA = 0x43 +CONTENT_FINGERPRINT = 0x44 # TODO: Identify how exactly this differs from the others +CONTENT_ASCII_FULL = 0x45 +CONTENT_ASCII_SCHEMA = 0x46 + +CONTENT_COMP = (CONTENT_COMP_FULL, CONTENT_COMP_SCHEMA) +CONTENT_FULL = (CONTENT_COMP_FULL, CONTENT_ASCII_FULL) +CONTENT = ( + CONTENT_COMP_FULL, CONTENT_COMP_SCHEMA, CONTENT_FINGERPRINT, + CONTENT_ASCII_FULL, CONTENT_ASCII_SCHEMA +) + +ARRAY_BIT = 0x40 +ARRAY_MASK = ARRAY_BIT - 1 + + +@dataclass +class _Type: + id: int + fmt: str + names: List[str] + c_name: str + convert: Callable | None + size: int = 1 + no_check: bool = False + + def _parse(self, value): + if self.convert is None: + return () + if self.size == 1: + if isinstance(value, (list, tuple)) and len(value) == 1: + value = value[0] + return self.convert(value) + if not self.no_check: + assert_true(len(value) == self.size, "Invalid node data") + return (*map(self.convert, value),) + + +def parse_ip(ip: int | str) -> tuple[int, int, int, int]: + if isinstance(ip, int): + return ( + (ip >> 24) & 0xff, + (ip >> 16) & 0xff, + (ip >> 8) & 0xff, + (ip >> 0) & 0xff, + ) + return (*map(int, ip.split(".")),) + + +class Type(enum.Enum): + Void = _Type(0x01, "", ["void"], "void", None) + S8 = _Type(0x02, "b", ["s8"], "int8", int) + U8 = _Type(0x03, "B", ["u8"], "uint8", int) + S16 = _Type(0x04, "h", ["s16"], "int16", int) + U16 = _Type(0x05, "H", ["u16"], "uint16", int) + S32 = _Type(0x06, "i", ["s32"], "int32", int) + U32 = _Type(0x07, "I", ["u32"], "uint32", int) + S64 = _Type(0x08, "q", ["s64"], "int64", int) + U64 = _Type(0x09, "Q", ["u64"], "uint64", int) + Blob = _Type(0x0a, "S", ["bin", "binary"], "char[]", bytes) + Str = _Type(0x0b, "s", ["str", "string"], "char[]", unescape) + IPv4 = _Type(0x0c, "4B", ["ip4"], "uint8[4]", parse_ip, 1, True) + IPv4_Int = _Type(0x0c, "I", ["ip4"], "uint8[4]", parse_ip, 1, True) + Time = _Type(0x0d, "I", ["time"], "uint32", int) + Float = _Type(0x0e, "f", ["float", "f"], "float", float) + Double = _Type(0x0f, "d", ["double", "d"], "double", float) + + TwoS8 = _Type(0x10, "2b", ["2s8"], "int8[2]", int, 2) + TwoU8 = _Type(0x11, "2B", ["2u8"], "uint8[2]", int, 2) + TwoS16 = _Type(0x12, "2h", ["2s16"], "int16[2]", int, 2) + TwoU16 = _Type(0x13, "2H", ["2u16"], "uint16[2]", int, 2) + TwoS32 = _Type(0x14, "2i", ["2s32"], "int32[2]", int, 2) + TwoU32 = _Type(0x15, "2I", ["2u32"], "uint32[2]", int, 2) + TwoS64 = _Type(0x16, "2q", ["2s64", "vs64"], "int16[2]", int, 2) + TwoU64 = _Type(0x17, "2Q", ["2u64", "vu64"], "uint16[2]", int, 2) + TwoFloat = _Type(0x18, "2f", ["2f"], "float[2]", float, 2) + TwoDouble = _Type(0x19, "2d", ["2d", "vd"], "double[2]", float, 2) + + ThreeS8 = _Type(0x1a, "3b", ["3s8"], "int8[3]", int, 3) + ThreeU8 = _Type(0x1b, "3B", ["3u8"], "uint8[3]", int, 3) + ThreeS16 = _Type(0x1c, "3h", ["3s16"], "int16[3]", int, 3) + ThreeU16 = _Type(0x1d, "3H", ["3u16"], "uint16[3]", int, 3) + ThreeS32 = _Type(0x1e, "3i", ["3s32"], "int32[3]", int, 3) + ThreeU32 = _Type(0x1f, "3I", ["3u32"], "uint32[3]", int, 3) + ThreeS64 = _Type(0x20, "3q", ["3s64"], "int64[3]", int, 3) + ThreeU64 = _Type(0x21, "3Q", ["3u64"], "uint64[3]", int, 3) + ThreeFloat = _Type(0x22, "3f", ["3f"], "float[3]", float, 3) + ThreeDouble = _Type(0x23, "3d", ["3d"], "double[3]", float, 3) + + FourS8 = _Type(0x24, "4b", ["4s8"], "int8[4]", int, 4) + FourU8 = _Type(0x25, "4B", ["4u8"], "uint8[4]", int, 4) + FourS16 = _Type(0x26, "4h", ["4s16"], "int16[4]", int, 4) + FourU16 = _Type(0x27, "4H", ["4u16"], "uint8[4]", int, 4) + FourS32 = _Type(0x28, "4i", ["4s32", "vs32"], "int32[4]", int, 4) + FourU32 = _Type(0x29, "4I", ["4u32", "vs32"], "uint32[4]", int, 4) + FourS64 = _Type(0x2a, "4q", ["4s64"], "int64[4]", int, 4) + FourU64 = _Type(0x2b, "4Q", ["4u64"], "uint64[4]", int, 4) + FourFloat = _Type(0x2c, "4f", ["4f", "vf"], "float[4]", float, 4) + FourDouble = _Type(0x2d, "4d", ["4d"], "double[4]", float, 4) + + Attr = _Type(0x2e, "s", ["attr"], "char[]", None) + Array = _Type(0x2f, "", ["array"], "", None) + + VecS8 = _Type(0x30, "16b", ["vs8"], "int8[16]", int, 16) + VecU8 = _Type(0x31, "16B", ["vu8"], "uint8[16]", int, 16) + VecS16 = _Type(0x32, "8h", ["vs16"], "int8[8]", int, 8) + VecU16 = _Type(0x33, "8H", ["vu16"], "uint8[8]", int, 8) + + Bool = _Type(0x34, "b", ["bool", "b"], "bool", int) + TwoBool = _Type(0x35, "2b", ["2b"], "bool[2]", int, 2) + ThreeBool = _Type(0x36, "3b", ["3b"], "bool[3]", int, 3) + FourBool = _Type(0x37, "4b", ["4b"], "bool[4]", int, 4) + VecBool = _Type(0x38, "16b", ["vb"], "bool[16]", int, 16) + + @classmethod + def from_val(cls, value): + for i in cls: + if i.value.id == value & ARRAY_MASK: + return i + raise ValueError(f"Unknown node type {value}") + + +class ServicesMode(enum.Enum): + Operation = "operation" + Debug = "debug" + Test = "test" + Factory = "factory" + + +class Compression(enum.Enum): + Lz77 = "lz77" + None_ = "none" diff --git a/eaapi/crypt.py b/eaapi/crypt.py index e92fac1..55ca2e4 100644 --- a/eaapi/crypt.py +++ b/eaapi/crypt.py @@ -1,49 +1,50 @@ -import binascii -import hashlib -import time -import re - -from Crypto.Cipher import ARC4 - -from .misc import assert_true -from .keys import EA_KEY - - -def new_prng(): - state = 0x41c64e6d - - while True: - x = (state * 0x838c9cda) + 0x6072 - # state = (state * 0x41c64e6d + 0x3039) - # state = (state * 0x41c64e6d + 0x3039) - state = (state * 0xc2a29a69 + 0xd3dc167e) & 0xffffffff - yield (x & 0x7fff0000) | state >> 15 & 0xffff - - -prng = new_prng() - - -def validate_key(info): - match = re.match(r"^(\d)-([0-9a-f]{8})-([0-9a-f]{4})$", info) - assert_true(match, "Invalid eamuse info key") - version = match.group(1) - assert_true(version == "1", f"Unsupported encryption version ({version})") - - seconds = binascii.unhexlify(match.group(2)) # 4 bytes - rng = binascii.unhexlify(match.group(3)) # 2 bytes - return seconds, rng - - -def get_key(prng_=None): - return f"1-{int(time.time()):08x}-{(next(prng_ or prng) & 0xffff):04x}" - - -def ea_symmetric_crypt(data, info): - seconds, rng = validate_key(info) - - key = hashlib.md5(seconds + rng + EA_KEY).digest() - - return ARC4.new(key).encrypt(data) - - -__all__ = ("new_prng", "prng", "validate_key", "get_key", "ea_symmetric_crypt") +import binascii +import hashlib +import time +import re + +from Crypto.Cipher import ARC4 + +from .misc import assert_true +from .keys import EA_KEY + + +def new_prng(): + state = 0x41c64e6d + + while True: + x = (state * 0x838c9cda) + 0x6072 + # state = (state * 0x41c64e6d + 0x3039) + # state = (state * 0x41c64e6d + 0x3039) + state = (state * 0xc2a29a69 + 0xd3dc167e) & 0xffffffff + yield (x & 0x7fff0000) | state >> 15 & 0xffff + + +prng = new_prng() + + +def validate_key(info): + match = re.match(r"^(\d)-([0-9a-f]{8})-([0-9a-f]{4})$", info) + assert_true(match is not None, "Invalid eamuse info key") + assert match is not None + version = match.group(1) + assert_true(version == "1", f"Unsupported encryption version ({version})") + + seconds = binascii.unhexlify(match.group(2)) # 4 bytes + rng = binascii.unhexlify(match.group(3)) # 2 bytes + return seconds, rng + + +def get_key(prng_=None): + return f"1-{int(time.time()):08x}-{(next(prng_ or prng) & 0xffff):04x}" + + +def ea_symmetric_crypt(data, info): + seconds, rng = validate_key(info) + + key = hashlib.md5(seconds + rng + EA_KEY).digest() + + return ARC4.new(key).encrypt(data) + + +__all__ = ("new_prng", "prng", "validate_key", "get_key", "ea_symmetric_crypt") diff --git a/eaapi/decoder.py b/eaapi/decoder.py index 6ce0a52..eaa487a 100644 --- a/eaapi/decoder.py +++ b/eaapi/decoder.py @@ -1,222 +1,241 @@ -import math -import struct -import io - -from html import unescape - -try: - from lxml import etree -except ModuleNotFoundError: - print("W", "lxml not found, XML strings will not be supported") - etree = None - - -from .packer import Packer -from .const import ( - NAME_MAX_COMPRESSED, NAME_MAX_DECOMPRESSED, ATTR, PACK_ALPHABET, END_NODE, END_DOC, ARRAY_BIT, - ENCODING, CONTENT, CONTENT_COMP, CONTENT_FULL, XML_ENCODING, Type -) -from .misc import unpack, py_encoding, assert_true -from .node import XMLNode -from .exception import DecodeError - - -class Decoder: - def __init__(self, packet): - self.stream = io.BytesIO(packet) - self.is_xml_string = packet.startswith(b"<") - self.encoding = None - self.compressed = False - self.has_data = False - self.packer = None - - @classmethod - def decode(cls, packet): - return cls(packet).unpack() - - def read(self, s_format, single=True, align=True): - if s_format == "S": - length = self.read("L") - if self.packer: - self.packer.notify_skipped(length) - return self.stream.read(length) - if s_format == "s": - length = self.read("L") - if self.packer: - self.packer.notify_skipped(length) - raw = self.stream.read(length) - return raw.decode(py_encoding(self.encoding)).rstrip("\0") - - length = struct.calcsize("=" + s_format) - if self.packer and align: - self.stream.seek(self.packer.request_allocation(length)) - data = self.stream.read(length) - assert_true(len(data) == length, "EOF reached", DecodeError) - value = struct.unpack(">" + s_format, data) - return value[0] if single else value - - def _read_node_value(self, node): - fmt = node.type.value.fmt - count = 1 - if node.is_array: - length = struct.calcsize("=" + fmt) - count = self.read("I") // length - values = [] - for _ in range(count): - values.append(self.read(fmt, single=len(fmt) == 1, align=False)) - self.packer.notify_skipped(count * length) - return values - - node.value = self.read(fmt, single=len(fmt) == 1) - - def _read_metadata_name(self): - length = self.read("B") - - if not self.compressed: - if length < 0x80: - assert_true(length >= 0x40, "Invalid name length", DecodeError) - # i.e. length = (length & ~0x40) + 1 - length -= 0x3f - else: - length = (length << 8) | self.read("B") - # i.e. length = (length & ~0x8000) + 0x41 - length -= 0x7fbf - assert_true(length <= NAME_MAX_DECOMPRESSED, "Name length too long", DecodeError) - - name = self.stream.read(length) - assert_true(len(name) == length, "Not enough bytes to read name", DecodeError) - return name.decode(self.encoding) - - out = "" - if length == 0: - return out - - assert_true(length <= NAME_MAX_COMPRESSED, "Name length too long", DecodeError) - - no_bytes = math.ceil((length * 6) / 8) - unpacked = unpack(self.stream.read(no_bytes), 6)[:length] - return "".join(PACK_ALPHABET[i] for i in unpacked) - - def _read_metadata(self, type_): - name = self._read_metadata_name() - node = XMLNode(name, type_, None, encoding=self.encoding) - - while (child := self.read("B")) != END_NODE: - if child == ATTR: - attr = self._read_metadata_name() - assert_true(not attr.startswith("__"), "Invalid binary node name", DecodeError) - # Abuse the array here to maintain order - node.children.append(attr) - else: - node.children.append(self._read_metadata(child)) - is_array = not not (type_ & ARRAY_BIT) - if is_array: - node.value = [] - return node - - def _read_databody(self, node: XMLNode): - self._read_node_value(node) - - children = list(node.children) - node.children = [] - for i in children: - if isinstance(i, XMLNode): - node.children.append(self._read_databody(i)) - else: - node[i] = self.read("s") - - return node - - def _read_magic(self): - magic, contents, enc, enc_comp = struct.unpack(">BBBB", self.stream.read(4)) - - assert_true(magic == 0xA0, "Not a packet", DecodeError) - assert_true(~enc & 0xFF == enc_comp, "Malformed packet header", DecodeError) - assert_true(enc in ENCODING, "Unknown packet encoding", DecodeError) - assert_true(contents in CONTENT, "Invalid packet contents", DecodeError) - self.compressed = contents in CONTENT_COMP - self.has_data = contents in CONTENT_FULL or contents == 0x44 - self.encoding = ENCODING[enc] - - def _read_xml_string(self): - assert_true(etree is not None, "lxml missing", DecodeError) - parser = etree.XMLParser(remove_comments=True) - tree = etree.XML(self.stream.read(), parser) - self.encoding = XML_ENCODING[tree.getroottree().docinfo.encoding.upper()] - self.compressed = False - self.has_data = True - - def walk(node): - attrib = {**node.attrib} - type_str = attrib.pop("__type", "void") - for i in Type: - if type_str in i.value.names: - type_ = i - break - else: - raise ValueError("Invalid node type") - attrib.pop("__size", None) - count = attrib.pop("__count", None) - - is_array = count is not None - count = 1 if count is None else int(count) - - d_type = type_.value - - if d_type.size == 1 and not is_array: - value = d_type._parse(node.text or "") - else: - data = node.text.split(" ") - - value = [] - for i in range(0, len(data), d_type.size): - value.append(d_type._parse(data[i:i+d_type.size])) - if not is_array: - value = value[0] - - xml_node = XMLNode(node.tag, type_, value, encoding=self.encoding) - for i in node.getchildren(): - xml_node.children.append(walk(i)) - - for i in attrib: - xml_node[i] = unescape(attrib[i]) - - return xml_node - - return walk(tree) - - def unpack(self): - try: - return self._unpack() - except struct.error as e: - raise DecodeError(e) - - def _unpack(self): - if self.is_xml_string: - return self._read_xml_string() - - self._read_magic() - - header_len = self.read("I") - start = self.stream.tell() - schema = self._read_metadata(self.read("B")) - assert_true(self.read("B") == END_DOC, "Unterminated schema", DecodeError) - padding = header_len - (self.stream.tell() - start) - assert_true(padding >= 0, "Invalid schema definition", DecodeError) - assert_true(all(i == 0 for i in self.stream.read(padding)), "Invalid schema padding", DecodeError) - - body_len = self.read("I") - start = self.stream.tell() - self.packer = Packer(start) - data = self._read_databody(schema) - self.stream.seek(self.packer.request_allocation(0)) - padding = body_len - (self.stream.tell() - start) - assert_true(padding >= 0, "Data shape not match schema", DecodeError) - assert_true(all(i == 0 for i in self.stream.read(padding)), "Invalid data padding", DecodeError) - - assert_true(self.stream.read(1) == b"", "Trailing data unconsumed", DecodeError) - - return data - - -__all__ = ("Decoder", ) +import math +import struct +import io + +from html import unescape + +try: + from lxml import etree +except ModuleNotFoundError: + print("W", "lxml not found, XML strings will not be supported") + etree = None + + +from .packer import Packer +from .const import ( + NAME_MAX_COMPRESSED, NAME_MAX_DECOMPRESSED, ATTR, PACK_ALPHABET, END_NODE, END_DOC, ARRAY_BIT, + ENCODING, CONTENT, CONTENT_COMP, CONTENT_FULL, XML_ENCODING, DEFAULT_ENCODING, Type +) +from .misc import unpack, py_encoding, assert_true +from .node import XMLNode +from .exception import DecodeError + + +class Decoder: + def __init__(self, packet): + self.stream = io.BytesIO(packet) + self.is_xml_string = packet.startswith(b"<") + self.encoding = None + self.compressed = False + self.has_data = False + self.packer = None + + @classmethod + def decode(cls, packet): + return cls(packet).unpack() + + def read(self, s_format, single=True, align=True): + if s_format == "S": + length = self.read("L") + if self.packer: + self.packer.notify_skipped(length) + return self.stream.read(length) + if s_format == "s": + length = self.read("L") + if self.packer: + self.packer.notify_skipped(length) + raw = self.stream.read(length) + return raw.decode(py_encoding(self.encoding or DEFAULT_ENCODING)).rstrip("\0") + + length = struct.calcsize("=" + s_format) + if self.packer and align: + self.stream.seek(self.packer.request_allocation(length)) + data = self.stream.read(length) + assert_true(len(data) == length, "EOF reached", DecodeError) + value = struct.unpack(">" + s_format, data) + return value[0] if single else value + + def _read_node_value(self, node: XMLNode) -> None: + fmt = node.type.value.fmt + count = 1 + if node.is_array: + length = struct.calcsize("=" + fmt) + nbytes = self.read("I") + assert isinstance(nbytes, int) + count = nbytes // length + values = [] + for _ in range(count): + values.append(self.read(fmt, single=len(fmt) == 1, align=False)) + + assert self.packer is not None + self.packer.notify_skipped(count * length) + node.value = values + else: + node.value = self.read(fmt, single=len(fmt) == 1) + + def _read_metadata_name(self) -> str: + length = self.read("B") + assert isinstance(length, int) + + if not self.compressed: + if length < 0x80: + assert_true(length >= 0x40, "Invalid name length", DecodeError) + # i.e. length = (length & ~0x40) + 1 + length -= 0x3f + else: + extra = self.read("B") + assert isinstance(extra, int) + length = (length << 8) | extra + # i.e. length = (length & ~0x8000) + 0x41 + length -= 0x7fbf + assert_true(length <= NAME_MAX_DECOMPRESSED, "Name length too long", DecodeError) + + name = self.stream.read(length) + assert_true(len(name) == length, "Not enough bytes to read name", DecodeError) + return name.decode(self.encoding or "") + + out = "" + if length == 0: + return out + + assert_true(length <= NAME_MAX_COMPRESSED, "Name length too long", DecodeError) + + no_bytes = math.ceil((length * 6) / 8) + unpacked = unpack(self.stream.read(no_bytes), 6)[:length] + return "".join(PACK_ALPHABET[i] for i in unpacked) + + def _read_metadata(self, type_): + name = self._read_metadata_name() + node = XMLNode(name, type_, None, encoding=self.encoding or DEFAULT_ENCODING) + + while (child := self.read("B")) != END_NODE: + if child == ATTR: + attr = self._read_metadata_name() + assert_true(not attr.startswith("__"), "Invalid binary node name", DecodeError) + # Abuse the array here to maintain order + node.children.append(attr) + else: + node.children.append(self._read_metadata(child)) + + if type_ & ARRAY_BIT: + node.value = [] + return node + + def _read_databody(self, node: XMLNode): + self._read_node_value(node) + + children = list(node.children) + node.children = [] + for i in children: + if isinstance(i, XMLNode): + node.children.append(self._read_databody(i)) + else: + node[i] = self.read("s") + + return node + + def _read_magic(self): + magic, contents, enc, enc_comp = struct.unpack(">BBBB", self.stream.read(4)) + + assert_true(magic == 0xA0, "Not a packet", DecodeError) + assert_true(~enc & 0xFF == enc_comp, "Malformed packet header", DecodeError) + assert_true(enc in ENCODING, "Unknown packet encoding", DecodeError) + assert_true(contents in CONTENT, "Invalid packet contents", DecodeError) + self.compressed = contents in CONTENT_COMP + self.has_data = contents in CONTENT_FULL or contents == 0x44 + self.encoding = ENCODING[enc] + + def _read_xml_string(self): + assert_true(etree is not None, "lxml missing", DecodeError) + assert etree is not None + + parser = etree.XMLParser(remove_comments=True) + tree = etree.XML(self.stream.read(), parser) + self.encoding = XML_ENCODING[tree.getroottree().docinfo.encoding.upper()] + self.compressed = False + self.has_data = True + + def walk(node): + attrib = {**node.attrib} + type_str = attrib.pop("__type", "void") + for i in Type: + if type_str in i.value.names: + type_ = i + break + else: + raise ValueError("Invalid node type") + attrib.pop("__size", None) + count = attrib.pop("__count", None) + + is_array = count is not None + count = 1 if count is None else int(count) + + d_type = type_.value + + if d_type.size == 1 and not is_array: + try: + value = d_type._parse(node.text or "") + except ValueError: + print(f"Failed to parse {node.tag} ({d_type.names[0]}): {repr(node.text)}") + raise + else: + data = node.text.split(" ") + + value = [] + for i in range(0, len(data), d_type.size): + value.append(d_type._parse(data[i:i+d_type.size])) + if not is_array: + value = value[0] + + xml_node = XMLNode(node.tag, type_, value, encoding=self.encoding or DEFAULT_ENCODING) + for i in node.getchildren(): + xml_node.children.append(walk(i)) + + for i in attrib: + xml_node[i] = unescape(attrib[i]) + + return xml_node + + return walk(tree) + + def unpack(self): + try: + return self._unpack() + except struct.error as e: + raise DecodeError(e) + + def _unpack(self): + if self.is_xml_string: + return self._read_xml_string() + + self._read_magic() + + header_len = self.read("I") + assert isinstance(header_len, int) + start = self.stream.tell() + schema = self._read_metadata(self.read("B")) + assert_true(self.read("B") == END_DOC, "Unterminated schema", DecodeError) + padding = header_len - (self.stream.tell() - start) + assert_true(padding >= 0, "Invalid schema definition", DecodeError) + assert_true( + all(i == 0 for i in self.stream.read(padding)), "Invalid schema padding", DecodeError + ) + + body_len = self.read("I") + assert isinstance(body_len, int) + start = self.stream.tell() + self.packer = Packer(start) + data = self._read_databody(schema) + self.stream.seek(self.packer.request_allocation(0)) + padding = body_len - (self.stream.tell() - start) + assert_true(padding >= 0, "Data shape not match schema", DecodeError) + assert_true( + all(i == 0 for i in self.stream.read(padding)), "Invalid data padding", DecodeError + ) + + assert_true(self.stream.read(1) == b"", "Trailing data unconsumed", DecodeError) + + return data + + +__all__ = ("Decoder", ) diff --git a/eaapi/encoder.py b/eaapi/encoder.py index 6a79af6..4cb8acd 100644 --- a/eaapi/encoder.py +++ b/eaapi/encoder.py @@ -1,157 +1,166 @@ -import struct -import io - -from .packer import Packer -from .misc import pack, py_encoding, assert_true -from .const import ( - PACK_ALPHABET, DEFAULT_ENCODING, ENCODING, ENCODING_BACK, NAME_MAX_DECOMPRESSED, ARRAY_BIT, - ATTR, END_NODE, END_DOC, CONTENT_COMP_FULL, CONTENT_COMP_SCHEMA, CONTENT_ASCII_FULL, - CONTENT_ASCII_SCHEMA -) -from .exception import EncodeError - - -class Encoder: - def __init__(self, encoding=DEFAULT_ENCODING): - self.stream = io.BytesIO() - assert_true(encoding in ENCODING_BACK, f"Unknown encoding {encoding}", EncodeError) - self.encoding = ENCODING_BACK[encoding] - self.packer = None - self._compressed = False - - @classmethod - def encode(cls, tree, xml_string=False): - if xml_string: - return tree.to_str(pretty=False).encode(tree.encoding) - encoder = cls(tree.encoding) - encoder.pack(tree) - return encoder.stream.getvalue() - - def align(self, to=4, pad_char=b"\0"): - if to < 2: - return - if (dist := self.stream.tell() % to) == 0: - return - self.stream.write(pad_char * (to - dist)) - - def write(self, s_format, value, single=True): - if s_format == "S": - self.write("L", len(value)) - self.stream.write(value) - self.packer.notify_skipped(len(value)) - return - if s_format == "s": - value = value.encode(py_encoding(ENCODING[self.encoding])) + b"\0" - self.write("L", len(value)) - self.stream.write(value) - self.packer.notify_skipped(len(value)) - return - - length = struct.calcsize("=" + s_format) - - if not isinstance(value, list): - value = [value] - count = len(value) - if count != 1: - self.write("L", count * length) - self.packer.notify_skipped(count * length) - - for x in value: - if self.packer and count == 1: - self.stream.seek(self.packer.request_allocation(length)) - - try: - if single: - self.stream.write(struct.pack(f">{s_format}", x)) - else: - self.stream.write(struct.pack(f">{s_format}", *x)) - except struct.error: - raise ValueError(f"Failed to pack {s_format}: {repr(x)}") - - def _write_node_value(self, type_, value): - fmt = type_.value.fmt - if fmt == "s": - self.write("s", value) - else: - self.write(fmt, value, single=len(fmt) == 1) - - def _write_metadata_name(self, name): - if not self._compressed: - assert_true(len(name) <= NAME_MAX_DECOMPRESSED, "Name length too long", EncodeError) - if len(name) > 64: - self.write("H", len(name) + 0x7fbf) - else: - self.write("B", len(name) + 0x3f) - self.stream.write(name.encode(py_encoding(ENCODING[self.encoding]))) - return - - assert_true(all(i in PACK_ALPHABET for i in name), f"Invalid schema name {name} (invalid chars)", EncodeError) - assert_true(len(name) < 256, f"Invalid schema name {name} (too long)", EncodeError) - self.write("B", len(name)) - if len(name) == 0: - return - - name = bytearray(PACK_ALPHABET.index(i) for i in name) - self.stream.write(pack(name, 6)) - - def _write_metadata(self, node): - self.write("B", node.type.value.id | (ARRAY_BIT if node.is_array else 0x00)) - self._write_metadata_name(node.name) - - for attr in node.attributes: - self.write("B", ATTR) - self._write_metadata_name(attr) - for child in node: - self._write_metadata(child) - self.write("B", END_NODE) - - def _write_databody(self, data): - self._write_node_value(data.type, data.value) - - for attr in data.attributes: - self.align() - self.write("s", data[attr]) - for child in data: - self._write_databody(child) - - def _write_magic(self, has_data=True): - if has_data: - contents = CONTENT_COMP_FULL if self._compressed else CONTENT_ASCII_FULL - else: - contents = CONTENT_COMP_SCHEMA if self._compressed else CONTENT_ASCII_SCHEMA - - enc_comp = ~self.encoding & 0xFF - self.stream.write(struct.pack(">BBBB", 0xA0, contents, self.encoding, enc_comp)) - - def pack(self, node): - try: - return self._pack(node) - except struct.error as e: - return EncodeError(e) - - def _pack(self, node): - self._compressed = node.can_compress # Opportunically compress if we can - self._write_magic() - - schema_start = self.stream.tell() - self.write("I", 0) - self._write_metadata(node) - self.write("B", END_DOC) - self.align() - schema_end = self.stream.tell() - self.stream.seek(schema_start) - self.write("I", schema_end - schema_start - 4) - - self.stream.seek(schema_end) - self.write("I", 0) - self.packer = Packer(self.stream.tell()) - self._write_databody(node) - self.stream.seek(0, io.SEEK_END) - self.align() - node_end = self.stream.tell() - self.stream.seek(schema_end) - self.packer = None - self.write("I", node_end - schema_end - 4) - - -__all__ = ("Encoder", ) +import struct +import io + +from .packer import Packer +from .misc import pack, py_encoding, assert_true +from .const import ( + PACK_ALPHABET, DEFAULT_ENCODING, ENCODING, ENCODING_BACK, NAME_MAX_DECOMPRESSED, ARRAY_BIT, + ATTR, END_NODE, END_DOC, CONTENT_COMP_FULL, CONTENT_COMP_SCHEMA, CONTENT_ASCII_FULL, + CONTENT_ASCII_SCHEMA +) +from .exception import EncodeError +from .node import XMLNode + + +class Encoder: + def __init__(self, encoding=DEFAULT_ENCODING): + self.stream = io.BytesIO() + assert_true(encoding in ENCODING_BACK, f"Unknown encoding {encoding}", EncodeError) + self.encoding = ENCODING_BACK[encoding] + self.packer = None + self._compressed = False + + @classmethod + def encode(cls, tree, xml_string=False): + if xml_string: + return tree.to_str(pretty=False).encode(tree.encoding) + encoder = cls(tree.encoding) + encoder.pack(tree) + return encoder.stream.getvalue() + + def align(self, to=4, pad_char=b"\0"): + if to < 2: + return + if (dist := self.stream.tell() % to) == 0: + return + self.stream.write(pad_char * (to - dist)) + + def write(self, s_format, value, single=True): + if s_format == "S": + assert self.packer is not None + self.write("L", len(value)) + self.stream.write(value) + self.packer.notify_skipped(len(value)) + return + if s_format == "s": + assert self.packer is not None + value = value.encode(py_encoding(ENCODING[self.encoding])) + b"\0" + self.write("L", len(value)) + self.stream.write(value) + self.packer.notify_skipped(len(value)) + return + + length = struct.calcsize("=" + s_format) + + if isinstance(value, list): + assert self.packer is not None + self.write("L", len(value) * length) + self.packer.notify_skipped(len(value) * length) + + for x in value: + try: + if single: + self.stream.write(struct.pack(f">{s_format}", x)) + else: + self.stream.write(struct.pack(f">{s_format}", *x)) + except struct.error: + raise ValueError(f"Failed to pack {s_format}: {repr(x)}") + else: + if self.packer: + self.stream.seek(self.packer.request_allocation(length)) + + try: + if single: + self.stream.write(struct.pack(f">{s_format}", value)) + else: + self.stream.write(struct.pack(f">{s_format}", *value)) + except struct.error: + raise ValueError(f"Failed to pack {s_format}: {repr(value)}") + + def _write_node_value(self, type_, value): + fmt = type_.value.fmt + if fmt == "s": + self.write("s", value) + else: + self.write(fmt, value, single=len(fmt) == 1) + + def _write_metadata_name(self, name): + if not self._compressed: + assert_true(len(name) <= NAME_MAX_DECOMPRESSED, "Name length too long", EncodeError) + if len(name) > 64: + self.write("H", len(name) + 0x7fbf) + else: + self.write("B", len(name) + 0x3f) + self.stream.write(name.encode(py_encoding(ENCODING[self.encoding]))) + return + + assert_true(all(i in PACK_ALPHABET for i in name), f"Invalid schema name {name} (invalid chars)", EncodeError) + assert_true(len(name) < 256, f"Invalid schema name {name} (too long)", EncodeError) + self.write("B", len(name)) + if len(name) == 0: + return + + name = bytearray(PACK_ALPHABET.index(i) for i in name) + self.stream.write(pack(name, 6)) + + def _write_metadata(self, node): + self.write("B", node.type.value.id | (ARRAY_BIT if node.is_array else 0x00)) + self._write_metadata_name(node.name) + + for attr in node.attributes: + self.write("B", ATTR) + self._write_metadata_name(attr) + for child in node: + self._write_metadata(child) + self.write("B", END_NODE) + + def _write_databody(self, data: XMLNode): + self._write_node_value(data.type, data.value) + + for attr in data.attributes: + self.align() + self.write("s", data[attr]) + for child in data: + self._write_databody(child) + + def _write_magic(self, has_data=True): + if has_data: + contents = CONTENT_COMP_FULL if self._compressed else CONTENT_ASCII_FULL + else: + contents = CONTENT_COMP_SCHEMA if self._compressed else CONTENT_ASCII_SCHEMA + + enc_comp = ~self.encoding & 0xFF + self.stream.write(struct.pack(">BBBB", 0xA0, contents, self.encoding, enc_comp)) + + def pack(self, node): + try: + return self._pack(node) + except struct.error as e: + return EncodeError(e) + + def _pack(self, node): + self._compressed = node.can_compress # Opportunically compress if we can + self._write_magic() + + schema_start = self.stream.tell() + self.write("I", 0) + self._write_metadata(node) + self.write("B", END_DOC) + self.align() + schema_end = self.stream.tell() + self.stream.seek(schema_start) + self.write("I", schema_end - schema_start - 4) + + self.stream.seek(schema_end) + self.write("I", 0) + self.packer = Packer(self.stream.tell()) + self._write_databody(node) + self.stream.seek(0, io.SEEK_END) + self.align() + node_end = self.stream.tell() + self.stream.seek(schema_end) + self.packer = None + self.write("I", node_end - schema_end - 4) + + +__all__ = ("Encoder", ) diff --git a/eaapi/exception.py b/eaapi/exception.py index d544a17..e613433 100644 --- a/eaapi/exception.py +++ b/eaapi/exception.py @@ -1,22 +1,34 @@ -class EAAPIException(Exception): - pass - - -class CheckFailed(EAAPIException): - pass - - -class InvalidCard(CheckFailed): - pass - - -class DecodeError(CheckFailed): - pass - - -class EncodeError(CheckFailed): - pass - - -class InvalidModel(EAAPIException): - pass +class EAAPIException(Exception): + pass + + +class CheckFailed(EAAPIException): + pass + + +class InvalidCard(CheckFailed): + pass + + +class DecodeError(CheckFailed): + pass + + +class EncodeError(CheckFailed): + pass + + +class InvalidModel(EAAPIException): + pass + + +class XMLStrutureError(EAAPIException): + pass + + +class NodeNotFound(XMLStrutureError, IndexError): + pass + + +class AttributeNotFound(XMLStrutureError, KeyError): + pass diff --git a/eaapi/lz77.py b/eaapi/lz77.py index d6a3561..0fd2962 100644 --- a/eaapi/lz77.py +++ b/eaapi/lz77.py @@ -1,135 +1,135 @@ -from .misc import assert_true - - -WINDOW_SIZE = 0x1000 -WINDOW_MASK = WINDOW_SIZE - 1 -THRESHOLD = 3 -INPLACE_THRESHOLD = 0xA -LOOK_RANGE = 0x200 -MAX_LEN = 0xF + THRESHOLD -MAX_BUFFER = 0x10 + 1 - - -def match_current(window, pos, max_len, data, dpos): - length = 0 - while ( - dpos + length < len(data) - and length < max_len - and window[(pos + length) & WINDOW_MASK] == data[dpos + length] - and length < MAX_LEN - ): - length += 1 - return length - - -def match_window(window, pos, data, d_pos): - max_pos = 0 - max_len = 0 - for i in range(THRESHOLD, LOOK_RANGE): - length = match_current(window, (pos - i) & WINDOW_MASK, i, data, d_pos) - if length >= INPLACE_THRESHOLD: - return (i, length) - if length >= THRESHOLD: - max_pos = i - max_len = length - if max_len >= THRESHOLD: - return (max_pos, max_len) - return None - - -def lz77_compress(data): - output = bytearray() - window = [0] * WINDOW_SIZE - current_pos = 0 - current_window = 0 - current_buffer = 0 - flag_byte = 0 - bit = 0 - buffer = [0] * MAX_BUFFER - pad = 3 - while current_pos < len(data): - flag_byte = 0 - current_buffer = 0 - for bit_pos in range(8): - if current_pos >= len(data): - pad = 0 - flag_byte = flag_byte >> (8 - bit_pos) - buffer[current_buffer] = 0 - buffer[current_buffer + 1] = 0 - current_buffer += 2 - break - else: - found = match_window(window, current_window, data, current_pos) - if found is not None and found[1] >= THRESHOLD: - pos, length = found - - byte1 = pos >> 4 - byte2 = (((pos & 0x0F) << 4) | ((length - THRESHOLD) & 0x0F)) - buffer[current_buffer] = byte1 - buffer[current_buffer + 1] = byte2 - current_buffer += 2 - bit = 0 - for _ in range(length): - window[current_window & WINDOW_MASK] = data[current_pos] - current_pos += 1 - current_window += 1 - else: - buffer[current_buffer] = data[current_pos] - window[current_window] = data[current_pos] - current_pos += 1 - current_window += 1 - current_buffer += 1 - bit = 1 - - flag_byte = (flag_byte >> 1) | ((bit & 1) << 7) - current_window = current_window & WINDOW_MASK - - assert_true(current_buffer < MAX_BUFFER, f"current buffer {current_buffer} > max buffer {MAX_BUFFER}") - - output.append(flag_byte) - for i in range(current_buffer): - output.append(buffer[i]) - for _ in range(pad): - output.append(0) - - return bytes(output) - - -def lz77_decompress(data): - output = bytearray() - cur_byte = 0 - window = [0] * WINDOW_SIZE - window_cursor = 0 - - while cur_byte < len(data): - flag = data[cur_byte] - cur_byte += 1 - - for i in range(8): - if (flag >> i) & 1 == 1: - output.append(data[cur_byte]) - window[window_cursor] = data[cur_byte] - window_cursor = (window_cursor + 1) & WINDOW_MASK - cur_byte += 1 - else: - w = ((data[cur_byte]) << 8) | (data[cur_byte + 1]) - if w == 0: - return bytes(output) - - cur_byte += 2 - position = ((window_cursor - (w >> 4)) & WINDOW_MASK) - length = (w & 0x0F) + THRESHOLD - - for _ in range(length): - b = window[position & WINDOW_MASK] - output.append(b) - window[window_cursor] = b - window_cursor = (window_cursor + 1) & WINDOW_MASK - position += 1 - - return bytes(output) - - -__all__ = ( - "lz77_compress", "lz77_decompress" -) +from .misc import assert_true + + +WINDOW_SIZE = 0x1000 +WINDOW_MASK = WINDOW_SIZE - 1 +THRESHOLD = 3 +INPLACE_THRESHOLD = 0xA +LOOK_RANGE = 0x200 +MAX_LEN = 0xF + THRESHOLD +MAX_BUFFER = 0x10 + 1 + + +def match_current(window: bytes, pos: int, max_len: int, data: bytes, dpos: int) -> int: + length = 0 + while ( + dpos + length < len(data) + and length < max_len + and window[(pos + length) & WINDOW_MASK] == data[dpos + length] + and length < MAX_LEN + ): + length += 1 + return length + + +def match_window(window: bytes, pos: int, data: bytes, d_pos: int) -> None | tuple[int, int]: + max_pos = 0 + max_len = 0 + for i in range(THRESHOLD, LOOK_RANGE): + length = match_current(window, (pos - i) & WINDOW_MASK, i, data, d_pos) + if length >= INPLACE_THRESHOLD: + return (i, length) + if length >= THRESHOLD: + max_pos = i + max_len = length + if max_len >= THRESHOLD: + return (max_pos, max_len) + return None + + +def lz77_compress(data: bytes) -> bytes: + output = bytearray() + window = bytearray(WINDOW_SIZE) + current_pos = 0 + current_window = 0 + current_buffer = 0 + flag_byte = 0 + bit = 0 + buffer = [0] * MAX_BUFFER + pad = 3 + while current_pos < len(data): + flag_byte = 0 + current_buffer = 0 + for bit_pos in range(8): + if current_pos >= len(data): + pad = 0 + flag_byte = flag_byte >> (8 - bit_pos) + buffer[current_buffer] = 0 + buffer[current_buffer + 1] = 0 + current_buffer += 2 + break + else: + found = match_window(window, current_window, data, current_pos) + if found is not None and found[1] >= THRESHOLD: + pos, length = found + + byte1 = pos >> 4 + byte2 = (((pos & 0x0F) << 4) | ((length - THRESHOLD) & 0x0F)) + buffer[current_buffer] = byte1 + buffer[current_buffer + 1] = byte2 + current_buffer += 2 + bit = 0 + for _ in range(length): + window[current_window & WINDOW_MASK] = data[current_pos] + current_pos += 1 + current_window += 1 + else: + buffer[current_buffer] = data[current_pos] + window[current_window] = data[current_pos] + current_pos += 1 + current_window += 1 + current_buffer += 1 + bit = 1 + + flag_byte = (flag_byte >> 1) | ((bit & 1) << 7) + current_window = current_window & WINDOW_MASK + + assert_true(current_buffer < MAX_BUFFER, f"current buffer {current_buffer} > max buffer {MAX_BUFFER}") + + output.append(flag_byte) + for i in range(current_buffer): + output.append(buffer[i]) + for _ in range(pad): + output.append(0) + + return bytes(output) + + +def lz77_decompress(data: bytes) -> bytes: + output = bytearray() + cur_byte = 0 + window = bytearray(WINDOW_SIZE) + window_cursor = 0 + + while cur_byte < len(data): + flag = data[cur_byte] + cur_byte += 1 + + for i in range(8): + if (flag >> i) & 1 == 1: + output.append(data[cur_byte]) + window[window_cursor] = data[cur_byte] + window_cursor = (window_cursor + 1) & WINDOW_MASK + cur_byte += 1 + else: + w = ((data[cur_byte]) << 8) | (data[cur_byte + 1]) + if w == 0: + return bytes(output) + + cur_byte += 2 + position = ((window_cursor - (w >> 4)) & WINDOW_MASK) + length = (w & 0x0F) + THRESHOLD + + for _ in range(length): + b = window[position & WINDOW_MASK] + output.append(b) + window[window_cursor] = b + window_cursor = (window_cursor + 1) & WINDOW_MASK + position += 1 + + return bytes(output) + + +__all__ = ( + "lz77_compress", "lz77_decompress" +) diff --git a/eaapi/misc.py b/eaapi/misc.py index 66382b0..2bb0b1b 100644 --- a/eaapi/misc.py +++ b/eaapi/misc.py @@ -1,70 +1,73 @@ -import inspect -import re - -from .exception import CheckFailed, InvalidModel - - -def assert_true(check, reason, exc=CheckFailed): - if not check: - line = inspect.stack()[1].code_context - print() - print("\n".join(line)) - raise exc(reason) - - -def py_encoding(name): - if name.startswith("shift-jis"): - return "shift-jis" - return name - - -def parse_model(model): - # e.g. KFC:J:A:A:2019020600 - match = re.match(r"^([A-Z0-9]{3}):([A-Z]):([A-Z]):([A-Z])(?::(\d{10}))?$", model) - if match is None: - raise InvalidModel - gamecode, dest, spec, rev, datecode = match.groups() - return gamecode, dest, spec, rev, datecode - - -def pack(data, width): - assert_true(1 <= width <= 8, "Invalid pack size") - assert_true(all(i < (1 << width) for i in data), "Data too large for packing") - bit_buf = in_buf = 0 - output = bytearray() - for i in data: - bit_buf |= i << (8 - width) - shift = min(8 - in_buf, width) - bit_buf <<= shift - in_buf += shift - if in_buf == 8: - output.append(bit_buf >> 8) - in_buf = width - shift - bit_buf = (bit_buf & 0xff) << in_buf - - if in_buf: - output.append(bit_buf >> in_buf) - - return bytes(output) - - -def unpack(data, width): - assert_true(1 <= width <= 8, "Invalid pack size") - bit_buf = in_buf = 0 - output = bytearray() - for i in data: - bit_buf |= i - bit_buf <<= width - in_buf - in_buf += 8 - while in_buf >= width: - output.append(bit_buf >> 8) - in_buf -= width - bit_buf = (bit_buf & 0xff) << min(width, in_buf) - - if in_buf: - output.append(bit_buf >> (8 + in_buf - width)) - - return bytes(output) - - -__all__ = ("assert_true", "py_encoding", "parse_model", "pack", "unpack") +import inspect +import re + +from typing import Type + +from .exception import CheckFailed, InvalidModel + + +def assert_true(check: bool, reason: str, exc: Type[Exception] = CheckFailed): + if not check: + line = inspect.stack()[1].code_context + if line: + print() + print("\n".join(line)) + raise exc(reason) + + +def py_encoding(name: str) -> str: + if name.startswith("shift-jis"): + return "shift-jis" + return name + + +def parse_model(model: str) -> tuple[str, str, str, str, str]: + # e.g. KFC:J:A:A:2019020600 + match = re.match(r"^([A-Z0-9]{3}):([A-Z]):([A-Z]):([A-Z])(?::(\d{10}))?$", model) + if match is None: + raise InvalidModel + gamecode, dest, spec, rev, datecode = match.groups() + return gamecode, dest, spec, rev, datecode + + +def pack(data, width: int) -> bytes: + assert_true(1 <= width <= 8, "Invalid pack size") + assert_true(all(i < (1 << width) for i in data), "Data too large for packing") + bit_buf = in_buf = 0 + output = bytearray() + for i in data: + bit_buf |= i << (8 - width) + shift = min(8 - in_buf, width) + bit_buf <<= shift + in_buf += shift + if in_buf == 8: + output.append(bit_buf >> 8) + in_buf = width - shift + bit_buf = (bit_buf & 0xff) << in_buf + + if in_buf: + output.append(bit_buf >> in_buf) + + return bytes(output) + + +def unpack(data, width: int) -> bytes: + assert_true(1 <= width <= 8, "Invalid pack size") + bit_buf = in_buf = 0 + output = bytearray() + for i in data: + bit_buf |= i + bit_buf <<= width - in_buf + in_buf += 8 + while in_buf >= width: + output.append(bit_buf >> 8) + in_buf -= width + bit_buf = (bit_buf & 0xff) << min(width, in_buf) + + if in_buf: + output.append(bit_buf >> (8 + in_buf - width)) + + return bytes(output) + + +__all__ = ("assert_true", "py_encoding", "parse_model", "pack", "unpack") diff --git a/eaapi/node.py b/eaapi/node.py index 7c7d9e3..3120f9c 100644 --- a/eaapi/node.py +++ b/eaapi/node.py @@ -1,155 +1,169 @@ -import binascii -import re - -from html import escape - -from .misc import assert_true -from .const import DEFAULT_ENCODING, NAME_MAX_COMPRESSED, XML_ENCODING_BACK, Type - - -class XMLNode: - def __init__(self, name, type_, value, attributes=None, encoding=DEFAULT_ENCODING): - self.name = name - self.type = type_ if isinstance(type_, Type) else Type.from_val(type_) - self.value = value - self.children = [] - self.attributes = {} - if attributes is not None: - for i in attributes: - self.attributes[i] = attributes[i] - self.encoding = encoding or DEFAULT_ENCODING - assert_true(encoding in XML_ENCODING_BACK, "Invalid encoding") - - @classmethod - def void(cls, __name, **attributes): - return cls(__name, Type.Void, (), attributes) - - @property - def is_array(self): - return isinstance(self.value, list) - - @property - def can_compress(self): - return ( - (len(self.name) <= NAME_MAX_COMPRESSED) - and all(i.can_compress for i in self.children) - ) - - def _xpath(self, attr, path): - if path: - child = path.pop(0) - for i in self.children: - if i.name == child: - return i._xpath(attr, path) - raise IndexError - if not attr: - return self - if attr in self.attributes: - return self.attributes[attr] - raise IndexError - - def xpath(self, path): - match = re.match(r"^(?:@([\w:]+)/)?((?:[\w:]+(?:/|$))+)", path) - if match is None: - raise ValueError - attr = match.group(1) - path = match.group(2).split("/") - return self._xpath(attr, path) - - def append(self, __name, __type=Type.Void, __value=(), **attributes): - child = XMLNode(__name, __type, __value, attributes) - self.children.append(child) - return child - - def __len__(self): - return len(self.children) - - def __iter__(self): - for i in self.children: - yield i - - def get(self, name, default=None): - try: - return self[name] - except IndexError: - return default - except KeyError: - return default - - def __getitem__(self, name): - if isinstance(name, int): - return self.children[name] - return self.attributes[name] - - def __setitem__(self, name, value): - self.attributes[name] = value - - def to_str(self, pretty=False): - return ( - f'' - + ("\n" if pretty else "") - + self._to_str(pretty) - ) - - def __str__(self): - return self.to_str(pretty=True) - - def _value_str(self, value): - if isinstance(value, list): - return " ".join(map(self._value_str, value)) - if self.type == Type.Blob: - return binascii.hexlify(value).decode() - if self.type == Type.IPv4: - return f"{value[0]}.{value[1]}.{value[2]}.{value[3]}" - if self.type in (Type.Float, Type.TwoFloat, Type.ThreeFloat): - return f"{value:.6f}" - if self.type == Type.Str: - return escape(str(value)) - - return str(value) - - def _to_str(self, pretty, indent=0): - if not pretty: - indent = 0 - nl = "\n" if pretty else "" - tag = f"{' ' * indent}<{self.name}" - - if self.type != Type.Void: - tag += f" __type=\"{self.type.value.names[0]}\"" - if self.type == Type.Blob: - tag += f" __size=\"{len(self.value)}\"" - if self.is_array: - tag += f" __count=\"{len(self.value)}\"" - - attributes = " ".join(f"{i}=\"{escape(j)}\"" for i, j in self.attributes.items()) - if attributes: - tag += " " + attributes - tag += ">" - if self.value is not None and self.type != Type.Void: - if self.is_array: - tag += " ".join(map(self._value_str, self.value)) - else: - tag += self._value_str(self.value) - elif not self.children: - return tag[:-1] + (" " if pretty else "") + "/>" - - for i in self.children: - if isinstance(i, XMLNode): - tag += nl + i._to_str(pretty, indent + 4) - if self.children: - tag += nl + " " * indent - tag += f"" - return tag - - def __eq__(self, other): - return ( - isinstance(other, XMLNode) - and self.name == other.name - and self.type == other.type - and self.value == other.value - and len(self.children) == len(other.children) - and all(i == j for i, j in zip(self.children, other.children)) - ) - - -__all__ = ("XMLNode", ) +import binascii +import re + +from typing import Generator, Any + +from html import escape + +from .misc import assert_true +from .const import DEFAULT_ENCODING, NAME_MAX_COMPRESSED, XML_ENCODING_BACK, Type +from .exception import XMLStrutureError, NodeNotFound, AttributeNotFound + + +class XMLNode: + def __init__(self, name, type_, value, attributes=None, encoding=DEFAULT_ENCODING): + self.name = name + self.type = type_ if isinstance(type_, Type) else Type.from_val(type_) + self.value: Any = value # TODO: A stricter way to do this. Subclassing? + self.children = [] + self.attributes = {} + if attributes is not None: + for i in attributes: + self.attributes[i] = attributes[i] + self.encoding = encoding or DEFAULT_ENCODING + assert_true(encoding in XML_ENCODING_BACK, "Invalid encoding") + + @classmethod + def void(cls, __name, **attributes): + return cls(__name, Type.Void, (), attributes) + + @property + def is_array(self): + return isinstance(self.value, list) + + @property + def can_compress(self): + return ( + (len(self.name) <= NAME_MAX_COMPRESSED) + and all(i.can_compress for i in self.children) + ) + + def _xpath(self, attr, path): + if path: + child = path.pop(0) + for i in self.children: + if i.name == child: + return i._xpath(attr, path) + raise NodeNotFound + if not attr: + return self + if attr in self.attributes: + return self.attributes[attr] + raise AttributeNotFound + + def xpath(self, path): + match = re.match(r"^(?:@([\w:]+)/)?((?:[\w:]+(?:/|$))+)", path) + if match is None: + raise ValueError + attr = match.group(1) + path = match.group(2).split("/") + return self._xpath(attr, path) + + def append(self, __name, __type=Type.Void, __value=(), **attributes): + child = XMLNode(__name, __type, __value, attributes) + self.children.append(child) + return child + + def __len__(self): + return len(self.children) + + def __iter__(self) -> Generator["XMLNode", None, None]: + for i in self.children: + yield i + + def get(self, name, default=None): + try: + return self[name] + except XMLStrutureError: + return default + + def __getitem__(self, name): + if isinstance(name, int): + try: + return self.children[name] + except IndexError: + raise NodeNotFound + try: + return self.attributes[name] + except KeyError: + raise AttributeNotFound + + def __setitem__(self, name, value): + self.attributes[name] = value + + def to_str(self, pretty=False): + return ( + f'' + + ("\n" if pretty else "") + + self._to_str(pretty) + ) + + def __str__(self): + return self.to_str(pretty=True) + + def _value_str(self, value: Any) -> str: + if isinstance(value, list): + return " ".join(map(self._value_str, value)) + if self.type == Type.Blob: + return binascii.hexlify(value).decode() + if self.type == Type.IPv4 or self.type == Type.IPv4_Int: + if isinstance(value, int): + value = ( + (value >> 24) & 0xff, + (value >> 16) & 0xff, + (value >> 8) & 0xff, + (value >> 0) & 0xff, + ) + return f"{value[0]}.{value[1]}.{value[2]}.{value[3]}" + if self.type in (Type.Float, Type.TwoFloat, Type.ThreeFloat): + return f"{value:.6f}" + if self.type == Type.Str: + return escape(str(value)) + + return str(value) + + def _to_str(self, pretty, indent=0): + if not pretty: + indent = 0 + nl = "\n" if pretty else "" + tag = f"{' ' * indent}<{self.name}" + + if self.type != Type.Void: + tag += f" __type=\"{self.type.value.names[0]}\"" + if self.type == Type.Blob: + tag += f" __size=\"{len(self.value)}\"" + if self.is_array: + tag += f" __count=\"{len(self.value)}\"" + + attributes = " ".join(f"{i}=\"{escape(j)}\"" for i, j in self.attributes.items()) + if attributes: + tag += " " + attributes + tag += ">" + if self.value is not None and self.type != Type.Void: + if self.is_array: + tag += " ".join(map(self._value_str, self.value)) + else: + tag += self._value_str(self.value) + elif not self.children: + return tag[:-1] + (" " if pretty else "") + "/>" + + for i in self.children: + if isinstance(i, XMLNode): + tag += nl + i._to_str(pretty, indent + 4) + if self.children: + tag += nl + " " * indent + tag += f"" + return tag + + def __eq__(self, other): + return ( + isinstance(other, XMLNode) + and self.name == other.name + and self.type == other.type + and self.value == other.value + and len(self.children) == len(other.children) + and all(i == j for i, j in zip(self.children, other.children)) + ) + + +__all__ = ("XMLNode", ) diff --git a/eaapi/packer.py b/eaapi/packer.py index a8b3733..6af372d 100644 --- a/eaapi/packer.py +++ b/eaapi/packer.py @@ -1,41 +1,41 @@ -import math - - -class Packer: - def __init__(self, offset=0): - self._word_cursor = offset - self._short_cursor = offset - self._byte_cursor = offset - self._boundary = offset % 4 - - def _next_block(self): - self._word_cursor += 4 - return self._word_cursor - 4 - - def request_allocation(self, size): - if size == 0: - return self._word_cursor - elif size == 1: - if self._byte_cursor % 4 == self._boundary: - self._byte_cursor = self._next_block() + 1 - else: - self._byte_cursor += 1 - return self._byte_cursor - 1 - elif size == 2: - if self._short_cursor % 4 == self._boundary: - self._short_cursor = self._next_block() + 2 - else: - self._short_cursor += 2 - return self._short_cursor - 2 - else: - old_cursor = self._word_cursor - for _ in range(math.ceil(size / 4)): - self._word_cursor += 4 - return old_cursor - - def notify_skipped(self, no_bytes): - for _ in range(math.ceil(no_bytes / 4)): - self.request_allocation(4) - - -__all__ = ("Packer", ) +import math + + +class Packer: + def __init__(self, offset: int = 0): + self._word_cursor = offset + self._short_cursor = offset + self._byte_cursor = offset + self._boundary = offset % 4 + + def _next_block(self) -> int: + self._word_cursor += 4 + return self._word_cursor - 4 + + def request_allocation(self, size: int) -> int: + if size == 0: + return self._word_cursor + elif size == 1: + if self._byte_cursor % 4 == self._boundary: + self._byte_cursor = self._next_block() + 1 + else: + self._byte_cursor += 1 + return self._byte_cursor - 1 + elif size == 2: + if self._short_cursor % 4 == self._boundary: + self._short_cursor = self._next_block() + 2 + else: + self._short_cursor += 2 + return self._short_cursor - 2 + else: + old_cursor = self._word_cursor + for _ in range(math.ceil(size / 4)): + self._word_cursor += 4 + return old_cursor + + def notify_skipped(self, no_bytes: int) -> None: + for _ in range(math.ceil(no_bytes / 4)): + self.request_allocation(4) + + +__all__ = ("Packer", ) diff --git a/eaapi/server b/eaapi/server deleted file mode 160000 index dec7ca3..0000000 --- a/eaapi/server +++ /dev/null @@ -1 +0,0 @@ -Subproject commit dec7ca3536cf459be21b7284358bf2bea4eb3d14 diff --git a/eaapi/server/.gitignore b/eaapi/server/.gitignore new file mode 100644 index 0000000..e8ab094 --- /dev/null +++ b/eaapi/server/.gitignore @@ -0,0 +1,2 @@ +*.pyc +__pycache__/ diff --git a/eaapi/server/README.md b/eaapi/server/README.md new file mode 100644 index 0000000..5f75744 --- /dev/null +++ b/eaapi/server/README.md @@ -0,0 +1,43 @@ +# eaapi.server + +## Quickstart + +```py +server = EAMServer("http://127.0.0.1:5000") + +@server.handler("message", "get") +def message(ctx): + ctx.resp.append("message", expire="300", status="0") + +server.run("0.0.0.0", 5000) +``` + +```py +EAMServer( + # The URL this server can be access at. Used for services + public_url, + # Add `//` as a prefix when generating service urls (useful when debugging games) + prefix_services: bool = False, + # If both the URL and the query params match, which one gets the final say? + prioritise_params: bool = False, + # Include e-Amusement specific details of why requests failed in the responses + verbose_errors: bool = False, + # The operation mode in services.get's response + services_mode: eaapi.const.ServicesMode = eaapi.const.ServicesMode.Operation, + # The NTP server to use in services.get + ntp_server: str = "ntp://pool.ntp.org/", + # Keepalive server to use in serices.get. We'll use our own if one is not specified + keepalive_server: str = None +) + +@handler( + # Module name to handle. Will curry if method is not provided + module, + # Method name to handle + method=None, + # The datecode prefix to match during routing + dc_prefix=None, + # The service to use. Likely `local` or `local2` when handling game functions + service=None +) +``` diff --git a/eaapi/server/__init__.py b/eaapi/server/__init__.py new file mode 100644 index 0000000..1da17e0 --- /dev/null +++ b/eaapi/server/__init__.py @@ -0,0 +1,13 @@ +from .server import EAMServer +from .context import CallContext +from .model import Model, ModelMatcher, DatecodeMatcher +from .exceptions import EAMHTTPException +from .controller import Controller + +__all__ = ( + "EAMServer", + "CallContext", + "Model", "ModelMatcher", "DatecodeMatcher", + "EAMHTTPException", + "Controller", +) diff --git a/eaapi/server/__main__.py b/eaapi/server/__main__.py new file mode 100644 index 0000000..3651ac6 --- /dev/null +++ b/eaapi/server/__main__.py @@ -0,0 +1,4 @@ +from .server import EAMServer + +app = EAMServer("http://127.0.0.1:5000", verbose_errors=True) +app.run("0.0.0.0", 5000, debug=True) diff --git a/eaapi/server/const.py b/eaapi/server/const.py new file mode 100644 index 0000000..de9826c --- /dev/null +++ b/eaapi/server/const.py @@ -0,0 +1,45 @@ +# Services where the module and service name match +TRIVIAL_SERVICES = [ + "pcbtracker", + "message", + "facility", + "pcbevent", + "cardmng", + "package", + "userdata", + "userid", + "dlstatus", + "eacoin", + # "traceroute", + "apsmanager", + "sidmgr", +] + +# Just chilling here until I figure out where these route +UNMAPPED_SERVICES = { + "???0": "numbering", + "???1": "pkglist", + "???2": "posevent", + "???3": "lobby", + "???4": "lobby2", + "???5": "netlog", # ins.netlog (?) + "???6": "globby", + "???7": "matching", + "???8": "netsci", +} + +MODULE_SERVICES = "services" +RESERVED_MODULES = [ + MODULE_SERVICES, +] + +METHOD_SERVICES_GET = "get" + +SERVICE_SERVICES = "services" +SERVICE_KEEPALIVE = "keepalive" +SERVICE_NTP = "ntp" +RESERVED_SERVICES = [ + SERVICE_SERVICES, + SERVICE_KEEPALIVE, + SERVICE_NTP, +] diff --git a/eaapi/server/context.py b/eaapi/server/context.py new file mode 100644 index 0000000..e0a127f --- /dev/null +++ b/eaapi/server/context.py @@ -0,0 +1,105 @@ +import eaapi + +from . import exceptions as exc +from .model import Model + + +NODE_CALL = "call" +NODE_RESP = "response" + + +class CallContext: + def __init__(self, request, decoder, call, eainfo, compressed): + if call.name != NODE_CALL: + raise exc.CallNodeMissing + self._request = request + self._decoder = decoder + self._call: eaapi.XMLNode = call + self._eainfo: str | None = eainfo + self._compressed: bool = compressed + self._resp: eaapi.XMLNode = eaapi.XMLNode.void(NODE_RESP) + + self._module: str | None = None + self._method: str | None = None + self._url_slash: bool | None = None + + self._model: Model = Model.from_model_str(call.get("model")) + + @property + def module(self): + return self._module + + @property + def method(self): + return self._method + + @property + def url_slash(self): + return self._url_slash + + @property + def request(self): + return self._request + + @property + def was_xml_string(self): + return self._decoder.is_xml_string + + @property + def was_compressed(self): + return self._compressed + + @property + def call(self): + return self._call + + @property + def resp(self): + return self._resp + + @property + def model(self): + return self._model + + @property + def srcid(self): + return self._call.get("srcid") + + @property + def tag(self): + return self._call.get("tag") + + def get_root(self): + return self.call.xpath(self.module) + + def abort(self, status="1"): + return self.resp.append(self.module, status=status) + + def ok(self): + return self.abort("0") + + +class ResponseContext: + def __init__(self, resp, decoder, response, compressed): + if response.name != NODE_RESP: + raise exc.CallNodeMissing + self._resp = resp + self._decoder = decoder + self._response = response + self._compressed = compressed + + @property + def resp(self): + return self._resp + + @property + def decoder(self): + return self._decoder + + @property + def response(self): + return self._response + + @property + def compressed(self): + return self._compressed diff --git a/eaapi/server/controller.py b/eaapi/server/controller.py new file mode 100644 index 0000000..6710ce4 --- /dev/null +++ b/eaapi/server/controller.py @@ -0,0 +1,210 @@ +from typing import Callable +from collections import defaultdict +from abc import ABC + +from .context import CallContext +from .model import ModelMatcher +from .const import RESERVED_MODULES, RESERVED_SERVICES, TRIVIAL_SERVICES, MODULE_SERVICES, METHOD_SERVICES_GET + +import eaapi + + +Handler = Callable[[CallContext], None] + + +class IController(ABC): + _name: str + + def get_handler(self, ctx: CallContext) -> Handler | None: + raise NotImplementedError + + def get_service_routes(self, ctx: CallContext | None) -> dict[str, str]: + raise NotImplementedError + + def serviced_prefixes(self) -> list[str]: + raise NotImplementedError + + +class Controller(IController): + def __init__(self, server, endpoint="", matcher: None | ModelMatcher = None): + from .server import EAMServer + self._server: EAMServer = server + server.controllers.append(self) + + import inspect + caller = inspect.getmodule(inspect.stack()[1][0]) + assert caller is not None + self._name = caller.__name__ + + self._handlers: dict[ + tuple[str | None, str | None], + list[tuple[ModelMatcher, Handler]] + ] = defaultdict(lambda: []) + self._endpoint = endpoint + self._matcher = matcher + + self._services: set[tuple[ModelMatcher, str]] = set() + + self._pre_handler: list[Callable[[CallContext, Handler], Handler]] = [] + + @property + def server(self): + return self._server + + def on_pre(self, callback): + if callback not in self._pre_handler: + self._pre_handler.append(callback) + + def add_dummy_service( + self, + service: str, + matcher: ModelMatcher | None = None, + unsafe_force_bypass_reserved: bool = False + ): + if not unsafe_force_bypass_reserved: + if service in RESERVED_SERVICES: + raise KeyError( + f"{service} is a reserved service provided by default.\n" + "Pass unsafe_force_bypass_reserved=True to override this implementation" + ) + if matcher is None: + matcher = ModelMatcher() + self._services.add((matcher, service)) + + def register_handler( + self, + handler: Handler, + module: str, + method: str, + matcher: ModelMatcher | None = None, + service: str | None = None, + unsafe_force_bypass_reserved: bool = False + ): + if not unsafe_force_bypass_reserved: + if module in RESERVED_MODULES: + raise KeyError( + f"{module} is a reserved module provided by default.\n" + "Pass unsafe_force_bypass_reserved=True to override this implementation" + ) + + if service is not None and service in RESERVED_SERVICES: + raise KeyError( + f"{service} is a reserved service provided by default.\n" + "Pass unsafe_force_bypass_reserved=True to override this implementation" + ) + + if service is None: + if module not in TRIVIAL_SERVICES: + raise ValueError(f"Unable to identify service for {module}") + service = module + + handlers = self._handlers[(module, method)] + for i in handlers: + if matcher is None and i[0] is None: + raise ValueError(f"Duplicate default handler for {module}.{method}") + if matcher == i[0]: + raise ValueError(f"Duplicate handler for {module}.{method} ({matcher})") + matcher_ = matcher or ModelMatcher() + handlers.append((matcher_, handler)) + handlers.sort(key=lambda x: x[0]) + + self._services.add((matcher_, service)) + + def handler( + self, + module: str, + method: str | None = None, + matcher: ModelMatcher | None = None, + service: str | None = None, + unsafe_force_bypass_reserved: bool = False + ): + if method is None: + def h2(method): + return self.handler(module, method, matcher, service) + return h2 + + # Commented out for MachineCallContext bodge + # def decorator(handler: Handler): + def decorator(handler): + self.register_handler(handler, module, method, matcher, service, unsafe_force_bypass_reserved) + return handler + return decorator + + def get_handler(self, ctx): + if self._matcher is not None: + if not self._matcher.matches(ctx.model): + return None + + handlers = self._handlers[(ctx.module, ctx.method)] + if not handlers: + return None + + for matcher, handler in handlers: + if matcher.matches(ctx.model): + for i in self._pre_handler: + handler = i(ctx, handler) + return handler + + return None + + def get_service_routes(self, ctx: CallContext | None): + endpoint = self._server.expand_url(self._endpoint) + + if ctx is None: + return { + service: endpoint + for _, service in self._services + } + + if self._matcher is not None and not self._matcher.matches(ctx.model): + return {} + + return { + service: endpoint + for matcher, service in self._services + if matcher.matches(ctx.model) + } + + def serviced_prefixes(self) -> list[str]: + return [self._endpoint] + + +class ServicesController(IController): + def __init__(self, server, services_mode: eaapi.const.ServicesMode): + from .server import EAMServer + self._server: EAMServer = server + + self.services_mode = services_mode + + self._name = __name__ + "." + self.__class__.__name__ + + def service_routes_for(self, ctx: CallContext): + services = defaultdict(lambda: self._server.public_url) + for service, route in self._server.get_service_routes(ctx): + services[service] = self._server.expand_url(route) + + return services + + def get_handler(self, ctx: CallContext) -> Handler | None: + if ctx.module == MODULE_SERVICES and ctx.method == METHOD_SERVICES_GET: + return self.services_get + return None + + def service_route(self, for_service: str, ctx: CallContext): + routes = self.service_routes_for(ctx) + return routes[for_service] + + def services_get(self, ctx: CallContext): + services = ctx.resp.append( + MODULE_SERVICES, expire="600", mode=self.services_mode.value, status="0" + ) + + routes = self._server.get_service_routes(ctx) + for service in routes: + services.append("item", name=service, url=routes[service]) + + def get_service_routes(self, ctx: CallContext | None) -> dict[str, str]: + return {} + + def serviced_prefixes(self) -> list[str]: + return [""] diff --git a/eaapi/server/demo/dummy.py b/eaapi/server/demo/dummy.py new file mode 100644 index 0000000..9d30f1c --- /dev/null +++ b/eaapi/server/demo/dummy.py @@ -0,0 +1,88 @@ +import time + +from eaapi import Type + +from ..server import EAMServer +from ..model import ModelMatcher + + +server = EAMServer("http://127.0.0.1:5000", verbose_errors=True) + + +@server.handler("pcbtracker", "alive") +def pcbtracker(ctx): + ecflag = ctx.call.xpath("@ecflag/pcbtracker") + + ctx.resp.append( + "pcbtracker", + status="0", expire="1200", + ecenable=ecflag, eclimit="1000", limit="1000", + time=str(round(time.time())) + ) + + +@server.handler("message", "get", matcher=ModelMatcher("KFC")) +def message(ctx): + ctx.resp.append("message", expire="300", status="0") + + +@server.handler("facility", "get") +def facility_get(ctx): + facility = ctx.resp.append("facility", status="0") + location = facility.append("location") + location.append("id", Type.Str, "") + location.append("country", Type.Str, "UK") + location.append("region", Type.Str, "") + location.append("name", Type.Str, "Hello Flask") + location.append("type", Type.U8, 0) + location.append("countryname", Type.Str, "UK-c") + location.append("countryjname", Type.Str, "") + location.append("regionname", Type.Str, "UK-r") + location.append("regionjname", Type.Str, "") + location.append("customercode", Type.Str, "") + location.append("companycode", Type.Str, "") + location.append("latitude", Type.S32, 0) + location.append("longitude", Type.S32, 0) + location.append("accuracy", Type.U8, 0) + + line = facility.append("line") + line.append("id", Type.Str, "") + line.append("class", Type.U8, 0) + + portfw = facility.append("portfw") + portfw.append("globalip", Type.IPv4, (*map(int, ctx.request.remote_addr.split(".")),)) + portfw.append("globalport", Type.U16, ctx.request.environ.get('REMOTE_PORT')) + portfw.append("privateport", Type.U16, ctx.request.environ.get('REMOTE_PORT')) + + public = facility.append("public") + public.append("flag", Type.U8, 1) + public.append("name", Type.Str, "") + public.append("latitude", Type.S32, 0) + public.append("longitude", Type.S32, 0) + + share = facility.append("share") + eacoin = share.append("eacoin") + eacoin.append("notchamount", Type.S32, 0) + eacoin.append("notchcount", Type.S32, 0) + eacoin.append("supplylimit", Type.S32, 100000) + url = share.append("url") + url.append("eapass", Type.Str, "www.ea-pass.konami.net") + url.append("arcadefan", Type.Str, "www.konami.jp/am") + url.append("konaminetdx", Type.Str, "http://am.573.jp") + url.append("konamiid", Type.Str, "http://id.konami.jp") + url.append("eagate", Type.Str, "http://eagate.573.jp") + + +@server.handler("pcbevent", "put") +def pcbevent(ctx): + ctx.resp.append("pcbevent", status="0") + + +server.route_service_to("cardmng", server.service_route("cardmng")) +server.route_service_to("package", server.service_route("package")) + +server.route_service_to("message", "message/KFC", matcher=ModelMatcher("KFC")) + + +if __name__ == "__main__": + server.run("0.0.0.0", 5000, debug=True) diff --git a/eaapi/server/demo/proxy.py b/eaapi/server/demo/proxy.py new file mode 100644 index 0000000..a55f577 --- /dev/null +++ b/eaapi/server/demo/proxy.py @@ -0,0 +1,162 @@ +import base64 +import urllib.parse + +import requests + +import eaapi + +from werkzeug.routing import Rule +from werkzeug.wrappers import Response + +from .. import exceptions as exc +from ..model import ModelMatcher +from ..context import ResponseContext +from ..server import ( + METHOD_SERVICES_GET, MODULE_SERVICES, SERVICE_KEEPALIVE, SERVICE_NTP, EAMServer, HEADER_ENCRYPTION, + HEADER_COMPRESSION, TRIVIAL_SERVICES +) + + +class EAMProxyServer(EAMServer): + def __init__(self, upstream, *args, **kwargs): + super().__init__(*args, **kwargs) + self._upstream_fallback = upstream + + self._taps = [] + self._tapped_services = set() + + if not kwargs.get("disable_routes"): + self._rules.add(Rule("/__tap/////", endpoint="tap_request")) + self._rules.add(Rule("/__tap/", endpoint="tap_request")) + + def tap_url(self, ctx, url): + url = base64.b64encode(url.encode("latin-1")).decode() + return urllib.parse.urljoin(self._public_url, f"__tap/{url}/") + + def _call_upstream(self, ctx, upstream): + base = upstream or self._upstream_fallback + if ctx.url_slash: + url = f"{base}/{ctx.model}/{ctx.module}/{ctx.method}" + else: + url = f"{base}?model={ctx.model}&module={ctx.module}&method={ctx.method}" + + return requests.post(url, headers=ctx.request.headers, data=ctx.request.data) + + def tap(self, module, method=None, matcher=None, service=None): + if method is None: + def h2(method): + return self.handler(module, method, matcher, service) + return h2 + + def decorator(handler): + matcher_ = matcher or ModelMatcher() + self._taps.append((module, method, matcher_, handler)) + + if service is None: + if module in TRIVIAL_SERVICES: + self._tapped_services.add((matcher_, module)) + else: + raise ValueError(f"Unable to identify service for {module}") + else: + self._tapped_services.add((matcher_, service)) + return handler + return decorator + + def _decode_us_request(self, resp): + ea_info = resp.headers.get(HEADER_ENCRYPTION) + compression = resp.headers.get(HEADER_COMPRESSION) + + compressed = False + if compression == eaapi.Compression.Lz77.value: + compressed = True + elif compression != eaapi.Compression.None_.value: + raise exc.UnknownCompression + + payload = eaapi.unwrap(resp.content, ea_info, compressed) + decoder = eaapi.Decoder(payload) + try: + response = decoder.unpack() + except eaapi.exception.EAAPIException: + raise exc.InvalidPacket + + return ResponseContext(resp, decoder, response, compressed) + + def call_upstream(self, ctx, upstream): + resp = self._call_upstream(ctx, upstream) + return self._decode_us_request(resp) + + def _services_handler(self, ctx, upstream=None): + upstream_ctx = self.call_upstream(ctx, upstream) + + super()._services_handler(ctx) + services = ctx.resp.xpath("services") + # Never proxy NTP or keepalive + services.children = [ + i for i in services.children + if i.get("name") not in (SERVICE_NTP, SERVICE_KEEPALIVE) + ] + + added = set(i.get("name") for i in services) + + for i in upstream_ctx.response.xpath("services"): + name = i.get("name") + if name in added: + continue + + url = i.get("url") + # for matcher, service in self._tapped_services: + # if name == service and matcher.matches(ctx.model): + # url = self.tap_url(ctx, url) + if name != "keepalive": + url = self.tap_url(ctx, url) + + services.append("item", name=name, url=url) + + def _handle_request(self, upstream, url_slash, request, model, module, method): + ctx = self._create_ctx(url_slash, request, model, module, method) + + if ctx.module == MODULE_SERVICES: + if ctx.method != METHOD_SERVICES_GET: + raise exc.NoMethodHandler + self._services_handler(ctx, upstream) + return self._encode_response(ctx) + + try: + return super()._handle_request(ctx) + except exc.NoMethodHandler: + try: + us_resp = self._call_upstream(ctx, upstream) + except requests.RequestException: + raise exc.UpstreamFailed + + try: + resp_ctx = self._decode_us_request(us_resp) + for tap in self._taps: + if tap[0] == module and tap[1] == method and tap[2].matches(ctx.model): + tap[3](ctx, resp_ctx) + except Exception: + import traceback + traceback.print_exc() + finally: + # Return the response the upstream sent + response = Response(us_resp.content, us_resp.status_code) + for i in us_resp.headers: + response.headers[i] = us_resp.headers[i] + return response + + def on_xrpc_request(self, request, service=None, model=None, module=None, method=None): + return self.on_tap_request(request, None, service, model, module, method) + + def on_tap_request(self, request, upstream, service=None, model=None, module=None, method=None): + url_slash, service, model, module, method = self.parse_request(request, service, model, module, method) + + if request.method != "POST": + return self.on_xrpc_other(request, service, model, module, method) + + if upstream is not None: + try: + upstream = base64.b64decode(upstream).decode("latin-1") + except Exception: + raise exc.InvalidUpstream + + return self._handle_request(upstream, url_slash, request, model, module, method) diff --git a/eaapi/server/demo/tapper-test.py b/eaapi/server/demo/tapper-test.py new file mode 100644 index 0000000..cdfebd3 --- /dev/null +++ b/eaapi/server/demo/tapper-test.py @@ -0,0 +1,32 @@ +from .proxy import EAMProxyServer +from ..model import ModelMatcher + + +server = EAMProxyServer( + upstream="http://127.0.0.1:8083", + public_url="http://127.0.0.1:5000", + verbose_errors=True +) + + +@server.tap("game", "sv4_save_m", matcher=ModelMatcher("KFC"), service="local2") +def sv4_save_m(ctx_in, ctx_out): + print("SAVE M") + print(ctx_in.call) + print(ctx_out.response) + + game = ctx_in.call.xpath("game") + print("mid:", game.xpath("music_id").value) + print("score:", game.xpath("score").value) + print("clear:", game.xpath("clear_type").value) + + just = game.xpath("just_checker") + print( + just.xpath("before_3").value, just.xpath("before_2").value, just.xpath("before_1").value, + just.xpath("just").value, + just.xpath("after_1").value, just.xpath("after_2").value, just.xpath("after_3").value + ) + + +if __name__ == "__main__": + server.run("0.0.0.0", 5000, debug=True) diff --git a/eaapi/server/exceptions.py b/eaapi/server/exceptions.py new file mode 100644 index 0000000..83315f2 --- /dev/null +++ b/eaapi/server/exceptions.py @@ -0,0 +1,51 @@ +from werkzeug.exceptions import HTTPException + + +class EAMHTTPException(HTTPException): + code = None + eam_description = None + + +class InvalidUpstream(EAMHTTPException): + code = 400 + eam_description = "Upstream URL invalid" + + +class UpstreamFailed(EAMHTTPException): + code = 400 + eam_description = "Upstream request failed" + + +class UnknownCompression(EAMHTTPException): + code = 400 + eam_description = "Unknown compression type" + + +class InvalidPacket(EAMHTTPException): + code = 400 + eam_description = "Invalid XML packet" + + +class InvalidModel(EAMHTTPException): + code = 400 + eam_description = "Invalid model" + + +class ModelMissmatch(EAMHTTPException): + code = 400 + eam_description = "Model missmatched" + + +class ModuleMethodMissing(EAMHTTPException): + code = 400 + eam_description = "Module or method missing" + + +class CallNodeMissing(EAMHTTPException): + code = 400 + eam_description = " node missing" + + +class NoMethodHandler(EAMHTTPException): + code = 404 + eam_description = "No handler found for module/method" diff --git a/eaapi/server/model.py b/eaapi/server/model.py new file mode 100644 index 0000000..43a29fe --- /dev/null +++ b/eaapi/server/model.py @@ -0,0 +1,201 @@ +from dataclasses import dataclass + +import eaapi + + +class Model: + def __init__(self, gamecode: str, dest: str, spec: str, rev: str, datecode: str): + self._gamecode = gamecode + self._dest = dest + self._spec = spec + self._rev = rev + self._datecode = datecode + + @classmethod + def from_model_str(cls, model: str) -> "Model": + return cls(*eaapi.parse_model(model)) + + @property + def gamecode(self): + return self._gamecode + + @property + def dest(self): + return self._dest + + @property + def spec(self): + return self._spec + + @property + def rev(self): + return self._rev + + @property + def datecode(self): + return int(self._datecode) + + @property + def year(self): + return int(self._datecode[:4]) + + @property + def month(self): + return int(self._datecode[4:6]) + + @property + def day(self): + return int(self._datecode[6:8]) + + @property + def minor(self): + return int(self._datecode[8:10]) + + def __hash__(self): + return hash(str(self)) + + def __eq__(self, other): + if not isinstance(other, Model): + return False + return str(other) == str(self) + + def __str__(self): + return f"{self.gamecode}:{self.dest}:{self.spec}:{self.rev}:{self.datecode}" + + def __repr__(self): + return f"" + + +@dataclass +class DatecodeMatcher: + year: int | None = None + month: int | None = None + day: int | None = None + minor: int | None = None + + @classmethod + def from_str(cls, datecode: str): + if len(datecode) != 10 or not datecode.isdigit(): + raise ValueError("Not a valid datecode") + return cls( + int(datecode[0:4]), + int(datecode[4:6]), + int(datecode[6:8]), + int(datecode[8:10]) + ) + + def _num_filters(self): + num = 0 + if self.year is not None: + num += 1 + if self.month is not None: + num += 1 + if self.day is not None: + num += 1 + if self.minor is not None: + num += 1 + return num + + def __lt__(self, other): + if self._num_filters() < other._num_filters(): + return False + if self.minor is None and other.minor is not None: + return False + if self.day is None and other.day is not None: + return False + if self.month is None and other.month is not None: + return False + if self.year is None and other.year is not None: + return False + return True + + def __hash__(self): + return hash(str(self)) + + def __str__(self): + year = self.year if self.year is not None else "----" + month = self.month if self.month is not None else "--" + day = self.day if self.day is not None else "--" + minor = self.minor if self.minor is not None else "--" + return f"{year:04}{month:02}{day:02}{minor:02}" + + def matches(self, model): + if self.year is not None and model.year != self.year: + return False + if self.month is not None and model.month != self.month: + return False + if self.day is not None and model.day != self.day: + return False + if self.minor is not None and model.minor != self.minor: + return False + return True + + +@dataclass +class ModelMatcher: + gamecode: str | None = None + dest: str | None = None + spec: str | None = None + rev: str | None = None + datecode: list[DatecodeMatcher] | DatecodeMatcher | None = None + + def _num_filters(self): + num = 0 + if self.gamecode is not None: + num += 1 + if self.dest is not None: + num += 1 + if self.spec is not None: + num += 1 + if self.rev is not None: + num += 1 + if isinstance(self.datecode, list): + num += sum(i._num_filters() for i in self.datecode) + elif self.datecode is not None: + num += self.datecode._num_filters() + return num + + def __lt__(self, other): + if self._num_filters() < other._num_filters(): + return False + if self.datecode is None and other.datecode is not None: + return False + if self.rev is None and other.rev is not None: + return False + if self.spec is None and other.spec is not None: + return False + if self.dest is None and other.dest is not None: + return False + if self.gamecode is None and other.gamecode is not None: + return False + return True + + def __hash__(self): + return hash(str(self)) + + def __str__(self): + gamecode = self.gamecode if self.gamecode is not None else "---" + dest = self.dest if self.dest is not None else "-" + spec = self.spec if self.spec is not None else "-" + rev = self.rev if self.rev is not None else "-" + datecode = self.datecode if self.datecode is not None else "-" * 10 + if isinstance(self.datecode, list): + datecode = "/".join(str(i) for i in self.datecode) + if not datecode: + datecode = "-" * 10 + return f"{gamecode:3}:{dest}:{spec}:{rev}:{datecode}" + + def matches(self, model): + if self.gamecode is not None and model.gamecode != self.gamecode: + return False + if self.dest is not None and model.dest != self.dest: + return False + if self.spec is not None and model.spec != self.spec: + return False + if self.rev is not None and model.rev != self.rev: + return False + if isinstance(self.datecode, list): + return any(i.matches(model) for i in self.datecode) + if self.datecode is not None: + return self.datecode.matches(model) + return True diff --git a/eaapi/server/server.py b/eaapi/server/server.py new file mode 100644 index 0000000..756dba3 --- /dev/null +++ b/eaapi/server/server.py @@ -0,0 +1,397 @@ +import traceback +import urllib.parse +import urllib +import sys +import os + +from typing import Callable +from collections import defaultdict + +from werkzeug.exceptions import HTTPException, MethodNotAllowed +from werkzeug.wrappers import Request, Response +from werkzeug.routing import Map, Rule + +import eaapi + +from . import exceptions as exc +from .context import CallContext +from .model import Model +from .controller import IController, ServicesController +from .const import SERVICE_NTP, SERVICE_KEEPALIVE + + +Handler = Callable[[CallContext], None] + + +HEADER_ENCRYPTION = "X-Eamuse-Info" +HEADER_COMPRESSION = "X-Compress" + +PINGABLE_IP = "127.0.0.1" + + +class NetworkState: + def __init__(self): + self._pa = PINGABLE_IP # TODO: what does this one mean? + self.router_ip = PINGABLE_IP + self.gateway_ip = PINGABLE_IP + self.center_ip = PINGABLE_IP + + def format_ka(self, base): + return base + "?" + urllib.parse.urlencode({ + "pa": self.pa, + "ia": self.ia, + "ga": self.ga, + "ma": self.ma, + "t1": self.t1, + "t2": self.t2, + }) + + @property + def pa(self) -> str: + return self._pa + + @property + def ia(self) -> str: + return self.router_ip + + @property + def ga(self) -> str: + return self.gateway_ip + + @property + def ma(self) -> str: + return self.center_ip + + # TODO: Identify what these values are. Ping intervals? + @property + def t1(self): + return 2 + + @property + def t2(self): + return 10 + + +class EAMServer: + def __init__( + self, + public_url: str, + prioritise_params: bool = False, + verbose_errors: bool = False, + services_mode: eaapi.const.ServicesMode = eaapi.const.ServicesMode.Operation, + ntp_server: str = "ntp://pool.ntp.org/", + keepalive_server: str | None = None, + no_keepalive_route: bool = False, + disable_routes: bool = False, + no_services_handler: bool = False, + ): + self.network = NetworkState() + + self.verbose_errors = verbose_errors + + self._prioritise_params = prioritise_params + self._public_url = public_url + + self.disable_routes = disable_routes + self._no_keepalive_route = no_keepalive_route + + self.ntp = ntp_server + self.keepalive = keepalive_server or f"{public_url}/keepalive" + + self._prng = eaapi.crypt.new_prng() + + self._setup = [] + self._pre_handlers_check = [] + self._teardown = [] + + self._einfo_ctx: CallContext | None = None + self._einfo_controller: str | None = None + + self.controllers: list[IController] = [] + if not no_services_handler: + self.controllers.append(ServicesController(self, services_mode)) + + def on_setup(self, callback): + if callback not in self._setup: + self._setup.append(callback) + + def on_pre_handlers_check(self, callback): + if callback not in self._pre_handlers_check: + self._pre_handlers_check.append(callback) + + def on_teardown(self, callback): + if callback not in self._teardown: + self._teardown.append(callback) + + def build_rules_map(self) -> Map: + if self.disable_routes: + return Map([]) + + rules = Map([], strict_slashes=False, merge_slashes=False) + + prefixes = {"/"} + for i in self.controllers: + for prefix in i.serviced_prefixes(): + prefix = self.expand_url(prefix) + if not prefix.startswith(self._public_url): + continue + prefix = prefix[len(self._public_url):] + if prefix == "": + prefix = "/" + + prefixes.add(prefix) + + for i in prefixes: + rules.add(Rule(f"{i}///", endpoint="xrpc_request")) + # WSGI flattens the // at the start + if i == "/": + rules.add(Rule("///", endpoint="xrpc_request")) + rules.add(Rule(f"{i}", endpoint="xrpc_request")) + + if not self._no_keepalive_route: + rules.add(Rule("/keepalive", endpoint="keepalive_request")) + + return rules + + def expand_url(self, url: str) -> str: + return urllib.parse.urljoin(self._public_url, url) + + @property + def public_url(self) -> str: + return self._public_url + + def get_service_routes(self, ctx: CallContext | None) -> dict[str, str]: + services: dict[str, str] = defaultdict(lambda: self.public_url) + services[SERVICE_NTP] = self.ntp + services[SERVICE_KEEPALIVE] = self.network.format_ka(self.keepalive) + + for i in self.controllers: + services.update(i.get_service_routes(ctx)) + return services + + def _decode_request(self, request: Request) -> CallContext: + ea_info = request.headers.get(HEADER_ENCRYPTION) + compression = request.headers.get(HEADER_COMPRESSION) + + compressed = False + if compression == eaapi.Compression.Lz77.value: + compressed = True + elif compression != eaapi.Compression.None_.value: + raise exc.UnknownCompression + + payload = eaapi.unwrap(request.data, ea_info, compressed) + decoder = eaapi.Decoder(payload) + try: + call = decoder.unpack() + except eaapi.EAAPIException: + raise exc.InvalidPacket + + return CallContext(request, decoder, call, ea_info, compressed) + + def _encode_response(self, ctx: CallContext) -> Response: + if ctx._eainfo is None: + ea_info = None + else: + ea_info = eaapi.crypt.get_key(self._prng) + + encoded = eaapi.Encoder.encode(ctx.resp, ctx.was_xml_string) + wrapped = eaapi.wrap(encoded, ea_info, ctx.was_compressed) + response = Response(wrapped, 200) + if ea_info: + response.headers[HEADER_ENCRYPTION] = ea_info + response.headers[HEADER_COMPRESSION] = ( + eaapi.Compression.Lz77 if ctx.was_compressed + else eaapi.Compression.None_ + ).value + + return response + + def _create_ctx( + self, + url_slash: bool, + request: Request, + model: Model | None, + module: str, + method: str + ) -> CallContext: + ctx = self._decode_request(request) + ctx._module = module + ctx._method = method + ctx._url_slash = url_slash + self._einfo_ctx = ctx + + if ctx.model != model: + raise exc.ModelMissmatch + return ctx + + def _handle_request(self, ctx: CallContext) -> Response: + for controller in self.controllers: + if (handler := controller.get_handler(ctx)) is not None: + self._einfo_controller = ( + f"{controller._name}" + ) + + handler(ctx) + break + else: + raise exc.NoMethodHandler + + return self._encode_response(ctx) + + def on_xrpc_other( + self, + request: Request, + service: str | None = None, + model: str | None = None, + module: str | None = None, + method: str | None = None + ): + if request.method != "GET" or not self.verbose_errors: + raise MethodNotAllowed + + return Response( + f"XRPC running. model {model}, call {module}.{method} ({service})" + ) + + def keepalive_request(self) -> Response: + return Response(None) + + def parse_request( + self, + request: Request, + service: str | None = None, + model: str | None = None, + module: str | None = None, + method: str | None = None + ): + url_slash = bool(module and module and method) + model_param = request.args.get("model", None) + module_param = request.args.get("module", None) + method_param = request.args.get("method", None) + if "f" in request.args: + module_param, _, method_param = request.args.get("f", "").partition(".") + + if self._prioritise_params: + model = model_param or model + module = module_param or module + method = method_param or method + else: + model = model or model_param + module = module or module_param + method = method or method_param + + if module is None or method is None: + raise exc.ModuleMethodMissing + + if model is None: + model_obj = None + else: + try: + model_obj = Model.from_model_str(model) + except eaapi.exception.InvalidModel: + raise exc.InvalidModel + + return url_slash, service, model_obj, module, method + + def on_xrpc_request( + self, + request: Request, + service: str | None = None, + model: str | None = None, + module: str | None = None, + method: str | None = None + ): + url_slash, service, model_obj, module, method = self.parse_request(request, service, model, module, method) + + if request.method != "POST": + return self.on_xrpc_other(request, service, model, module, method) + + ctx = self._create_ctx(url_slash, request, model_obj, module, method) + for i in self._pre_handlers_check: + i(ctx) + return self._handle_request(ctx) + + def _make_error(self, status: int | None = None, message: str | None = None) -> Response: + response = eaapi.XMLNode.void("response") + if status is not None: + response["status"] = str(status) + + if self.verbose_errors: + if message: + response.append("details", eaapi.Type.Str, message) + + context = response.append("context") + if self._einfo_ctx is not None: + context.append("module", eaapi.Type.Str, self._einfo_ctx.module) + context.append("method", eaapi.Type.Str, self._einfo_ctx.method) + context.append("game", eaapi.Type.Str, str(self._einfo_ctx.model)) + if self._einfo_controller is not None: + context.append("controller", eaapi.Type.Str, self._einfo_controller) + + encoded = eaapi.Encoder.encode(response, False) + wrapped = eaapi.wrap(encoded, None, False) + response = Response(wrapped, status or 500) + response.headers[HEADER_COMPRESSION] = eaapi.Compression.None_.value + + return response + + def _eamhttp_error(self, exc: exc.EAMHTTPException) -> Response: + return self._make_error(exc.code, exc.eam_description) + + def _structure_error(self, e: eaapi.exception.XMLStrutureError) -> Response: + summary = traceback.extract_tb(e.__traceback__) + for frame_summary in summary: + filename = frame_summary.filename + frame_summary.filename = os.path.relpath(filename) + + # The first three entries are within the controller, and the last one is us + summary = summary[3:-1] + tb = "".join(traceback.format_list(traceback.StackSummary.from_list(summary))) + tb += f"{e.__module__}.{e.__class__.__name__}" + + return self._make_error(400, tb) + + def _generic_error(self, exc: Exception) -> Response: + return self._make_error(500, str(exc)) + + def dispatch_request(self, request): + self._einfo_ctx = None + self._einfo_controller = None + + adapter = self.build_rules_map().bind_to_environ(request.environ) + try: + endpoint, values = adapter.match() + return getattr(self, f"on_{endpoint}")(request, **values) + except exc.EAMHTTPException as e: + return self._eamhttp_error(e) + except HTTPException as e: + return e + except eaapi.exception.XMLStrutureError as e: + traceback.print_exc(file=sys.stderr) + return self._structure_error(e) + except Exception as e: + traceback.print_exc(file=sys.stderr) + return self._generic_error(e) + + def wsgi_app(self, environ, start_response): + request = Request(environ) + response = self.dispatch_request(request) + return response(environ, start_response) + + def __call__(self, environ, start_response): + for i in self._setup: + i() + + try: + response = self.wsgi_app(environ, start_response) + for i in self._teardown: + i(None) + return response + except Exception as e: + for i in self._teardown: + i(e) + raise e + + def run(self, host="127.0.0.1", port=5000, debug=False): + from werkzeug.serving import run_simple + run_simple(host, port, self, use_debugger=debug, use_reloader=debug) diff --git a/eaapi/wrapper.py b/eaapi/wrapper.py index ac920df..cf0157c 100644 --- a/eaapi/wrapper.py +++ b/eaapi/wrapper.py @@ -1,31 +1,31 @@ -from .crypt import ea_symmetric_crypt -from .lz77 import lz77_compress, lz77_decompress - - -def wrap(packet, info=None, compressed=True): - if compressed: - packet = lz77_compress(packet) - if info is None: - return packet - return ea_symmetric_crypt(packet, info) - - -def unwrap(packet, info=None, compressed=True): - if info is None: - decrypted = packet - else: - decrypted = ea_symmetric_crypt(packet, info) - - if compressed is None: - try: - decompressed = lz77_decompress(decrypted) - except IndexError: - return decrypted - if decompressed == b"\0\0\0\0\0\0": - # Decompression almost certainly failed - return decrypted - return decompressed - return lz77_decompress(decrypted) if compressed else decrypted - - -__all__ = ("wrap", "unwrap") +from .crypt import ea_symmetric_crypt +from .lz77 import lz77_compress, lz77_decompress + + +def wrap(packet: bytes, info: str | None = None, compressed: bool = True) -> bytes: + if compressed: + packet = lz77_compress(packet) + if info is None: + return packet + return ea_symmetric_crypt(packet, info) + + +def unwrap(packet: bytes, info: str | None = None, compressed: bool = True) -> bytes: + if info is None: + decrypted = packet + else: + decrypted = ea_symmetric_crypt(packet, info) + + if compressed is None: + try: + decompressed = lz77_decompress(decrypted) + except IndexError: + return decrypted + if decompressed == b"\0\0\0\0\0\0": + # Decompression almost certainly failed + return decrypted + return decompressed + return lz77_decompress(decrypted) if compressed else decrypted + + +__all__ = ("wrap", "unwrap")