2024-11-19 11:18:12 +03:00
|
|
|
import json
|
|
|
|
import logging
|
|
|
|
import re
|
|
|
|
|
|
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
from starlette.requests import Request
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
class PatchMalformedJsonMiddleware(BaseHTTPMiddleware):
|
|
|
|
# see oscar.krause/fastapi-dls#1
|
|
|
|
|
2024-11-19 11:40:59 +03:00
|
|
|
REGEX = '(\"mac_address_list\"\:\s?\[)([\w\d])'
|
|
|
|
|
2024-11-19 11:18:12 +03:00
|
|
|
def __init__(self, app, enabled: bool):
|
|
|
|
super().__init__(app)
|
|
|
|
self.enabled = enabled
|
|
|
|
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
|
|
body = await request.body()
|
|
|
|
content_type = request.headers.get('Content-Type')
|
|
|
|
|
|
|
|
if self.enabled and content_type == 'application/json':
|
2024-11-22 16:16:10 +03:00
|
|
|
logger.debug(f'Using Request-Patch because "PatchMalformedJsonMiddleware" is enabled!')
|
2024-11-22 16:19:51 +03:00
|
|
|
body = body.decode()
|
2024-11-22 16:16:10 +03:00
|
|
|
|
|
|
|
# try to fix json
|
2024-11-19 11:18:12 +03:00
|
|
|
try:
|
2024-11-22 16:16:10 +03:00
|
|
|
j = json.loads(body)
|
2024-11-22 17:00:16 +03:00
|
|
|
self.__fix_mac_address_list_length(j=j, size=1)
|
2024-11-19 11:18:12 +03:00
|
|
|
except json.decoder.JSONDecodeError:
|
2024-11-22 16:16:10 +03:00
|
|
|
logger.warning(f'Malformed json received! Try to fix it.')
|
2024-11-22 16:19:51 +03:00
|
|
|
body = PatchMalformedJsonMiddleware.fix_json(body)
|
|
|
|
logger.debug(f'Fixed JSON: "{body}"')
|
|
|
|
j = json.loads(body) # ensure json is now valid
|
2024-11-22 17:00:16 +03:00
|
|
|
j = self.__fix_mac_address_list_length(j=j, size=1)
|
2024-11-19 11:18:12 +03:00
|
|
|
# set new body
|
2024-11-22 16:16:10 +03:00
|
|
|
request._body = json.dumps(j).encode('utf-8')
|
2024-11-19 11:18:12 +03:00
|
|
|
|
|
|
|
response = await call_next(request)
|
|
|
|
return response
|
2024-11-20 11:10:43 +03:00
|
|
|
|
2024-11-22 17:00:16 +03:00
|
|
|
@staticmethod
|
|
|
|
def __fix_mac_address_list_length(j: dict, size: int = 1) -> dict:
|
2024-11-22 16:16:10 +03:00
|
|
|
# reduce "mac_address_list" to
|
|
|
|
environment = j.get('environment', {})
|
|
|
|
fingerprint = environment.get('fingerprint', {})
|
|
|
|
mac_address = fingerprint.get('mac_address_list', [])
|
|
|
|
|
|
|
|
if len(mac_address) > 0:
|
|
|
|
logger.info(f'Transforming "mac_address_list" to length of {size}.')
|
|
|
|
j['environment']['fingerprint']['mac_address_list'] = mac_address[:size]
|
|
|
|
|
|
|
|
return j
|
|
|
|
|
2024-11-20 11:10:43 +03:00
|
|
|
@staticmethod
|
|
|
|
def fix_json(s: str) -> str:
|
|
|
|
s = s.replace('\t', '')
|
|
|
|
s = s.replace('\n', '')
|
|
|
|
return re.sub(PatchMalformedJsonMiddleware.REGEX, r'\1"\2', s)
|