lenticular_cloud2/lenticular_cloud/model.py

258 lines
8.4 KiB
Python

from flask import current_app
from flask_login import UserMixin
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives import serialization
from collections.abc import MutableSequence
from datetime import datetime
from dateutil import tz
import pyotp
import json
import logging
import crypt
import secrets
import string
from sqlalchemy import null
from sqlalchemy.orm import DeclarativeBase, MappedAsDataclass, Mapped, mapped_column, relationship, declarative_base
from flask_sqlalchemy import SQLAlchemy
from flask_sqlalchemy.model import Model, DefaultMeta
from flask_sqlalchemy.extension import _FSAModel
from flask_migrate import Migrate
from datetime import datetime
import uuid
from typing import Iterator, Optional, List, Dict, Tuple, Any, Type, TYPE_CHECKING
from cryptography.x509 import Certificate as CertificateObj
from sqlalchemy.ext.declarative import DeclarativeMeta
logger = logging.getLogger(__name__)
db = SQLAlchemy()
migrate = Migrate()
class BaseModelIntern(MappedAsDataclass, DeclarativeBase):
pass
if TYPE_CHECKING:
class BaseModel (_FSAModel,BaseModelIntern):
pass
else:
BaseModel: Type[_FSAModel] = db.Model
class ModelUpdatedMixin:
created_at: Mapped[datetime] = mapped_column(db.DateTime, default=datetime.now(), nullable=False)
modified_at: Mapped[datetime] = mapped_column(db.DateTime, default=datetime.now(), onupdate=datetime.now, nullable=False)
class SecurityUser(UserMixin):
def __init__(self, username):
self._username = username
def get_id(self):
return self._username
class Service(object):
def __init__(self, name: str):
self._name = name
self._app_token = False
self._icon: Optional[str] = None
self._href: Optional[str] = None
self._client_cert = False
self._pki_config = {
'cn': '{username}',
'email': '{username}@{domain}'
}
@staticmethod
def from_config(name, config) -> 'Service':
service = Service(name)
if 'app_token' in config:
service._app_token = bool(config['app_token'])
if 'client_cert' in config:
service._client_cert = bool(config['client_cert'])
if 'pki_config' in config:
service._pki_config.update(config['pki_config'])
if 'icon' in config:
service._icon = str(config['icon'])
if 'href' in config:
service._href = str(config['href'])
return service
@property
def name(self) -> str:
return self._name
@property
def client_cert(self) -> bool:
return self._client_cert
@property
def app_token(self) -> bool:
return self._app_token
@property
def pki_config(self) -> dict[str,str]:
if not self._client_cert:
raise Exception('invalid call')
return self._pki_config
@property
def icon(self) -> Optional[str]:
return self._icon
@property
def href(self) -> Optional[str]:
return self._href
class Certificate(object):
def __init__(self, cn: str, ca_name: str, cert_data: CertificateObj, revoked=False):
self._cn = cn
self._ca_name = ca_name
self._cert_data = cert_data
self._revoked = revoked
self._cert_data.not_valid_after.replace(tzinfo=tz.tzutc())
self._cert_data.not_valid_before.replace(tzinfo=tz.tzutc())
@property
def cn(self) -> str:
return self._cn
@property
def ca_name(self) -> str:
return self._ca_name
@property
def not_valid_before(self) -> datetime:
return self._cert_data.not_valid_before.replace(tzinfo=tz.tzutc()).astimezone(tz.tzlocal()).replace(tzinfo=None)
@property
def not_valid_after(self) -> datetime:
return self._cert_data.not_valid_after.replace(tzinfo=tz.tzutc()).astimezone(tz.tzlocal()).replace(tzinfo=None)
@property
def serial_number(self) -> int:
return self._cert_data.serial_number
@property
def serial_number_hex(self) -> str:
return f'{self._cert_data.serial_number:X}'
def fingerprint(self, algorithm=hashes.SHA256()) -> bytes:
return self._cert_data.fingerprint(algorithm)
@property
def is_valid(self) -> bool:
return self.not_valid_after > datetime.now() and not self._revoked
def pem(self) -> str:
return self._cert_data.public_bytes(encoding=serialization.Encoding.PEM).decode()
@property
def raw(self):
return self._cert_data
def __str__(self):
return f'Certificate(cn={self._cn}, ca_name={self._ca_name}, not_valid_before={self.not_valid_before}, not_valid_after={self.not_valid_after})'
def generate_uuid():
return str(uuid.uuid4())
class User(BaseModel, ModelUpdatedMixin):
id: Mapped[uuid.UUID] = mapped_column(db.Uuid, primary_key=True, default=uuid.uuid4)
username: Mapped[str] = mapped_column(db.String, unique=True, nullable=False)
password_hashed: Mapped[str] = mapped_column(db.String, nullable=False)
alternative_email: Mapped[Optional[str]] = mapped_column( db.String, nullable=True)
last_login: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True)
enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, default=False)
app_tokens: Mapped[List['AppToken']] = relationship('AppToken', back_populates='user')
passkey_credentials: Mapped[List['PasskeyCredential']] = relationship('PasskeyCredential', back_populates='user', cascade='delete,delete-orphan', passive_deletes=True)
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
@property
def is_authenticated(self) -> bool:
return True # TODO
def get(self, key) -> None:
print(f'getitem: {key}') # TODO
@property
def groups(self) -> list['Group']:
admins = current_app.config['ADMINS']
if self.username in admins:
return [Group(name='admin')]
else:
return []
@property
def email(self) -> str:
domain = current_app.config['DOMAIN']
return f'{self.username}@{domain}'
def change_password(self, password_new: str) -> None:
self.password_hashed = crypt.crypt(password_new)
def get_token_by_name(self, name: str) -> Optional['AppToken']:
for token in self.app_tokens:
if token.name == name:
return token
return None
def get_token_by_scope(self, scope: str) -> Iterator['AppToken']:
for token in self.app_tokens:
if scope in token.scopes.split():
yield token # type: ignore
class AppToken(BaseModel, ModelUpdatedMixin):
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
scopes: Mapped[str] = mapped_column(nullable=False) # string of a list seperated by `,`
user_id: Mapped[uuid.UUID] = mapped_column(
db.Uuid,
db.ForeignKey(User.id), nullable=False)
user: Mapped[User] = relationship(User, back_populates="app_tokens")
token: Mapped[str] = mapped_column(nullable=False)
name: Mapped[str] = mapped_column(nullable=False)
last_used: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True, default=None)
@staticmethod
def new(user: User, scopes: str, name: str):
alphabet = string.ascii_letters + string.digits
token = ''.join(secrets.choice(alphabet) for i in range(12))
return AppToken(scopes=scopes, token=token, user=user, name=name)
class PasskeyCredential(BaseModel, ModelUpdatedMixin): # pylint: disable=too-few-public-methods
"""Passkey credential model"""
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
user_id: Mapped[uuid.UUID] = mapped_column(db.Uuid, db.ForeignKey('user.id', ondelete='CASCADE'), nullable=False)
credential_id: Mapped[bytes] = mapped_column(db.LargeBinary, nullable=False)
credential_public_key: Mapped[bytes] = mapped_column(db.LargeBinary, nullable=False)
name: Mapped[str] = mapped_column(db.String(250), nullable=False)
last_used: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True, default=None)
sign_count: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0)
user = db.relationship('User', back_populates='passkey_credentials')
class Group(BaseModel, ModelUpdatedMixin):
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
name: Mapped[str] = mapped_column(db.String(), nullable=False, unique=True)