diff --git a/app/main.py b/app/main.py index 60f8c60..30d2ed1 100644 --- a/app/main.py +++ b/app/main.py @@ -21,7 +21,7 @@ from starlette.middleware.cors import CORSMiddleware from starlette.responses import StreamingResponse, JSONResponse as JSONr, HTMLResponse as HTMLr, Response, RedirectResponse from orm import Origin, Lease, init as db_init, migrate -from util import load_private_key, load_public_key, get_pem, load_file +from util import PrivateKey, PublicKey, load_file # Load variables load_dotenv('../version.env') @@ -42,8 +42,8 @@ DLS_PORT = int(env('DLS_PORT', '443')) SITE_KEY_XID = str(env('SITE_KEY_XID', '00000000-0000-0000-0000-000000000000')) INSTANCE_REF = str(env('INSTANCE_REF', '10000000-0000-0000-0000-000000000001')) ALLOTMENT_REF = str(env('ALLOTMENT_REF', '20000000-0000-0000-0000-000000000001')) -INSTANCE_KEY_RSA = load_private_key(str(env('INSTANCE_KEY_RSA', join(dirname(__file__), 'cert/instance.private.pem')))) -INSTANCE_KEY_PUB = load_public_key(str(env('INSTANCE_KEY_PUB', join(dirname(__file__), 'cert/instance.public.pem')))) +INSTANCE_KEY_RSA = PrivateKey(str(env('INSTANCE_KEY_RSA', join(dirname(__file__), 'cert/instance.private.pem')))) +INSTANCE_KEY_PUB = PublicKey(str(env('INSTANCE_KEY_PUB', join(dirname(__file__), 'cert/instance.public.pem')))) TOKEN_EXPIRE_DELTA = relativedelta(days=int(env('TOKEN_EXPIRE_DAYS', 1)), hours=int(env('TOKEN_EXPIRE_HOURS', 0))) LEASE_EXPIRE_DELTA = relativedelta(days=int(env('LEASE_EXPIRE_DAYS', 90)), hours=int(env('LEASE_EXPIRE_HOURS', 0))) LEASE_RENEWAL_PERIOD = float(env('LEASE_RENEWAL_PERIOD', 0.15)) @@ -51,8 +51,8 @@ LEASE_RENEWAL_DELTA = timedelta(days=int(env('LEASE_EXPIRE_DAYS', 90)), hours=in CLIENT_TOKEN_EXPIRE_DELTA = relativedelta(years=12) CORS_ORIGINS = str(env('CORS_ORIGINS', '')).split(',') if (env('CORS_ORIGINS')) else [f'https://{DLS_URL}'] -jwt_encode_key = jwk.construct(get_pem(INSTANCE_KEY_RSA), algorithm=ALGORITHMS.RS256) -jwt_decode_key = jwk.construct(get_pem(INSTANCE_KEY_PUB), algorithm=ALGORITHMS.RS256) +jwt_encode_key = jwk.construct(INSTANCE_KEY_RSA.pem(), algorithm=ALGORITHMS.RS256) +jwt_decode_key = jwk.construct(INSTANCE_KEY_PUB.pem(), algorithm=ALGORITHMS.RS256) # Logging LOG_LEVEL = logging.DEBUG if DEBUG else logging.INFO @@ -264,10 +264,10 @@ async def _client_token(): }, "service_instance_public_key_configuration": { "service_instance_public_key_me": { - "mod": hex(INSTANCE_KEY_PUB.public_numbers().n)[2:], - "exp": int(INSTANCE_KEY_PUB.public_numbers().e), + "mod": hex(INSTANCE_KEY_PUB.raw().public_numbers().n)[2:], + "exp": int(INSTANCE_KEY_PUB.raw().public_numbers().e), }, - "service_instance_public_key_pem": get_pem(INSTANCE_KEY_PUB).decode('utf-8'), + "service_instance_public_key_pem": INSTANCE_KEY_PUB.pem().decode('utf-8'), "key_retention_mode": "LATEST_ONLY" }, } diff --git a/app/util.py b/app/util.py index f2b1be4..dd07f50 100644 --- a/app/util.py +++ b/app/util.py @@ -1,8 +1,60 @@ import logging +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey, generate_private_key +from cryptography.hazmat.primitives.serialization import load_pem_private_key, load_pem_public_key + logging.basicConfig() +class PrivateKey: + + def __init__(self, filename: str): + log = logging.getLogger(__name__) + log.debug(f'Importing RSA-Key from "{filename}"') + + with open(filename, 'rb') as f: + data = f.read() + + self.key = load_pem_private_key(data.strip(), password=None) + + def raw(self) -> RSAPrivateKey: + return self.key + + def pem(self) -> bytes: + return self.key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption() + ) + + @staticmethod + def generate(public_exponent: int = 65537, key_size: int = 2048) -> RSAPrivateKey: + log = logging.getLogger(__name__) + log.debug(f'Generating RSA-Key') + return generate_private_key(public_exponent=public_exponent, key_size=key_size) + + +class PublicKey: + + def __init__(self, filename: str): + log = logging.getLogger(__name__) + log.debug(f'Importing RSA-Key from "{filename}"') + + with open(filename, 'rb') as f: + data = f.read() + + self.key = load_pem_public_key(data.strip()) + + def raw(self) -> RSAPublicKey: + return self.key + + def pem(self) -> bytes: + return self.key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo + ) + def load_file(filename: str) -> bytes: log = logging.getLogger(f'{__name__}') log.debug(f'Loading contents of file "{filename}') @@ -11,53 +63,6 @@ def load_file(filename: str) -> bytes: return content -def load_private_key(filename: str) -> "RSAPrivateKey": - from cryptography.hazmat.primitives.serialization import load_pem_private_key - - log = logging.getLogger(__name__) - log.debug(f'Importing RSA-Key from "{filename}"') - - with open(filename, 'rb') as f: - data = f.read() - return load_pem_private_key(data.strip(), password=None) - - -def load_public_key(filename: str) -> "RSAPublicKey": - from cryptography.hazmat.primitives.serialization import load_pem_public_key - - log = logging.getLogger(__name__) - log.debug(f'Importing RSA-Key from "{filename}"') - - with open(filename, 'rb') as f: - data = f.read() - return load_pem_public_key(data.strip()) - - -def get_pem(key) -> bytes | None: - from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey - from cryptography.hazmat.primitives import serialization - - if isinstance(key, RSAPrivateKey): - return key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption() - ) - if isinstance(key, RSAPublicKey): - return key.public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo - ) - - -def generate_private_key() -> "RSAPrivateKey": - from cryptography.hazmat.primitives.asymmetric import rsa - - log = logging.getLogger(__name__) - log.debug(f'Generating RSA-Key') - return rsa.generate_private_key(public_exponent=65537, key_size=2048) - - class NV: __DRIVER_MATRIX_FILENAME = 'static/driver_matrix.json' __DRIVER_MATRIX: None | dict = None # https://docs.nvidia.com/grid/ => "Driver Versions" diff --git a/test/main.py b/test/main.py index b63601e..a56a84a 100644 --- a/test/main.py +++ b/test/main.py @@ -16,7 +16,7 @@ sys.path.append('../') sys.path.append('../app') from app import main -from util import load_private_key, load_public_key, get_pem +from util import PrivateKey, PublicKey client = TestClient(main.app) @@ -25,11 +25,11 @@ ORIGIN_REF, ALLOTMENT_REF, SECRET = str(uuid4()), '20000000-0000-0000-0000-00000 # INSTANCE_KEY_RSA = generate_key() # INSTANCE_KEY_PUB = INSTANCE_KEY_RSA.public_key() -INSTANCE_KEY_RSA = load_private_key(str(join(dirname(__file__), '../app/cert/instance.private.pem'))) -INSTANCE_KEY_PUB = load_public_key(str(join(dirname(__file__), '../app/cert/instance.public.pem'))) +INSTANCE_KEY_RSA = PrivateKey(str(join(dirname(__file__), '../app/cert/instance.private.pem'))) +INSTANCE_KEY_PUB = PublicKey(str(join(dirname(__file__), '../app/cert/instance.public.pem'))) -jwt_encode_key = jwk.construct(get_pem(INSTANCE_KEY_RSA), algorithm=ALGORITHMS.RS256) -jwt_decode_key = jwk.construct(get_pem(INSTANCE_KEY_PUB), algorithm=ALGORITHMS.RS256) +jwt_encode_key = jwk.construct(INSTANCE_KEY_RSA.pem(), algorithm=ALGORITHMS.RS256) +jwt_decode_key = jwk.construct(INSTANCE_KEY_PUB.pem(), algorithm=ALGORITHMS.RS256) def __bearer_token(origin_ref: str) -> str: