add get_ip_addr util function for servers behind proxies

This commit is contained in:
Hay1tsme 2023-03-12 01:00:51 -05:00
parent ea14f105d5
commit 18a95f5213
9 changed files with 33 additions and 20 deletions

View File

@ -12,8 +12,8 @@ from Crypto.Signature import PKCS1_v1_5
from time import strptime from time import strptime
from core.config import CoreConfig from core.config import CoreConfig
from core.data import Data
from core.utils import Utils from core.utils import Utils
from core.data import Data
from core.const import * from core.const import *
@ -69,7 +69,7 @@ class AllnetServlet:
) )
def handle_poweron(self, request: Request, _: Dict): def handle_poweron(self, request: Request, _: Dict):
request_ip = request.getClientAddress().host request_ip = Utils.get_ip_addr(request)
try: try:
req_dict = self.allnet_req_to_dict(request.content.getvalue()) req_dict = self.allnet_req_to_dict(request.content.getvalue())
if req_dict is None: if req_dict is None:
@ -162,7 +162,7 @@ class AllnetServlet:
return self.dict_to_http_form_string([vars(resp)]).encode("utf-8") return self.dict_to_http_form_string([vars(resp)]).encode("utf-8")
def handle_dlorder(self, request: Request, _: Dict): def handle_dlorder(self, request: Request, _: Dict):
request_ip = request.getClientAddress().host request_ip = Utils.get_ip_addr(request)
try: try:
req_dict = self.allnet_req_to_dict(request.content.getvalue()) req_dict = self.allnet_req_to_dict(request.content.getvalue())
if req_dict is None: if req_dict is None:
@ -255,7 +255,7 @@ class AllnetServlet:
return resp_str.encode("utf-8") return resp_str.encode("utf-8")
def handle_naomitest(self, request: Request, _: Dict) -> bytes: def handle_naomitest(self, request: Request, _: Dict) -> bytes:
self.logger.info(f"Ping from {request.getClientAddress().host}") self.logger.info(f"Ping from {Utils.get_ip_addr(request)}")
return b"naomi ok" return b"naomi ok"
def kvp_to_dict(self, kvp: List[str]) -> List[Dict[str, Any]]: def kvp_to_dict(self, kvp: List[str]) -> List[Dict[str, Any]]:

View File

@ -10,9 +10,8 @@ from twisted.python.components import registerAdapter
import jinja2 import jinja2
import bcrypt import bcrypt
from core.config import CoreConfig from core import CoreConfig, Utils
from core.data import Data from core.data import Data
from core.utils import Utils
class IUserSession(Interface): class IUserSession(Interface):
@ -143,7 +142,7 @@ class FE_Gate(FE_Base):
def render_POST(self, request: Request): def render_POST(self, request: Request):
uri = request.uri.decode() uri = request.uri.decode()
ip = request.getClientAddress().host ip = Utils.get_ip_addr(request)
if uri == "/gate/gate.login": if uri == "/gate/gate.login":
access_code: str = request.args[b"access_code"][0].decode() access_code: str = request.args[b"access_code"][0].decode()

View File

@ -6,7 +6,7 @@ from twisted.web.http import Request
from datetime import datetime from datetime import datetime
import pytz import pytz
from core.config import CoreConfig from core import CoreConfig
from core.utils import Utils from core.utils import Utils
@ -52,6 +52,8 @@ class MuchaServlet:
def handle_boardauth(self, request: Request, _: Dict) -> bytes: def handle_boardauth(self, request: Request, _: Dict) -> bytes:
req_dict = self.mucha_preprocess(request.content.getvalue()) req_dict = self.mucha_preprocess(request.content.getvalue())
client_ip = Utils.get_ip_addr(request)
if req_dict is None: if req_dict is None:
self.logger.error( self.logger.error(
f"Error processing mucha request {request.content.getvalue()}" f"Error processing mucha request {request.content.getvalue()}"
@ -61,7 +63,7 @@ class MuchaServlet:
req = MuchaAuthRequest(req_dict) req = MuchaAuthRequest(req_dict)
self.logger.debug(f"Mucha request {vars(req)}") self.logger.debug(f"Mucha request {vars(req)}")
self.logger.info( self.logger.info(
f"Boardauth request from {request.getClientAddress().host} for {req.gameVer}" f"Boardauth request from {client_ip} for {req.gameVer}"
) )
if req.gameCd not in self.mucha_registry: if req.gameCd not in self.mucha_registry:
@ -80,6 +82,8 @@ class MuchaServlet:
def handle_updatecheck(self, request: Request, _: Dict) -> bytes: def handle_updatecheck(self, request: Request, _: Dict) -> bytes:
req_dict = self.mucha_preprocess(request.content.getvalue()) req_dict = self.mucha_preprocess(request.content.getvalue())
client_ip = Utils.get_ip_addr(request)
if req_dict is None: if req_dict is None:
self.logger.error( self.logger.error(
f"Error processing mucha request {request.content.getvalue()}" f"Error processing mucha request {request.content.getvalue()}"
@ -89,7 +93,7 @@ class MuchaServlet:
req = MuchaUpdateRequest(req_dict) req = MuchaUpdateRequest(req_dict)
self.logger.debug(f"Mucha request {vars(req)}") self.logger.debug(f"Mucha request {vars(req)}")
self.logger.info( self.logger.info(
f"Updatecheck request from {request.getClientAddress().host} for {req.gameVer}" f"Updatecheck request from {client_ip} for {req.gameVer}"
) )
if req.gameCd not in self.mucha_registry: if req.gameCd not in self.mucha_registry:

View File

@ -1,5 +1,6 @@
from typing import Dict, Any from typing import Dict, Any
from types import ModuleType from types import ModuleType
from twisted.web.http import Request
import logging import logging
import importlib import importlib
from os import walk from os import walk
@ -21,3 +22,7 @@ class Utils:
logging.getLogger("core").error(f"get_all_titles: {dir} - {e}") logging.getLogger("core").error(f"get_all_titles: {dir} - {e}")
raise raise
return ret return ret
@classmethod
def get_ip_addr(cls, req: Request) -> str:
return req.getAllHeaders()[b"x-forwarded-for"].decode() if b"x-forwarded-for" in req.getAllHeaders() else req.getClientAddress().host

View File

@ -96,9 +96,11 @@ class HttpDispatcher(resource.Resource):
def render_GET(self, request: Request) -> bytes: def render_GET(self, request: Request) -> bytes:
test = self.map_get.match(request.uri.decode()) test = self.map_get.match(request.uri.decode())
client_ip = Utils.get_ip_addr(request)
if test is None: if test is None:
self.logger.debug( self.logger.debug(
f"Unknown GET endpoint {request.uri.decode()} from {request.getClientAddress().host} to port {request.getHost().port}" f"Unknown GET endpoint {request.uri.decode()} from {client_ip} to port {request.getHost().port}"
) )
request.setResponseCode(404) request.setResponseCode(404)
return b"Endpoint not found." return b"Endpoint not found."
@ -107,9 +109,11 @@ class HttpDispatcher(resource.Resource):
def render_POST(self, request: Request) -> bytes: def render_POST(self, request: Request) -> bytes:
test = self.map_post.match(request.uri.decode()) test = self.map_post.match(request.uri.decode())
client_ip = Utils.get_ip_addr(request)
if test is None: if test is None:
self.logger.debug( self.logger.debug(
f"Unknown POST endpoint {request.uri.decode()} from {request.getClientAddress().host} to port {request.getHost().port}" f"Unknown POST endpoint {request.uri.decode()} from {client_ip} to port {request.getHost().port}"
) )
request.setResponseCode(404) request.setResponseCode(404)
return b"Endpoint not found." return b"Endpoint not found."

View File

@ -9,8 +9,7 @@ import logging, coloredlogs
from logging.handlers import TimedRotatingFileHandler from logging.handlers import TimedRotatingFileHandler
from typing import List, Optional from typing import List, Optional
from core import CoreConfig from core import CoreConfig, Utils
from core.utils import Utils
class BaseReader: class BaseReader:

View File

@ -11,7 +11,7 @@ from Crypto.Util.Padding import pad
from os import path from os import path
from typing import Tuple from typing import Tuple
from core import CoreConfig from core import CoreConfig, Utils
from titles.chuni.config import ChuniConfig from titles.chuni.config import ChuniConfig
from titles.chuni.const import ChuniConstants from titles.chuni.const import ChuniConstants
from titles.chuni.base import ChuniBase from titles.chuni.base import ChuniBase
@ -111,6 +111,7 @@ class ChuniServlet:
encrtped = False encrtped = False
internal_ver = 0 internal_ver = 0
endpoint = url_split[len(url_split) - 1] endpoint = url_split[len(url_split) - 1]
client_ip = Utils.get_ip_addr(request)
if version < 105: # 1.0 if version < 105: # 1.0
internal_ver = ChuniConstants.VER_CHUNITHM internal_ver = ChuniConstants.VER_CHUNITHM
@ -179,7 +180,7 @@ class ChuniServlet:
req_data = json.loads(unzip) req_data = json.loads(unzip)
self.logger.info( self.logger.info(
f"v{version} {endpoint} request from {request.getClientAddress().host}" f"v{version} {endpoint} request from {client_ip}"
) )
self.logger.debug(req_data) self.logger.debug(req_data)

View File

@ -10,7 +10,7 @@ import inflection
from os import path from os import path
from google.protobuf.message import DecodeError from google.protobuf.message import DecodeError
from core.config import CoreConfig from core import CoreConfig, Utils
from titles.pokken.config import PokkenConfig from titles.pokken.config import PokkenConfig
from titles.pokken.base import PokkenBase from titles.pokken.base import PokkenBase
from titles.pokken.const import PokkenConstants from titles.pokken.const import PokkenConstants
@ -128,7 +128,7 @@ class PokkenServlet(resource.Resource):
def handle_matching(self, request: Request) -> bytes: def handle_matching(self, request: Request) -> bytes:
content = request.content.getvalue() content = request.content.getvalue()
client_ip = request.getClientAddress().host client_ip = Utils.get_ip_addr(request)
if content is None or content == b"": if content is None or content == b"":
self.logger.info("Empty matching request") self.logger.info("Empty matching request")

View File

@ -8,7 +8,7 @@ from twisted.web.http import Request
from typing import Dict, Tuple from typing import Dict, Tuple
from os import path from os import path
from core.config import CoreConfig from core import CoreConfig, Utils
from titles.wacca.config import WaccaConfig from titles.wacca.config import WaccaConfig
from titles.wacca.config import WaccaConfig from titles.wacca.config import WaccaConfig
from titles.wacca.const import WaccaConstants from titles.wacca.const import WaccaConstants
@ -89,6 +89,7 @@ class WaccaServlet:
request.responseHeaders.addRawHeader(b"X-Wacca-Hash", hash.hex().encode()) request.responseHeaders.addRawHeader(b"X-Wacca-Hash", hash.hex().encode())
return json.dumps(resp).encode() return json.dumps(resp).encode()
client_ip = Utils.get_ip_addr(request)
try: try:
req_json = json.loads(request.content.getvalue()) req_json = json.loads(request.content.getvalue())
version_full = Version(req_json["appVersion"]) version_full = Version(req_json["appVersion"])
@ -140,7 +141,7 @@ class WaccaServlet:
return end(resp.make()) return end(resp.make())
self.logger.info( self.logger.info(
f"v{req_json['appVersion']} {url_path} request from {request.getClientAddress().host} with chipId {req_json['chipId']}" f"v{req_json['appVersion']} {url_path} request from {client_ip} with chipId {req_json['chipId']}"
) )
self.logger.debug(req_json) self.logger.debug(req_json)