mirror of
				https://git.collinwebdesigns.de/oscar.krause/fastapi-dls.git
				synced 2025-10-26 18:05:28 +03:00 
			
		
		
		
	Merge branch 'db' into 'main'
DB - store settings in database See merge request oscar.krause/fastapi-dls!49
This commit is contained in:
		
						commit
						ffc9f91c2e
					
				
							
								
								
									
										158
									
								
								app/main.py
									
									
									
									
									
								
							
							
						
						
									
										158
									
								
								app/main.py
									
									
									
									
									
								
							@ -1,8 +1,9 @@
 | 
			
		||||
import logging
 | 
			
		||||
import sys
 | 
			
		||||
from base64 import b64encode as b64enc
 | 
			
		||||
from calendar import timegm
 | 
			
		||||
from contextlib import asynccontextmanager
 | 
			
		||||
from datetime import datetime, timedelta, UTC
 | 
			
		||||
from datetime import datetime, UTC
 | 
			
		||||
from hashlib import sha256
 | 
			
		||||
from json import loads as json_loads
 | 
			
		||||
from os import getenv as env
 | 
			
		||||
@ -13,15 +14,14 @@ from dateutil.relativedelta import relativedelta
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
from fastapi import FastAPI
 | 
			
		||||
from fastapi.requests import Request
 | 
			
		||||
from jose import jws, jwk, jwt, JWTError
 | 
			
		||||
from jose import jws, jwt, JWTError
 | 
			
		||||
from jose.constants import ALGORITHMS
 | 
			
		||||
from sqlalchemy import create_engine
 | 
			
		||||
from sqlalchemy.orm import sessionmaker
 | 
			
		||||
from starlette.middleware.cors import CORSMiddleware
 | 
			
		||||
from starlette.responses import StreamingResponse, JSONResponse as JSONr, HTMLResponse as HTMLr, Response, RedirectResponse
 | 
			
		||||
 | 
			
		||||
from orm import Origin, Lease, init as db_init, migrate
 | 
			
		||||
from util import PrivateKey, PublicKey, load_file
 | 
			
		||||
from orm import Origin, Lease, init as db_init, migrate, Instance, Site
 | 
			
		||||
 | 
			
		||||
# Load variables
 | 
			
		||||
load_dotenv('../version.env')
 | 
			
		||||
@ -39,20 +39,9 @@ db_init(db), migrate(db)
 | 
			
		||||
# Load DLS variables (all prefixed with "INSTANCE_*" is used as "SERVICE_INSTANCE_*" or "SI_*" in official dls service)
 | 
			
		||||
DLS_URL = str(env('DLS_URL', 'localhost'))
 | 
			
		||||
DLS_PORT = int(env('DLS_PORT', '443'))
 | 
			
		||||
SITE_KEY_XID = str(env('SITE_KEY_XID', '00000000-0000-0000-0000-000000000000'))
 | 
			
		||||
INSTANCE_REF = str(env('INSTANCE_REF', '10000000-0000-0000-0000-000000000001'))
 | 
			
		||||
ALLOTMENT_REF = str(env('ALLOTMENT_REF', '20000000-0000-0000-0000-000000000001'))
 | 
			
		||||
INSTANCE_KEY_RSA = PrivateKey.from_file(str(env('INSTANCE_KEY_RSA', join(dirname(__file__), 'cert/instance.private.pem'))))
 | 
			
		||||
INSTANCE_KEY_PUB = PublicKey.from_file(str(env('INSTANCE_KEY_PUB', join(dirname(__file__), 'cert/instance.public.pem'))))
 | 
			
		||||
TOKEN_EXPIRE_DELTA = relativedelta(days=int(env('TOKEN_EXPIRE_DAYS', 1)), hours=int(env('TOKEN_EXPIRE_HOURS', 0)))
 | 
			
		||||
LEASE_EXPIRE_DELTA = relativedelta(days=int(env('LEASE_EXPIRE_DAYS', 90)), hours=int(env('LEASE_EXPIRE_HOURS', 0)))
 | 
			
		||||
LEASE_RENEWAL_PERIOD = float(env('LEASE_RENEWAL_PERIOD', 0.15))
 | 
			
		||||
LEASE_RENEWAL_DELTA = timedelta(days=int(env('LEASE_EXPIRE_DAYS', 90)), hours=int(env('LEASE_EXPIRE_HOURS', 0)))
 | 
			
		||||
CLIENT_TOKEN_EXPIRE_DELTA = relativedelta(years=12)
 | 
			
		||||
CORS_ORIGINS = str(env('CORS_ORIGINS', '')).split(',') if (env('CORS_ORIGINS')) else [f'https://{DLS_URL}']
 | 
			
		||||
 | 
			
		||||
jwt_encode_key = jwk.construct(INSTANCE_KEY_RSA.pem(), algorithm=ALGORITHMS.RS256)
 | 
			
		||||
jwt_decode_key = jwk.construct(INSTANCE_KEY_PUB.pem(), algorithm=ALGORITHMS.RS256)
 | 
			
		||||
ALLOTMENT_REF = str(env('ALLOTMENT_REF', '20000000-0000-0000-0000-000000000001'))  # todo
 | 
			
		||||
 | 
			
		||||
# Logging
 | 
			
		||||
LOG_LEVEL = logging.DEBUG if DEBUG else logging.INFO
 | 
			
		||||
@ -60,25 +49,33 @@ logging.basicConfig(format='[{levelname:^7}] [{module:^15}] {message}', style='{
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
logger.setLevel(LOG_LEVEL)
 | 
			
		||||
logging.getLogger('util').setLevel(LOG_LEVEL)
 | 
			
		||||
logging.getLogger('NV').setLevel(LOG_LEVEL)
 | 
			
		||||
logging.getLogger('DriverMatrix').setLevel(LOG_LEVEL)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# FastAPI
 | 
			
		||||
@asynccontextmanager
 | 
			
		||||
async def lifespan(_: FastAPI):
 | 
			
		||||
    # on startup
 | 
			
		||||
    default_instance = Instance.get_default_instance(db)
 | 
			
		||||
 | 
			
		||||
    lease_renewal_period = default_instance.lease_renewal_period
 | 
			
		||||
    lease_renewal_delta = default_instance.get_lease_renewal_delta()
 | 
			
		||||
    client_token_expire_delta = default_instance.get_client_token_expire_delta()
 | 
			
		||||
 | 
			
		||||
    logger.info(f'''
 | 
			
		||||
    
 | 
			
		||||
    Using timezone: {str(TZ)}. Make sure this is correct and match your clients!
 | 
			
		||||
    
 | 
			
		||||
    Your clients renew their license every {str(Lease.calculate_renewal(LEASE_RENEWAL_PERIOD, LEASE_RENEWAL_DELTA))}.
 | 
			
		||||
    If the renewal fails, the license is {str(LEASE_RENEWAL_DELTA)} valid.
 | 
			
		||||
    Your clients will renew their license every {str(Lease.calculate_renewal(lease_renewal_period, lease_renewal_delta))}.
 | 
			
		||||
    If the renewal fails, the license is valid for {str(lease_renewal_delta)}.
 | 
			
		||||
    
 | 
			
		||||
    Your client-token file (.tok) is valid for {str(CLIENT_TOKEN_EXPIRE_DELTA)}.
 | 
			
		||||
    Your client-token file (.tok) is valid for {str(client_token_expire_delta)}.
 | 
			
		||||
    ''')
 | 
			
		||||
 | 
			
		||||
    logger.info(f'Debug is {"enabled" if DEBUG else "disabled"}.')
 | 
			
		||||
 | 
			
		||||
    validate_settings()
 | 
			
		||||
 | 
			
		||||
    yield
 | 
			
		||||
 | 
			
		||||
    # on shutdown
 | 
			
		||||
@ -99,12 +96,24 @@ app.add_middleware(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Helper
 | 
			
		||||
def __get_token(request: Request) -> dict:
 | 
			
		||||
def __get_token(request: Request, jwt_decode_key: "jose.jwt") -> dict:
 | 
			
		||||
    authorization_header = request.headers.get('authorization')
 | 
			
		||||
    token = authorization_header.split(' ')[1]
 | 
			
		||||
    return jwt.decode(token=token, key=jwt_decode_key, algorithms=ALGORITHMS.RS256, options={'verify_aud': False})
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def validate_settings():
 | 
			
		||||
    session = sessionmaker(bind=db)()
 | 
			
		||||
 | 
			
		||||
    lease_expire_delta_min, lease_expire_delta_max = 86_400, 7_776_000
 | 
			
		||||
    for instance in session.query(Instance).all():
 | 
			
		||||
        lease_expire_delta = instance.lease_expire_delta
 | 
			
		||||
        if lease_expire_delta < 86_400 or lease_expire_delta > 7_776_000:
 | 
			
		||||
            logging.warning(f'> [ instance ]: {instance.instance_ref}: "lease_expire_delta" should be between {lease_expire_delta_min} and {lease_expire_delta_max}')
 | 
			
		||||
 | 
			
		||||
    session.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Endpoints
 | 
			
		||||
 | 
			
		||||
@app.get('/', summary='Index')
 | 
			
		||||
@ -124,18 +133,20 @@ async def _health():
 | 
			
		||||
 | 
			
		||||
@app.get('/-/config', summary='* Config', description='returns environment variables.')
 | 
			
		||||
async def _config():
 | 
			
		||||
    default_site, default_instance = Site.get_default_site(db), Instance.get_default_instance(db)
 | 
			
		||||
 | 
			
		||||
    return JSONr({
 | 
			
		||||
        'VERSION': str(VERSION),
 | 
			
		||||
        'COMMIT': str(COMMIT),
 | 
			
		||||
        'DEBUG': str(DEBUG),
 | 
			
		||||
        'DLS_URL': str(DLS_URL),
 | 
			
		||||
        'DLS_PORT': str(DLS_PORT),
 | 
			
		||||
        'SITE_KEY_XID': str(SITE_KEY_XID),
 | 
			
		||||
        'INSTANCE_REF': str(INSTANCE_REF),
 | 
			
		||||
        'SITE_KEY_XID': str(default_site.site_key),
 | 
			
		||||
        'INSTANCE_REF': str(default_instance.instance_ref),
 | 
			
		||||
        'ALLOTMENT_REF': [str(ALLOTMENT_REF)],
 | 
			
		||||
        'TOKEN_EXPIRE_DELTA': str(TOKEN_EXPIRE_DELTA),
 | 
			
		||||
        'LEASE_EXPIRE_DELTA': str(LEASE_EXPIRE_DELTA),
 | 
			
		||||
        'LEASE_RENEWAL_PERIOD': str(LEASE_RENEWAL_PERIOD),
 | 
			
		||||
        'TOKEN_EXPIRE_DELTA': str(default_instance.get_token_expire_delta()),
 | 
			
		||||
        'LEASE_EXPIRE_DELTA': str(default_instance.get_lease_expire_delta()),
 | 
			
		||||
        'LEASE_RENEWAL_PERIOD': str(default_instance.lease_renewal_period),
 | 
			
		||||
        'CORS_ORIGINS': str(CORS_ORIGINS),
 | 
			
		||||
        'TZ': str(TZ),
 | 
			
		||||
    })
 | 
			
		||||
@ -144,6 +155,7 @@ async def _config():
 | 
			
		||||
@app.get('/-/readme', summary='* Readme')
 | 
			
		||||
async def _readme():
 | 
			
		||||
    from markdown import markdown
 | 
			
		||||
    from util import load_file
 | 
			
		||||
    content = load_file(join(dirname(__file__), '../README.md')).decode('utf-8')
 | 
			
		||||
    return HTMLr(markdown(text=content, extensions=['tables', 'fenced_code', 'md_in_html', 'nl2br', 'toc']))
 | 
			
		||||
 | 
			
		||||
@ -193,8 +205,7 @@ async def _origins(request: Request, leases: bool = False):
 | 
			
		||||
    for origin in session.query(Origin).all():
 | 
			
		||||
        x = origin.serialize()
 | 
			
		||||
        if leases:
 | 
			
		||||
            serialize = dict(renewal_period=LEASE_RENEWAL_PERIOD, renewal_delta=LEASE_RENEWAL_DELTA)
 | 
			
		||||
            x['leases'] = list(map(lambda _: _.serialize(**serialize), Lease.find_by_origin_ref(db, origin.origin_ref)))
 | 
			
		||||
            x['leases'] = list(map(lambda _: _.serialize(), Lease.find_by_origin_ref(db, origin.origin_ref)))
 | 
			
		||||
        response.append(x)
 | 
			
		||||
    session.close()
 | 
			
		||||
    return JSONr(response)
 | 
			
		||||
@ -211,8 +222,7 @@ async def _leases(request: Request, origin: bool = False):
 | 
			
		||||
    session = sessionmaker(bind=db)()
 | 
			
		||||
    response = []
 | 
			
		||||
    for lease in session.query(Lease).all():
 | 
			
		||||
        serialize = dict(renewal_period=LEASE_RENEWAL_PERIOD, renewal_delta=LEASE_RENEWAL_DELTA)
 | 
			
		||||
        x = lease.serialize(**serialize)
 | 
			
		||||
        x = lease.serialize()
 | 
			
		||||
        if origin:
 | 
			
		||||
            lease_origin = session.query(Origin).filter(Origin.origin_ref == lease.origin_ref).first()
 | 
			
		||||
            if lease_origin is not None:
 | 
			
		||||
@ -239,7 +249,13 @@ async def _lease_delete(request: Request, lease_ref: str):
 | 
			
		||||
@app.get('/-/client-token', summary='* Client-Token', description='creates a new messenger token for this service instance')
 | 
			
		||||
async def _client_token():
 | 
			
		||||
    cur_time = datetime.now(UTC)
 | 
			
		||||
    exp_time = cur_time + CLIENT_TOKEN_EXPIRE_DELTA
 | 
			
		||||
 | 
			
		||||
    default_instance = Instance.get_default_instance(db)
 | 
			
		||||
    public_key = default_instance.get_public_key()
 | 
			
		||||
    # todo: implemented request parameter to support different instances
 | 
			
		||||
    jwt_encode_key = default_instance.get_jwt_encode_key()
 | 
			
		||||
 | 
			
		||||
    exp_time = cur_time + default_instance.get_client_token_expire_delta()
 | 
			
		||||
 | 
			
		||||
    payload = {
 | 
			
		||||
        "jti": str(uuid4()),
 | 
			
		||||
@ -252,7 +268,7 @@ async def _client_token():
 | 
			
		||||
        "scope_ref_list": [ALLOTMENT_REF],
 | 
			
		||||
        "fulfillment_class_ref_list": [],
 | 
			
		||||
        "service_instance_configuration": {
 | 
			
		||||
            "nls_service_instance_ref": INSTANCE_REF,
 | 
			
		||||
            "nls_service_instance_ref": default_instance.instance_ref,
 | 
			
		||||
            "svc_port_set_list": [
 | 
			
		||||
                {
 | 
			
		||||
                    "idx": 0,
 | 
			
		||||
@ -264,10 +280,10 @@ async def _client_token():
 | 
			
		||||
        },
 | 
			
		||||
        "service_instance_public_key_configuration": {
 | 
			
		||||
            "service_instance_public_key_me": {
 | 
			
		||||
                "mod": hex(INSTANCE_KEY_PUB.raw().public_numbers().n)[2:],
 | 
			
		||||
                "exp": int(INSTANCE_KEY_PUB.raw().public_numbers().e),
 | 
			
		||||
                "mod": hex(public_key.raw().public_numbers().n)[2:],
 | 
			
		||||
                "exp": int(public_key.raw().public_numbers().e),
 | 
			
		||||
            },
 | 
			
		||||
            "service_instance_public_key_pem": INSTANCE_KEY_PUB.pem().decode('utf-8'),
 | 
			
		||||
            "service_instance_public_key_pem": public_key.pem().decode('utf-8'),
 | 
			
		||||
            "key_retention_mode": "LATEST_ONLY"
 | 
			
		||||
        },
 | 
			
		||||
    }
 | 
			
		||||
@ -349,13 +365,16 @@ async def auth_v1_code(request: Request):
 | 
			
		||||
    delta = relativedelta(minutes=15)
 | 
			
		||||
    expires = cur_time + delta
 | 
			
		||||
 | 
			
		||||
    default_site = Site.get_default_site(db)
 | 
			
		||||
    jwt_encode_key = Instance.get_default_instance(db).get_jwt_encode_key()
 | 
			
		||||
 | 
			
		||||
    payload = {
 | 
			
		||||
        'iat': timegm(cur_time.timetuple()),
 | 
			
		||||
        'exp': timegm(expires.timetuple()),
 | 
			
		||||
        'challenge': j.get('code_challenge'),
 | 
			
		||||
        'origin_ref': j.get('origin_ref'),
 | 
			
		||||
        'key_ref': SITE_KEY_XID,
 | 
			
		||||
        'kid': SITE_KEY_XID
 | 
			
		||||
        'key_ref': default_site.site_key,
 | 
			
		||||
        'kid': default_site.site_key,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    auth_code = jws.sign(payload, key=jwt_encode_key, headers={'kid': payload.get('kid')}, algorithm=ALGORITHMS.RS256)
 | 
			
		||||
@ -375,6 +394,9 @@ async def auth_v1_code(request: Request):
 | 
			
		||||
async def auth_v1_token(request: Request):
 | 
			
		||||
    j, cur_time = json_loads((await request.body()).decode('utf-8')), datetime.now(UTC)
 | 
			
		||||
 | 
			
		||||
    default_site, default_instance = Site.get_default_site(db), Instance.get_default_instance(db)
 | 
			
		||||
    jwt_encode_key, jwt_decode_key = default_instance.get_jwt_encode_key(), default_instance.get_jwt_decode_key()
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        payload = jwt.decode(token=j.get('auth_code'), key=jwt_decode_key, algorithms=ALGORITHMS.RS256)
 | 
			
		||||
    except JWTError as e:
 | 
			
		||||
@ -388,7 +410,7 @@ async def auth_v1_token(request: Request):
 | 
			
		||||
    if payload.get('challenge') != challenge:
 | 
			
		||||
        return JSONr(status_code=401, content={'status': 401, 'detail': 'expected challenge did not match verifier'})
 | 
			
		||||
 | 
			
		||||
    access_expires_on = cur_time + TOKEN_EXPIRE_DELTA
 | 
			
		||||
    access_expires_on = cur_time + default_instance.get_token_expire_delta()
 | 
			
		||||
 | 
			
		||||
    new_payload = {
 | 
			
		||||
        'iat': timegm(cur_time.timetuple()),
 | 
			
		||||
@ -397,8 +419,8 @@ async def auth_v1_token(request: Request):
 | 
			
		||||
        'aud': 'https://cls.nvidia.org',
 | 
			
		||||
        'exp': timegm(access_expires_on.timetuple()),
 | 
			
		||||
        'origin_ref': origin_ref,
 | 
			
		||||
        'key_ref': SITE_KEY_XID,
 | 
			
		||||
        'kid': SITE_KEY_XID,
 | 
			
		||||
        'key_ref': default_site.site_key,
 | 
			
		||||
        'kid': default_site.site_key,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    auth_token = jwt.encode(new_payload, key=jwt_encode_key, headers={'kid': payload.get('kid')}, algorithm=ALGORITHMS.RS256)
 | 
			
		||||
@ -415,10 +437,13 @@ async def auth_v1_token(request: Request):
 | 
			
		||||
# venv/lib/python3.9/site-packages/nls_services_lease/test/test_lease_multi_controller.py
 | 
			
		||||
@app.post('/leasing/v1/lessor', description='request multiple leases (borrow) for current origin')
 | 
			
		||||
async def leasing_v1_lessor(request: Request):
 | 
			
		||||
    j, token, cur_time = json_loads((await request.body()).decode('utf-8')), __get_token(request), datetime.now(UTC)
 | 
			
		||||
    j, cur_time = json_loads((await request.body()).decode('utf-8')), datetime.now(UTC)
 | 
			
		||||
 | 
			
		||||
    default_instance = Instance.get_default_instance(db)
 | 
			
		||||
    jwt_decode_key = default_instance.get_jwt_decode_key()
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        token = __get_token(request)
 | 
			
		||||
        token = __get_token(request, jwt_decode_key)
 | 
			
		||||
    except JWTError:
 | 
			
		||||
        return JSONr(status_code=401, content={'status': 401, 'detail': 'token is not valid'})
 | 
			
		||||
 | 
			
		||||
@ -432,7 +457,7 @@ async def leasing_v1_lessor(request: Request):
 | 
			
		||||
        #     return JSONr(status_code=500, detail=f'no service instances found for scopes: ["{scope_ref}"]')
 | 
			
		||||
 | 
			
		||||
        lease_ref = str(uuid4())
 | 
			
		||||
        expires = cur_time + LEASE_EXPIRE_DELTA
 | 
			
		||||
        expires = cur_time + default_instance.get_lease_expire_delta()
 | 
			
		||||
        lease_result_list.append({
 | 
			
		||||
            "ordinal": 0,
 | 
			
		||||
            # https://docs.nvidia.com/license-system/latest/nvidia-license-system-user-guide/index.html
 | 
			
		||||
@ -440,13 +465,13 @@ async def leasing_v1_lessor(request: Request):
 | 
			
		||||
                "ref": lease_ref,
 | 
			
		||||
                "created": cur_time.isoformat(),
 | 
			
		||||
                "expires": expires.isoformat(),
 | 
			
		||||
                "recommended_lease_renewal": LEASE_RENEWAL_PERIOD,
 | 
			
		||||
                "recommended_lease_renewal": default_instance.lease_renewal_period,
 | 
			
		||||
                "offline_lease": "true",
 | 
			
		||||
                "license_type": "CONCURRENT_COUNTED_SINGLE"
 | 
			
		||||
            }
 | 
			
		||||
        })
 | 
			
		||||
 | 
			
		||||
        data = Lease(origin_ref=origin_ref, lease_ref=lease_ref, lease_created=cur_time, lease_expires=expires)
 | 
			
		||||
        data = Lease(instance_ref=default_instance.instance_ref, origin_ref=origin_ref, lease_ref=lease_ref, lease_created=cur_time, lease_expires=expires)
 | 
			
		||||
        Lease.create_or_update(db, data)
 | 
			
		||||
 | 
			
		||||
    response = {
 | 
			
		||||
@ -463,7 +488,14 @@ async def leasing_v1_lessor(request: Request):
 | 
			
		||||
# venv/lib/python3.9/site-packages/nls_dal_service_instance_dls/schema/service_instance/V1_0_21__product_mapping.sql
 | 
			
		||||
@app.get('/leasing/v1/lessor/leases', description='get active leases for current origin')
 | 
			
		||||
async def leasing_v1_lessor_lease(request: Request):
 | 
			
		||||
    token, cur_time = __get_token(request), datetime.now(UTC)
 | 
			
		||||
    cur_time = datetime.now(UTC)
 | 
			
		||||
 | 
			
		||||
    jwt_decode_key = Instance.get_default_instance(db).get_jwt_decode_key()
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        token = __get_token(request, jwt_decode_key)
 | 
			
		||||
    except JWTError:
 | 
			
		||||
        return JSONr(status_code=401, content={'status': 401, 'detail': 'token is not valid'})
 | 
			
		||||
 | 
			
		||||
    origin_ref = token.get('origin_ref')
 | 
			
		||||
 | 
			
		||||
@ -483,7 +515,15 @@ async def leasing_v1_lessor_lease(request: Request):
 | 
			
		||||
# venv/lib/python3.9/site-packages/nls_core_lease/lease_single.py
 | 
			
		||||
@app.put('/leasing/v1/lease/{lease_ref}', description='renew a lease')
 | 
			
		||||
async def leasing_v1_lease_renew(request: Request, lease_ref: str):
 | 
			
		||||
    token, cur_time = __get_token(request), datetime.now(UTC)
 | 
			
		||||
    cur_time = datetime.now(UTC)
 | 
			
		||||
 | 
			
		||||
    default_instance = Instance.get_default_instance(db)
 | 
			
		||||
    jwt_decode_key = default_instance.get_jwt_decode_key()
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        token = __get_token(request, jwt_decode_key)
 | 
			
		||||
    except JWTError:
 | 
			
		||||
        return JSONr(status_code=401, content={'status': 401, 'detail': 'token is not valid'})
 | 
			
		||||
 | 
			
		||||
    origin_ref = token.get('origin_ref')
 | 
			
		||||
    logger.info(f'> [  renew   ]: {origin_ref}: renew {lease_ref}')
 | 
			
		||||
@ -492,11 +532,11 @@ async def leasing_v1_lease_renew(request: Request, lease_ref: str):
 | 
			
		||||
    if entity is None:
 | 
			
		||||
        return JSONr(status_code=404, content={'status': 404, 'detail': 'requested lease not available'})
 | 
			
		||||
 | 
			
		||||
    expires = cur_time + LEASE_EXPIRE_DELTA
 | 
			
		||||
    expires = cur_time + default_instance.get_lease_expire_delta()
 | 
			
		||||
    response = {
 | 
			
		||||
        "lease_ref": lease_ref,
 | 
			
		||||
        "expires": expires.isoformat(),
 | 
			
		||||
        "recommended_lease_renewal": LEASE_RENEWAL_PERIOD,
 | 
			
		||||
        "recommended_lease_renewal": default_instance.lease_renewal_period,
 | 
			
		||||
        "offline_lease": True,
 | 
			
		||||
        "prompts": None,
 | 
			
		||||
        "sync_timestamp": cur_time.isoformat(),
 | 
			
		||||
@ -510,7 +550,14 @@ async def leasing_v1_lease_renew(request: Request, lease_ref: str):
 | 
			
		||||
# venv/lib/python3.9/site-packages/nls_services_lease/test/test_lease_single_controller.py
 | 
			
		||||
@app.delete('/leasing/v1/lease/{lease_ref}', description='release (return) a lease')
 | 
			
		||||
async def leasing_v1_lease_delete(request: Request, lease_ref: str):
 | 
			
		||||
    token, cur_time = __get_token(request), datetime.now(UTC)
 | 
			
		||||
    cur_time = datetime.now(UTC)
 | 
			
		||||
 | 
			
		||||
    jwt_decode_key = Instance.get_default_instance(db).get_jwt_decode_key()
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        token = __get_token(request, jwt_decode_key)
 | 
			
		||||
    except JWTError:
 | 
			
		||||
        return JSONr(status_code=401, content={'status': 401, 'detail': 'token is not valid'})
 | 
			
		||||
 | 
			
		||||
    origin_ref = token.get('origin_ref')
 | 
			
		||||
    logger.info(f'> [  return  ]: {origin_ref}: return {lease_ref}')
 | 
			
		||||
@ -536,7 +583,14 @@ async def leasing_v1_lease_delete(request: Request, lease_ref: str):
 | 
			
		||||
# venv/lib/python3.9/site-packages/nls_services_lease/test/test_lease_multi_controller.py
 | 
			
		||||
@app.delete('/leasing/v1/lessor/leases', description='release all leases')
 | 
			
		||||
async def leasing_v1_lessor_lease_remove(request: Request):
 | 
			
		||||
    token, cur_time = __get_token(request), datetime.now(UTC)
 | 
			
		||||
    cur_time = datetime.now(UTC)
 | 
			
		||||
 | 
			
		||||
    jwt_decode_key = Instance.get_default_instance(db).get_jwt_decode_key()
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        token = __get_token(request, jwt_decode_key)
 | 
			
		||||
    except JWTError:
 | 
			
		||||
        return JSONr(status_code=401, content={'status': 401, 'detail': 'token is not valid'})
 | 
			
		||||
 | 
			
		||||
    origin_ref = token.get('origin_ref')
 | 
			
		||||
 | 
			
		||||
@ -558,6 +612,8 @@ async def leasing_v1_lessor_lease_remove(request: Request):
 | 
			
		||||
async def leasing_v1_lessor_shutdown(request: Request):
 | 
			
		||||
    j, cur_time = json_loads((await request.body()).decode('utf-8')), datetime.now(UTC)
 | 
			
		||||
 | 
			
		||||
    jwt_decode_key = Instance.get_default_instance(db).get_jwt_decode_key()
 | 
			
		||||
 | 
			
		||||
    token = j.get('token')
 | 
			
		||||
    token = jwt.decode(token=token, key=jwt_decode_key, algorithms=ALGORITHMS.RS256, options={'verify_aud': False})
 | 
			
		||||
    origin_ref = token.get('origin_ref')
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										251
									
								
								app/orm.py
									
									
									
									
									
								
							
							
						
						
									
										251
									
								
								app/orm.py
									
									
									
									
									
								
							@ -1,20 +1,143 @@
 | 
			
		||||
import logging
 | 
			
		||||
from datetime import datetime, timedelta, timezone, UTC
 | 
			
		||||
from os import getenv as env
 | 
			
		||||
from os.path import join, dirname, isfile
 | 
			
		||||
 | 
			
		||||
from dateutil.relativedelta import relativedelta
 | 
			
		||||
from sqlalchemy import Column, VARCHAR, CHAR, ForeignKey, DATETIME, update, and_, inspect, text
 | 
			
		||||
from jose import jwk
 | 
			
		||||
from jose.constants import ALGORITHMS
 | 
			
		||||
from sqlalchemy import Column, VARCHAR, CHAR, ForeignKey, DATETIME, update, and_, inspect, text, BLOB, INT, FLOAT
 | 
			
		||||
from sqlalchemy.engine import Engine
 | 
			
		||||
from sqlalchemy.orm import sessionmaker, declarative_base
 | 
			
		||||
from sqlalchemy.orm import sessionmaker, declarative_base, Session, relationship
 | 
			
		||||
from sqlalchemy.schema import CreateTable
 | 
			
		||||
 | 
			
		||||
from util import DriverMatrix
 | 
			
		||||
from util import DriverMatrix, PrivateKey, PublicKey, DriverMatrix
 | 
			
		||||
 | 
			
		||||
logging.basicConfig()
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
logger.setLevel(logging.INFO)
 | 
			
		||||
 | 
			
		||||
Base = declarative_base()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Site(Base):
 | 
			
		||||
    __tablename__ = "site"
 | 
			
		||||
 | 
			
		||||
    INITIAL_SITE_KEY_XID = '10000000-0000-0000-0000-000000000000'
 | 
			
		||||
    INITIAL_SITE_NAME = 'default-site'
 | 
			
		||||
 | 
			
		||||
    site_key = Column(CHAR(length=36), primary_key=True, unique=True, index=True)  # uuid4, SITE_KEY_XID
 | 
			
		||||
    name = Column(VARCHAR(length=256), nullable=False)
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return f'SITE_KEY_XID: {self.site_key}'
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def create_statement(engine: Engine):
 | 
			
		||||
        return CreateTable(Site.__table__).compile(engine)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_default_site(engine: Engine) -> "Site":
 | 
			
		||||
        session = sessionmaker(bind=engine)()
 | 
			
		||||
        entity = session.query(Site).filter(Site.site_key == Site.INITIAL_SITE_KEY_XID).first()
 | 
			
		||||
        session.close()
 | 
			
		||||
        return entity
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Instance(Base):
 | 
			
		||||
    __tablename__ = "instance"
 | 
			
		||||
 | 
			
		||||
    DEFAULT_INSTANCE_REF = '10000000-0000-0000-0000-000000000001'
 | 
			
		||||
    DEFAULT_TOKEN_EXPIRE_DELTA = 86_400  # 1 day
 | 
			
		||||
    DEFAULT_LEASE_EXPIRE_DELTA = 7_776_000  # 90 days
 | 
			
		||||
    DEFAULT_LEASE_RENEWAL_PERIOD = 0.15
 | 
			
		||||
    DEFAULT_CLIENT_TOKEN_EXPIRE_DELTA = 378_432_000  # 12 years
 | 
			
		||||
    # 1 day = 86400 (min. in production setup, max 90 days), 1 hour = 3600
 | 
			
		||||
 | 
			
		||||
    instance_ref = Column(CHAR(length=36), primary_key=True, unique=True, index=True)  # uuid4, INSTANCE_REF
 | 
			
		||||
    site_key = Column(CHAR(length=36), ForeignKey(Site.site_key, ondelete='CASCADE'), nullable=False, index=True)  # uuid4
 | 
			
		||||
    private_key = Column(BLOB(length=2048), nullable=False)
 | 
			
		||||
    public_key = Column(BLOB(length=512), nullable=False)
 | 
			
		||||
    token_expire_delta = Column(INT(), nullable=False, default=DEFAULT_TOKEN_EXPIRE_DELTA, comment='in seconds')
 | 
			
		||||
    lease_expire_delta = Column(INT(), nullable=False, default=DEFAULT_LEASE_EXPIRE_DELTA, comment='in seconds')
 | 
			
		||||
    lease_renewal_period = Column(FLOAT(precision=2), nullable=False, default=DEFAULT_LEASE_RENEWAL_PERIOD)
 | 
			
		||||
    client_token_expire_delta = Column(INT(), nullable=False, default=DEFAULT_CLIENT_TOKEN_EXPIRE_DELTA, comment='in seconds')
 | 
			
		||||
 | 
			
		||||
    __origin = relationship(Site, foreign_keys=[site_key])
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return f'INSTANCE_REF: {self.instance_ref} (SITE_KEY_XID: {self.site_key})'
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def create_statement(engine: Engine):
 | 
			
		||||
        return CreateTable(Instance.__table__).compile(engine)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def create_or_update(engine: Engine, instance: "Instance"):
 | 
			
		||||
        session = sessionmaker(bind=engine)()
 | 
			
		||||
        entity = session.query(Instance).filter(Instance.instance_ref == instance.instance_ref).first()
 | 
			
		||||
        if entity is None:
 | 
			
		||||
            session.add(instance)
 | 
			
		||||
        else:
 | 
			
		||||
            x = dict(
 | 
			
		||||
                site_key=instance.site_key,
 | 
			
		||||
                private_key=instance.private_key,
 | 
			
		||||
                public_key=instance.public_key,
 | 
			
		||||
                token_expire_delta=instance.token_expire_delta,
 | 
			
		||||
                lease_expire_delta=instance.lease_expire_delta,
 | 
			
		||||
                lease_renewal_period=instance.lease_renewal_period,
 | 
			
		||||
                client_token_expire_delta=instance.client_token_expire_delta,
 | 
			
		||||
            )
 | 
			
		||||
            session.execute(update(Instance).where(Instance.instance_ref == instance.instance_ref).values(**x))
 | 
			
		||||
        session.commit()
 | 
			
		||||
        session.flush()
 | 
			
		||||
        session.close()
 | 
			
		||||
 | 
			
		||||
    # todo: validate on startup that "lease_expire_delta" is between 1 day and 90 days
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_default_instance(engine: Engine) -> "Instance":
 | 
			
		||||
        session = sessionmaker(bind=engine)()
 | 
			
		||||
        site = Site.get_default_site(engine)
 | 
			
		||||
        entity = session.query(Instance).filter(Instance.site_key == site.site_key).first()
 | 
			
		||||
        session.close()
 | 
			
		||||
        return entity
 | 
			
		||||
 | 
			
		||||
    def get_token_expire_delta(self) -> "dateutil.relativedelta.relativedelta":
 | 
			
		||||
        return relativedelta(seconds=self.token_expire_delta)
 | 
			
		||||
 | 
			
		||||
    def get_lease_expire_delta(self) -> "dateutil.relativedelta.relativedelta":
 | 
			
		||||
        return relativedelta(seconds=self.lease_expire_delta)
 | 
			
		||||
 | 
			
		||||
    def get_lease_renewal_delta(self) -> "datetime.timedelta":
 | 
			
		||||
        return timedelta(seconds=self.lease_expire_delta)
 | 
			
		||||
 | 
			
		||||
    def get_client_token_expire_delta(self) -> "dateutil.relativedelta.relativedelta":
 | 
			
		||||
        return relativedelta(seconds=self.client_token_expire_delta)
 | 
			
		||||
 | 
			
		||||
    def __get_private_key(self) -> "PrivateKey":
 | 
			
		||||
        return PrivateKey(self.private_key)
 | 
			
		||||
 | 
			
		||||
    def get_public_key(self) -> "PublicKey":
 | 
			
		||||
        return PublicKey(self.public_key)
 | 
			
		||||
 | 
			
		||||
    def get_jwt_encode_key(self) -> "jose.jkw":
 | 
			
		||||
        return jwk.construct(self.__get_private_key().pem().decode('utf-8'), algorithm=ALGORITHMS.RS256)
 | 
			
		||||
 | 
			
		||||
    def get_jwt_decode_key(self) -> "jose.jwt":
 | 
			
		||||
        return jwk.construct(self.get_public_key().pem().decode('utf-8'), algorithm=ALGORITHMS.RS256)
 | 
			
		||||
 | 
			
		||||
    def get_private_key_str(self, encoding: str = 'utf-8') -> str:
 | 
			
		||||
        return self.private_key.decode(encoding)
 | 
			
		||||
 | 
			
		||||
    def get_public_key_str(self, encoding: str = 'utf-8') -> str:
 | 
			
		||||
        return self.private_key.decode(encoding)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Origin(Base):
 | 
			
		||||
    __tablename__ = "origin"
 | 
			
		||||
 | 
			
		||||
    origin_ref = Column(CHAR(length=36), primary_key=True, unique=True, index=True)  # uuid4
 | 
			
		||||
 | 
			
		||||
    # service_instance_xid = Column(CHAR(length=36), nullable=False, index=True)  # uuid4 # not necessary, we only support one service_instance_xid ('INSTANCE_REF')
 | 
			
		||||
    hostname = Column(VARCHAR(length=256), nullable=True)
 | 
			
		||||
    guest_driver_version = Column(VARCHAR(length=10), nullable=True)
 | 
			
		||||
@ -39,7 +162,6 @@ class Origin(Base):
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def create_statement(engine: Engine):
 | 
			
		||||
        from sqlalchemy.schema import CreateTable
 | 
			
		||||
        return CreateTable(Origin.__table__).compile(engine)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
@ -85,18 +207,24 @@ class Origin(Base):
 | 
			
		||||
class Lease(Base):
 | 
			
		||||
    __tablename__ = "lease"
 | 
			
		||||
 | 
			
		||||
    instance_ref = Column(CHAR(length=36), ForeignKey(Instance.instance_ref, ondelete='CASCADE'), nullable=False, index=True)  # uuid4
 | 
			
		||||
    lease_ref = Column(CHAR(length=36), primary_key=True, nullable=False, index=True)  # uuid4
 | 
			
		||||
 | 
			
		||||
    origin_ref = Column(CHAR(length=36), ForeignKey(Origin.origin_ref, ondelete='CASCADE'), nullable=False, index=True)  # uuid4
 | 
			
		||||
    # scope_ref = Column(CHAR(length=36), nullable=False, index=True)  # uuid4 # not necessary, we only support one scope_ref ('ALLOTMENT_REF')
 | 
			
		||||
    lease_created = Column(DATETIME(), nullable=False)
 | 
			
		||||
    lease_expires = Column(DATETIME(), nullable=False)
 | 
			
		||||
    lease_updated = Column(DATETIME(), nullable=False)
 | 
			
		||||
 | 
			
		||||
    __instance = relationship(Instance, foreign_keys=[instance_ref])
 | 
			
		||||
    __origin = relationship(Origin, foreign_keys=[origin_ref])
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return f'Lease(origin_ref={self.origin_ref}, lease_ref={self.lease_ref}, expires={self.lease_expires})'
 | 
			
		||||
 | 
			
		||||
    def serialize(self, renewal_period: float, renewal_delta: timedelta) -> dict:
 | 
			
		||||
    def serialize(self) -> dict:
 | 
			
		||||
        renewal_period = self.__instance.lease_renewal_period
 | 
			
		||||
        renewal_delta = self.__instance.get_lease_renewal_delta
 | 
			
		||||
 | 
			
		||||
        lease_renewal = int(Lease.calculate_renewal(renewal_period, renewal_delta).total_seconds())
 | 
			
		||||
        lease_renewal = self.lease_updated + relativedelta(seconds=lease_renewal)
 | 
			
		||||
 | 
			
		||||
@ -112,7 +240,6 @@ class Lease(Base):
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def create_statement(engine: Engine):
 | 
			
		||||
        from sqlalchemy.schema import CreateTable
 | 
			
		||||
        return CreateTable(Lease.__table__).compile(engine)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
@ -206,38 +333,104 @@ class Lease(Base):
 | 
			
		||||
        return renew
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def init_default_site(session: Session):
 | 
			
		||||
    private_key = PrivateKey.generate()
 | 
			
		||||
    public_key = private_key.public_key()
 | 
			
		||||
 | 
			
		||||
    site = Site(
 | 
			
		||||
        site_key=Site.INITIAL_SITE_KEY_XID,
 | 
			
		||||
        name=Site.INITIAL_SITE_NAME
 | 
			
		||||
    )
 | 
			
		||||
    session.add(site)
 | 
			
		||||
    session.commit()
 | 
			
		||||
 | 
			
		||||
    instance = Instance(
 | 
			
		||||
        instance_ref=Instance.DEFAULT_INSTANCE_REF,
 | 
			
		||||
        site_key=site.site_key,
 | 
			
		||||
        private_key=private_key.pem(),
 | 
			
		||||
        public_key=public_key.pem(),
 | 
			
		||||
    )
 | 
			
		||||
    session.add(instance)
 | 
			
		||||
    session.commit()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def init(engine: Engine):
 | 
			
		||||
    tables = [Origin, Lease]
 | 
			
		||||
    tables = [Site, Instance, Origin, Lease]
 | 
			
		||||
    db = inspect(engine)
 | 
			
		||||
    session = sessionmaker(bind=engine)()
 | 
			
		||||
    for table in tables:
 | 
			
		||||
        if not db.dialect.has_table(engine.connect(), table.__tablename__):
 | 
			
		||||
        exists = db.dialect.has_table(engine.connect(), table.__tablename__)
 | 
			
		||||
        logger.info(f'> Table "{table.__tablename__:<16}" exists: {exists}')
 | 
			
		||||
        if not exists:
 | 
			
		||||
            session.execute(text(str(table.create_statement(engine))))
 | 
			
		||||
            session.commit()
 | 
			
		||||
 | 
			
		||||
    # create default site
 | 
			
		||||
    cnt = session.query(Site).count()
 | 
			
		||||
    if cnt == 0:
 | 
			
		||||
        init_default_site(session)
 | 
			
		||||
 | 
			
		||||
    session.flush()
 | 
			
		||||
    session.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def migrate(engine: Engine):
 | 
			
		||||
    db = inspect(engine)
 | 
			
		||||
 | 
			
		||||
    def upgrade_1_0_to_1_1():
 | 
			
		||||
        x = db.dialect.get_columns(engine.connect(), Lease.__tablename__)
 | 
			
		||||
        x = next(_ for _ in x if _['name'] == 'origin_ref')
 | 
			
		||||
        if x['primary_key'] > 0:
 | 
			
		||||
            print('Found old database schema with "origin_ref" as primary-key in "lease" table. Dropping table!')
 | 
			
		||||
            print('  Your leases are recreated on next renewal!')
 | 
			
		||||
            print('  If an error message appears on the client, you can ignore it.')
 | 
			
		||||
            Lease.__table__.drop(bind=engine)
 | 
			
		||||
            init(engine)
 | 
			
		||||
    # todo: add update guide to use 1.LATEST to 2.0
 | 
			
		||||
    def upgrade_1_x_to_2_0():
 | 
			
		||||
        site = Site.get_default_site(engine)
 | 
			
		||||
        logger.info(site)
 | 
			
		||||
        instance = Instance.get_default_instance(engine)
 | 
			
		||||
        logger.info(instance)
 | 
			
		||||
 | 
			
		||||
    # def upgrade_1_2_to_1_3():
 | 
			
		||||
    #    x = db.dialect.get_columns(engine.connect(), Lease.__tablename__)
 | 
			
		||||
    #    x = next((_ for _ in x if _['name'] == 'scope_ref'), None)
 | 
			
		||||
    #    if x is None:
 | 
			
		||||
    #        Lease.scope_ref.compile()
 | 
			
		||||
    #        column_name = Lease.scope_ref.name
 | 
			
		||||
    #        column_type = Lease.scope_ref.type.compile(engine.dialect)
 | 
			
		||||
    #        engine.execute(f'ALTER TABLE "{Lease.__tablename__}" ADD COLUMN "{column_name}" {column_type}')
 | 
			
		||||
        # SITE_KEY_XID
 | 
			
		||||
        if site_key := env('SITE_KEY_XID', None) is not None:
 | 
			
		||||
            site.site_key = str(site_key)
 | 
			
		||||
 | 
			
		||||
    upgrade_1_0_to_1_1()
 | 
			
		||||
    # upgrade_1_2_to_1_3()
 | 
			
		||||
        # INSTANCE_REF
 | 
			
		||||
        if instance_ref := env('INSTANCE_REF', None) is not None:
 | 
			
		||||
            instance.instance_ref = str(instance_ref)
 | 
			
		||||
 | 
			
		||||
        # ALLOTMENT_REF
 | 
			
		||||
        if allotment_ref := env('ALLOTMENT_REF', None) is not None:
 | 
			
		||||
            pass  # todo
 | 
			
		||||
 | 
			
		||||
        # INSTANCE_KEY_RSA, INSTANCE_KEY_PUB
 | 
			
		||||
        default_instance_private_key_path = str(join(dirname(__file__), 'cert/instance.private.pem'))
 | 
			
		||||
        instance_private_key = env('INSTANCE_KEY_RSA', None)
 | 
			
		||||
        if instance_private_key is not None:
 | 
			
		||||
            instance.private_key = PrivateKey(instance_private_key.encode('utf-8'))
 | 
			
		||||
        elif isfile(default_instance_private_key_path):
 | 
			
		||||
            instance.private_key = PrivateKey.from_file(default_instance_private_key_path)
 | 
			
		||||
        default_instance_public_key_path = str(join(dirname(__file__), 'cert/instance.public.pem'))
 | 
			
		||||
        instance_public_key = env('INSTANCE_KEY_PUB', None)
 | 
			
		||||
        if instance_public_key is not None:
 | 
			
		||||
            instance.public_key = PublicKey(instance_public_key.encode('utf-8'))
 | 
			
		||||
        elif isfile(default_instance_public_key_path):
 | 
			
		||||
            instance.public_key = PublicKey.from_file(default_instance_public_key_path)
 | 
			
		||||
 | 
			
		||||
        # TOKEN_EXPIRE_DELTA
 | 
			
		||||
        token_expire_delta = env('TOKEN_EXPIRE_DAYS', None)
 | 
			
		||||
        if token_expire_delta not in (None, 0):
 | 
			
		||||
            instance.token_expire_delta = token_expire_delta * 86_400
 | 
			
		||||
        token_expire_delta = env('TOKEN_EXPIRE_HOURS', None)
 | 
			
		||||
        if token_expire_delta not in (None, 0):
 | 
			
		||||
            instance.token_expire_delta = token_expire_delta * 3_600
 | 
			
		||||
 | 
			
		||||
        # LEASE_EXPIRE_DELTA, LEASE_RENEWAL_DELTA
 | 
			
		||||
        lease_expire_delta = env('LEASE_EXPIRE_DAYS', None)
 | 
			
		||||
        if lease_expire_delta not in (None, 0):
 | 
			
		||||
            instance.lease_expire_delta = lease_expire_delta * 86_400
 | 
			
		||||
        lease_expire_delta = env('LEASE_EXPIRE_HOURS', None)
 | 
			
		||||
        if lease_expire_delta not in (None, 0):
 | 
			
		||||
            instance.lease_expire_delta = lease_expire_delta * 3_600
 | 
			
		||||
 | 
			
		||||
        # LEASE_RENEWAL_PERIOD
 | 
			
		||||
        lease_renewal_period = env('LEASE_RENEWAL_PERIOD', None)
 | 
			
		||||
        if lease_renewal_period is not None:
 | 
			
		||||
            instance.lease_renewal_period = lease_renewal_period
 | 
			
		||||
 | 
			
		||||
        # todo: update site, instance
 | 
			
		||||
 | 
			
		||||
    upgrade_1_x_to_2_0()
 | 
			
		||||
 | 
			
		||||
@ -104,7 +104,7 @@ class DriverMatrix:
 | 
			
		||||
            self.log.debug(f'Successfully loaded "{DriverMatrix.__DRIVER_MATRIX_FILENAME}".')
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            DriverMatrix.__DRIVER_MATRIX = {}  # init empty dict to not try open file everytime, just when restarting app
 | 
			
		||||
            # self.log.warning(f'Failed to load "{NV.__DRIVER_MATRIX_FILENAME}": {e}')
 | 
			
		||||
            # self.log.warning(f'Failed to load "{DriverMatrix.__DRIVER_MATRIX_FILENAME}": {e}')
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def find(version: str) -> dict | None:
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										30
									
								
								test/main.py
									
									
									
									
									
								
							
							
						
						
									
										30
									
								
								test/main.py
									
									
									
									
									
								
							@ -3,12 +3,13 @@ from base64 import b64encode as b64enc
 | 
			
		||||
from calendar import timegm
 | 
			
		||||
from datetime import datetime, UTC
 | 
			
		||||
from hashlib import sha256
 | 
			
		||||
from os.path import dirname, join
 | 
			
		||||
from os import getenv as env
 | 
			
		||||
from uuid import uuid4, UUID
 | 
			
		||||
 | 
			
		||||
from dateutil.relativedelta import relativedelta
 | 
			
		||||
from jose import jwt, jwk
 | 
			
		||||
from jose import jwt
 | 
			
		||||
from jose.constants import ALGORITHMS
 | 
			
		||||
from sqlalchemy import create_engine
 | 
			
		||||
from starlette.testclient import TestClient
 | 
			
		||||
 | 
			
		||||
# add relative path to use packages as they were in the app/ dir
 | 
			
		||||
@ -16,20 +17,23 @@ sys.path.append('../')
 | 
			
		||||
sys.path.append('../app')
 | 
			
		||||
 | 
			
		||||
from app import main
 | 
			
		||||
from util import PrivateKey, PublicKey
 | 
			
		||||
from orm import init as db_init, migrate, Site, Instance
 | 
			
		||||
 | 
			
		||||
client = TestClient(main.app)
 | 
			
		||||
 | 
			
		||||
ORIGIN_REF, ALLOTMENT_REF, SECRET = str(uuid4()), '20000000-0000-0000-0000-000000000001', 'HelloWorld'
 | 
			
		||||
 | 
			
		||||
# INSTANCE_KEY_RSA = generate_key()
 | 
			
		||||
# INSTANCE_KEY_PUB = INSTANCE_KEY_RSA.public_key()
 | 
			
		||||
# fastapi setup
 | 
			
		||||
client = TestClient(main.app)
 | 
			
		||||
 | 
			
		||||
INSTANCE_KEY_RSA = PrivateKey.from_file(str(join(dirname(__file__), '../app/cert/instance.private.pem')))
 | 
			
		||||
INSTANCE_KEY_PUB = PublicKey.from_file(str(join(dirname(__file__), '../app/cert/instance.public.pem')))
 | 
			
		||||
# database setup
 | 
			
		||||
db = create_engine(str(env('DATABASE', 'sqlite:///db.sqlite')))
 | 
			
		||||
db_init(db), migrate(db)
 | 
			
		||||
 | 
			
		||||
jwt_encode_key = jwk.construct(INSTANCE_KEY_RSA.pem(), algorithm=ALGORITHMS.RS256)
 | 
			
		||||
jwt_decode_key = jwk.construct(INSTANCE_KEY_PUB.pem(), algorithm=ALGORITHMS.RS256)
 | 
			
		||||
# test vars
 | 
			
		||||
DEFAULT_SITE, DEFAULT_INSTANCE = Site.get_default_site(db), Instance.get_default_instance(db)
 | 
			
		||||
 | 
			
		||||
SITE_KEY = DEFAULT_SITE.site_key
 | 
			
		||||
jwt_encode_key, jwt_decode_key = DEFAULT_INSTANCE.get_jwt_encode_key(), DEFAULT_INSTANCE.get_jwt_decode_key()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def __bearer_token(origin_ref: str) -> str:
 | 
			
		||||
@ -38,6 +42,12 @@ def __bearer_token(origin_ref: str) -> str:
 | 
			
		||||
    return token
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_initial_default_site_and_instance():
 | 
			
		||||
    default_site, default_instance = Site.get_default_site(db), Instance.get_default_instance(db)
 | 
			
		||||
    assert default_site.site_key == Site.INITIAL_SITE_KEY_XID
 | 
			
		||||
    assert default_instance.instance_ref == Instance.DEFAULT_INSTANCE_REF
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_index():
 | 
			
		||||
    response = client.get('/')
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user