created PrivateKey / PublicKey wrapper classes

This commit is contained in:
Oscar Krause 2025-03-18 09:43:44 +01:00
parent 958f23f79d
commit fd46eecfb3
3 changed files with 65 additions and 60 deletions

View File

@ -21,7 +21,7 @@ from starlette.middleware.cors import CORSMiddleware
from starlette.responses import StreamingResponse, JSONResponse as JSONr, HTMLResponse as HTMLr, Response, RedirectResponse
from orm import Origin, Lease, init as db_init, migrate
from util import load_private_key, load_public_key, get_pem, load_file
from util import PrivateKey, PublicKey, load_file
# Load variables
load_dotenv('../version.env')
@ -42,8 +42,8 @@ DLS_PORT = int(env('DLS_PORT', '443'))
SITE_KEY_XID = str(env('SITE_KEY_XID', '00000000-0000-0000-0000-000000000000'))
INSTANCE_REF = str(env('INSTANCE_REF', '10000000-0000-0000-0000-000000000001'))
ALLOTMENT_REF = str(env('ALLOTMENT_REF', '20000000-0000-0000-0000-000000000001'))
INSTANCE_KEY_RSA = load_private_key(str(env('INSTANCE_KEY_RSA', join(dirname(__file__), 'cert/instance.private.pem'))))
INSTANCE_KEY_PUB = load_public_key(str(env('INSTANCE_KEY_PUB', join(dirname(__file__), 'cert/instance.public.pem'))))
INSTANCE_KEY_RSA = PrivateKey(str(env('INSTANCE_KEY_RSA', join(dirname(__file__), 'cert/instance.private.pem'))))
INSTANCE_KEY_PUB = PublicKey(str(env('INSTANCE_KEY_PUB', join(dirname(__file__), 'cert/instance.public.pem'))))
TOKEN_EXPIRE_DELTA = relativedelta(days=int(env('TOKEN_EXPIRE_DAYS', 1)), hours=int(env('TOKEN_EXPIRE_HOURS', 0)))
LEASE_EXPIRE_DELTA = relativedelta(days=int(env('LEASE_EXPIRE_DAYS', 90)), hours=int(env('LEASE_EXPIRE_HOURS', 0)))
LEASE_RENEWAL_PERIOD = float(env('LEASE_RENEWAL_PERIOD', 0.15))
@ -51,8 +51,8 @@ LEASE_RENEWAL_DELTA = timedelta(days=int(env('LEASE_EXPIRE_DAYS', 90)), hours=in
CLIENT_TOKEN_EXPIRE_DELTA = relativedelta(years=12)
CORS_ORIGINS = str(env('CORS_ORIGINS', '')).split(',') if (env('CORS_ORIGINS')) else [f'https://{DLS_URL}']
jwt_encode_key = jwk.construct(get_pem(INSTANCE_KEY_RSA), algorithm=ALGORITHMS.RS256)
jwt_decode_key = jwk.construct(get_pem(INSTANCE_KEY_PUB), algorithm=ALGORITHMS.RS256)
jwt_encode_key = jwk.construct(INSTANCE_KEY_RSA.pem(), algorithm=ALGORITHMS.RS256)
jwt_decode_key = jwk.construct(INSTANCE_KEY_PUB.pem(), algorithm=ALGORITHMS.RS256)
# Logging
LOG_LEVEL = logging.DEBUG if DEBUG else logging.INFO
@ -264,10 +264,10 @@ async def _client_token():
},
"service_instance_public_key_configuration": {
"service_instance_public_key_me": {
"mod": hex(INSTANCE_KEY_PUB.public_numbers().n)[2:],
"exp": int(INSTANCE_KEY_PUB.public_numbers().e),
"mod": hex(INSTANCE_KEY_PUB.raw().public_numbers().n)[2:],
"exp": int(INSTANCE_KEY_PUB.raw().public_numbers().e),
},
"service_instance_public_key_pem": get_pem(INSTANCE_KEY_PUB).decode('utf-8'),
"service_instance_public_key_pem": INSTANCE_KEY_PUB.pem().decode('utf-8'),
"key_retention_mode": "LATEST_ONLY"
},
}

View File

@ -1,8 +1,60 @@
import logging
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey, generate_private_key
from cryptography.hazmat.primitives.serialization import load_pem_private_key, load_pem_public_key
logging.basicConfig()
class PrivateKey:
def __init__(self, filename: str):
log = logging.getLogger(__name__)
log.debug(f'Importing RSA-Key from "{filename}"')
with open(filename, 'rb') as f:
data = f.read()
self.key = load_pem_private_key(data.strip(), password=None)
def raw(self) -> RSAPrivateKey:
return self.key
def pem(self) -> bytes:
return self.key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption()
)
@staticmethod
def generate(public_exponent: int = 65537, key_size: int = 2048) -> RSAPrivateKey:
log = logging.getLogger(__name__)
log.debug(f'Generating RSA-Key')
return generate_private_key(public_exponent=public_exponent, key_size=key_size)
class PublicKey:
def __init__(self, filename: str):
log = logging.getLogger(__name__)
log.debug(f'Importing RSA-Key from "{filename}"')
with open(filename, 'rb') as f:
data = f.read()
self.key = load_pem_public_key(data.strip())
def raw(self) -> RSAPublicKey:
return self.key
def pem(self) -> bytes:
return self.key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
def load_file(filename: str) -> bytes:
log = logging.getLogger(f'{__name__}')
log.debug(f'Loading contents of file "{filename}')
@ -11,53 +63,6 @@ def load_file(filename: str) -> bytes:
return content
def load_private_key(filename: str) -> "RSAPrivateKey":
from cryptography.hazmat.primitives.serialization import load_pem_private_key
log = logging.getLogger(__name__)
log.debug(f'Importing RSA-Key from "{filename}"')
with open(filename, 'rb') as f:
data = f.read()
return load_pem_private_key(data.strip(), password=None)
def load_public_key(filename: str) -> "RSAPublicKey":
from cryptography.hazmat.primitives.serialization import load_pem_public_key
log = logging.getLogger(__name__)
log.debug(f'Importing RSA-Key from "{filename}"')
with open(filename, 'rb') as f:
data = f.read()
return load_pem_public_key(data.strip())
def get_pem(key) -> bytes | None:
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey
from cryptography.hazmat.primitives import serialization
if isinstance(key, RSAPrivateKey):
return key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption()
)
if isinstance(key, RSAPublicKey):
return key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
def generate_private_key() -> "RSAPrivateKey":
from cryptography.hazmat.primitives.asymmetric import rsa
log = logging.getLogger(__name__)
log.debug(f'Generating RSA-Key')
return rsa.generate_private_key(public_exponent=65537, key_size=2048)
class NV:
__DRIVER_MATRIX_FILENAME = 'static/driver_matrix.json'
__DRIVER_MATRIX: None | dict = None # https://docs.nvidia.com/grid/ => "Driver Versions"

View File

@ -16,7 +16,7 @@ sys.path.append('../')
sys.path.append('../app')
from app import main
from util import load_private_key, load_public_key, get_pem
from util import PrivateKey, PublicKey
client = TestClient(main.app)
@ -25,11 +25,11 @@ ORIGIN_REF, ALLOTMENT_REF, SECRET = str(uuid4()), '20000000-0000-0000-0000-00000
# INSTANCE_KEY_RSA = generate_key()
# INSTANCE_KEY_PUB = INSTANCE_KEY_RSA.public_key()
INSTANCE_KEY_RSA = load_private_key(str(join(dirname(__file__), '../app/cert/instance.private.pem')))
INSTANCE_KEY_PUB = load_public_key(str(join(dirname(__file__), '../app/cert/instance.public.pem')))
INSTANCE_KEY_RSA = PrivateKey(str(join(dirname(__file__), '../app/cert/instance.private.pem')))
INSTANCE_KEY_PUB = PublicKey(str(join(dirname(__file__), '../app/cert/instance.public.pem')))
jwt_encode_key = jwk.construct(get_pem(INSTANCE_KEY_RSA), algorithm=ALGORITHMS.RS256)
jwt_decode_key = jwk.construct(get_pem(INSTANCE_KEY_PUB), algorithm=ALGORITHMS.RS256)
jwt_encode_key = jwk.construct(INSTANCE_KEY_RSA.pem(), algorithm=ALGORITHMS.RS256)
jwt_decode_key = jwk.construct(INSTANCE_KEY_PUB.pem(), algorithm=ALGORITHMS.RS256)
def __bearer_token(origin_ref: str) -> str: