billing: add classes and validation, fix response

This commit is contained in:
Hay1tsme 2023-08-21 01:50:59 -04:00
parent 984949d902
commit d8b0e2ea2a
1 changed files with 161 additions and 28 deletions

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Any, Optional, Tuple, Union from typing import Dict, List, Any, Optional, Tuple, Union, Final
import logging, coloredlogs import logging, coloredlogs
from logging.handlers import TimedRotatingFileHandler from logging.handlers import TimedRotatingFileHandler
from twisted.web.http import Request from twisted.web.http import Request
@ -14,12 +14,15 @@ from Crypto.Signature import PKCS1_v1_5
from time import strptime from time import strptime
from os import path from os import path
import urllib.parse import urllib.parse
import math
from core.config import CoreConfig from core.config import CoreConfig
from core.utils import Utils from core.utils import Utils
from core.data import Data from core.data import Data
from core.const import * from core.const import *
BILLING_DT_FORMAT: Final[str] = "%Y%m%d%H%M%S"
class DLIMG_TYPE(Enum): class DLIMG_TYPE(Enum):
app = 0 app = 0
opt = 1 opt = 1
@ -379,45 +382,60 @@ class AllnetServlet:
rsa = RSA.import_key(open(self.config.billing.signing_key, "rb").read()) rsa = RSA.import_key(open(self.config.billing.signing_key, "rb").read())
signer = PKCS1_v1_5.new(rsa) signer = PKCS1_v1_5.new(rsa)
digest = SHA.new() digest = SHA.new()
traces: List[TraceData] = []
try: try:
kc_playlimit = int(req_dict[0]["playlimit"]) for x in range(len(req_dict)):
kc_nearfull = int(req_dict[0]["nearfull"]) if not req_dict[x]:
kc_billigtype = int(req_dict[0]["billingtype"]) continue
kc_playcount = int(req_dict[0]["playcnt"])
kc_serial: str = req_dict[0]["keychipid"] if x == 0:
kc_game: str = req_dict[0]["gameid"] req = BillingInfo(req_dict[x])
kc_date = strptime(req_dict[0]["date"], "%Y%m%d%H%M%S") continue
kc_serial_bytes = kc_serial.encode()
tmp = TraceData(req_dict[x])
if tmp.trace_type == TraceDataType.CHARGE:
tmp = TraceDataCharge(req_dict[x])
elif tmp.trace_type == TraceDataType.EVENT:
tmp = TraceDataEvent(req_dict[x])
elif tmp.trace_type == TraceDataType.CREDIT:
tmp = TraceDataCredit(req_dict[x])
traces.append(tmp)
kc_serial_bytes = req.keychipid.encode()
except KeyError as e: except KeyError as e:
return f"result=5&linelimit=&message={e} field is missing".encode() self.logger.error(f"Billing request failed to parse: {e}")
return f"result=5&linelimit=&message=field is missing or formatting is incorrect\r\n".encode()
machine = self.data.arcade.get_machine(kc_serial) machine = self.data.arcade.get_machine(req.keychipid)
if machine is None and not self.config.server.allow_unregistered_serials: if machine is None and not self.config.server.allow_unregistered_serials:
msg = f"Unrecognised serial {kc_serial} attempted billing checkin from {request_ip} for game {kc_game}." msg = f"Unrecognised serial {req.keychipid} attempted billing checkin from {request_ip} for {req.gameid} v{req.gamever}."
self.data.base.log_event( self.data.base.log_event(
"allnet", "BILLING_CHECKIN_NG_SERIAL", logging.WARN, msg "allnet", "BILLING_CHECKIN_NG_SERIAL", logging.WARN, msg
) )
self.logger.warning(msg) self.logger.warning(msg)
resp = BillingResponse("", "", "", "") return f"result=1&requestno={req.requestno}&message=Keychip Serial bad\r\n".encode()
resp.result = "1"
return urllib.parse.unquote(urllib.parse.urlencode(vars(resp))) + "\r\n"
msg = ( msg = (
f"Billing checkin from {request_ip}: game {kc_game} keychip {kc_serial} playcount " f"Billing checkin from {request_ip}: game {req.gameid} ver {req.gamever} keychip {req.keychipid} playcount "
f"{kc_playcount} billing_type {kc_billigtype} nearfull {kc_nearfull} playlimit {kc_playlimit}" f"{req.playcnt} billing_type {req.billingtype.name} nearfull {req.nearfull} playlimit {req.playlimit}"
) )
self.logger.info(msg) self.logger.info(msg)
self.data.base.log_event("billing", "BILLING_CHECKIN_OK", logging.INFO, msg) self.data.base.log_event("billing", "BILLING_CHECKIN_OK", logging.INFO, msg)
if req.traceleft > 0:
self.logger.warn(f"{req.traceleft} unsent tracelogs")
kc_playlimit = req.playlimit
kc_nearfull = req.nearfull
while kc_playcount > kc_playlimit: while req.playcnt > req.playlimit:
kc_playlimit += 1024 kc_playlimit += 1024
kc_nearfull += 1024 kc_nearfull += 1024
playlimit = kc_playlimit playlimit = kc_playlimit
nearfull = kc_nearfull + (kc_billigtype * 0x00010000) nearfull = kc_nearfull + (req.billingtype.value * 0x00010000)
digest.update(playlimit.to_bytes(4, "little") + kc_serial_bytes) digest.update(playlimit.to_bytes(4, "little") + kc_serial_bytes)
playlimit_sig = signer.sign(digest).hex() playlimit_sig = signer.sign(digest).hex()
@ -428,11 +446,16 @@ class AllnetServlet:
# TODO: playhistory # TODO: playhistory
resp = BillingResponse(playlimit, playlimit_sig, nearfull, nearfull_sig) #resp = BillingResponse(playlimit, playlimit_sig, nearfull, nearfull_sig)
resp = BillingResponse(playlimit, playlimit_sig, nearfull, nearfull_sig, req.requestno, req.protocolver)
resp_str = urllib.parse.unquote(urllib.parse.urlencode(vars(resp))) + "\r\n" resp_str = urllib.parse.unquote(urllib.parse.urlencode(vars(resp))) + "\r\n"
self.logger.debug(f"response {vars(resp)}") self.logger.debug(f"response {vars(resp)}")
if req.traceleft > 0:
self.logger.info(f"Requesting 20 more of {req.traceleft} unsent tracelogs")
return f"result=6&waittime=0&linelimit=20\r\n".encode()
return resp_str.encode("utf-8") return resp_str.encode("utf-8")
def handle_naomitest(self, request: Request, _: Dict) -> bytes: def handle_naomitest(self, request: Request, _: Dict) -> bytes:
@ -565,6 +588,114 @@ class AllnetDownloadOrderResponse:
self.serial = serial self.serial = serial
self.uri = uri self.uri = uri
class TraceDataType(Enum):
CHARGE = 0
EVENT = 1
CREDIT = 2
class BillingType(Enum):
A = 1
B = 0
class float5:
def __init__(self, n: str = "0") -> None:
nf = float(n)
if nf > 999.9 or nf < 0:
raise ValueError('float5 must be between 0.000 and 999.9 inclusive')
return nf
@classmethod
def to_str(cls, f: float):
return f"%.{4 - int(math.log10(f))+1}f" % f
class BillingInfo:
def __init__(self, data: Dict) -> None:
try:
self.keychipid = str(data.get("keychipid", None))
self.functype = int(data.get("functype", None))
self.gameid = str(data.get("gameid", None))
self.gamever = float(data.get("gamever", None))
self.boardid = str(data.get("boardid", None))
self.tenpoip = str(data.get("tenpoip", None))
self.libalibver = float(data.get("libalibver", None))
self.datamax = int(data.get("datamax", None))
self.billingtype = BillingType(int(data.get("billingtype", None)))
self.protocolver = float(data.get("protocolver", None))
self.operatingfix = bool(data.get("operatingfix", None))
self.traceleft = int(data.get("traceleft", None))
self.requestno = int(data.get("requestno", None))
self.datesync = bool(data.get("datesync", None))
self.timezone = str(data.get("timezone", None))
self.date = datetime.strptime(data.get("date", None), BILLING_DT_FORMAT)
self.crcerrcnt = int(data.get("crcerrcnt", None))
self.memrepair = bool(data.get("memrepair", None))
self.playcnt = int(data.get("playcnt", None))
self.playlimit = int(data.get("playlimit", None))
self.nearfull = int(data.get("nearfull", None))
except Exception as e:
raise KeyError(e)
class TraceData:
def __init__(self, data: Dict) -> None:
try:
self.crc_err_flg = bool(data.get("cs", None))
self.record_number = int(data.get("rn", None))
self.seq_number = int(data.get("sn", None))
self.trace_type = TraceDataType(int(data.get("tt", None)))
self.date_sync_flg = bool(data.get("ds", None))
self.date = datetime.strptime(data.get("dt", None), BILLING_DT_FORMAT)
self.keychip = str(data.get("kn", None))
self.lib_ver = float(data.get("alib", None))
except Exception as e:
raise KeyError(e)
class TraceDataCharge(TraceData):
def __init__(self, data: Dict) -> None:
super().__init__(data)
try:
self.game_id = str(data.get("gi", None))
self.game_version = float(data.get("gv", None))
self.board_serial = str(data.get("bn", None))
self.shop_ip = str(data.get("ti", None))
self.play_count = int(data.get("pc", None))
self.play_limit = int(data.get("pl", None))
self.product_code = int(data.get("ic", None))
self.product_count = int(data.get("in", None))
self.func_type = int(data.get("kk", None))
self.player_number = int(data.get("playerno", None))
except Exception as e:
raise KeyError(e)
class TraceDataEvent(TraceData):
def __init__(self, data: Dict) -> None:
super().__init__(data)
try:
self.message = str(data.get("me", None))
except Exception as e:
raise KeyError(e)
class TraceDataCredit(TraceData):
def __init__(self, data: Dict) -> None:
super().__init__(data)
try:
self.chute_type = int(data.get("cct", None))
self.service_type = int(data.get("cst", None))
self.operation_type = int(data.get("cop", None))
self.coin_rate0 = int(data.get("cr0", None))
self.coin_rate1 = int(data.get("cr1", None))
self.bonus_addition = int(data.get("cba", None))
self.credit_rate = int(data.get("ccr", None))
self.credit0 = int(data.get("cc0", None))
self.credit1 = int(data.get("cc1", None))
self.credit2 = int(data.get("cc2", None))
self.credit3 = int(data.get("cc3", None))
self.credit4 = int(data.get("cc4", None))
self.credit5 = int(data.get("cc5", None))
self.credit6 = int(data.get("cc6", None))
self.credit7 = int(data.get("cc7", None))
except Exception as e:
raise KeyError(e)
class BillingResponse: class BillingResponse:
def __init__( def __init__(
@ -573,20 +704,22 @@ class BillingResponse:
playlimit_sig: str = "", playlimit_sig: str = "",
nearfull: str = "", nearfull: str = "",
nearfull_sig: str = "", nearfull_sig: str = "",
request_num: int = 1,
protocol_ver: float = 1.000,
playhistory: str = "000000/0:000000/0:000000/0", playhistory: str = "000000/0:000000/0:000000/0",
) -> None: ) -> None:
self.result = "0" self.result = 0
self.waitime = "100" self.requestno = request_num
self.linelimit = "1" self.traceerase = 1
self.message = "" self.fixinterval = 120
self.fixlogcnt = 100
self.playlimit = playlimit self.playlimit = playlimit
self.playlimitsig = playlimit_sig self.playlimitsig = playlimit_sig
self.protocolver = "1.000" self.playhistory = playhistory
self.nearfull = nearfull self.nearfull = nearfull
self.nearfullsig = nearfull_sig self.nearfullsig = nearfull_sig
self.fixlogincnt = "0" self.linelimit = 100
self.fixinterval = "5" self.protocolver = float5.to_str(protocol_ver)
self.playhistory = playhistory
# playhistory -> YYYYMM/C:... # playhistory -> YYYYMM/C:...
# YYYY -> 4 digit year, MM -> 2 digit month, C -> Playcount during that period # YYYY -> 4 digit year, MM -> 2 digit month, C -> Playcount during that period