from flask import current_app from flask_login import UserMixin from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import serialization from collections.abc import MutableSequence from datetime import datetime from dateutil import tz import pyotp import json import logging import crypt import secrets import string from sqlalchemy import null from sqlalchemy.orm import DeclarativeBase, MappedAsDataclass, Mapped, mapped_column, relationship, declarative_base from flask_sqlalchemy import SQLAlchemy from flask_sqlalchemy.model import Model, DefaultMeta from flask_sqlalchemy.extension import _FSAModel from flask_migrate import Migrate from datetime import datetime import uuid from typing import Iterator, Optional, List, Dict, Tuple, Any, Type, TYPE_CHECKING from cryptography.x509 import Certificate as CertificateObj from sqlalchemy.ext.declarative import DeclarativeMeta logger = logging.getLogger(__name__) db = SQLAlchemy() migrate = Migrate() class BaseModelIntern(MappedAsDataclass, DeclarativeBase): pass if TYPE_CHECKING: class BaseModel (_FSAModel,BaseModelIntern): pass else: BaseModel: Type[_FSAModel] = db.Model class ModelUpdatedMixin: created_at: Mapped[datetime] = mapped_column(db.DateTime, default=datetime.now(), nullable=False) modified_at: Mapped[datetime] = mapped_column(db.DateTime, default=datetime.now(), onupdate=datetime.now, nullable=False) class SecurityUser(UserMixin): def __init__(self, username): self._username = username def get_id(self): return self._username class Service(object): def __init__(self, name: str): self._name = name self._app_token = False self._icon: Optional[str] = None self._href: Optional[str] = None self._client_cert = False self._pki_config = { 'cn': '{username}', 'email': '{username}@{domain}' } @staticmethod def from_config(name, config) -> 'Service': service = Service(name) if 'app_token' in config: service._app_token = bool(config['app_token']) if 'client_cert' in config: service._client_cert = bool(config['client_cert']) if 'pki_config' in config: service._pki_config.update(config['pki_config']) if 'icon' in config: service._icon = str(config['icon']) if 'href' in config: service._href = str(config['href']) return service @property def name(self) -> str: return self._name @property def client_cert(self) -> bool: return self._client_cert @property def app_token(self) -> bool: return self._app_token @property def pki_config(self) -> dict[str,str]: if not self._client_cert: raise Exception('invalid call') return self._pki_config @property def icon(self) -> Optional[str]: return self._icon @property def href(self) -> Optional[str]: return self._href class Certificate(object): def __init__(self, cn: str, ca_name: str, cert_data: CertificateObj, revoked=False): self._cn = cn self._ca_name = ca_name self._cert_data = cert_data self._revoked = revoked self._cert_data.not_valid_after.replace(tzinfo=tz.tzutc()) self._cert_data.not_valid_before.replace(tzinfo=tz.tzutc()) @property def cn(self) -> str: return self._cn @property def ca_name(self) -> str: return self._ca_name @property def not_valid_before(self) -> datetime: return self._cert_data.not_valid_before.replace(tzinfo=tz.tzutc()).astimezone(tz.tzlocal()).replace(tzinfo=None) @property def not_valid_after(self) -> datetime: return self._cert_data.not_valid_after.replace(tzinfo=tz.tzutc()).astimezone(tz.tzlocal()).replace(tzinfo=None) @property def serial_number(self) -> int: return self._cert_data.serial_number @property def serial_number_hex(self) -> str: return f'{self._cert_data.serial_number:X}' def fingerprint(self, algorithm=hashes.SHA256()) -> bytes: return self._cert_data.fingerprint(algorithm) @property def is_valid(self) -> bool: return self.not_valid_after > datetime.now() and not self._revoked def pem(self) -> str: return self._cert_data.public_bytes(encoding=serialization.Encoding.PEM).decode() @property def raw(self): return self._cert_data def __str__(self): return f'Certificate(cn={self._cn}, ca_name={self._ca_name}, not_valid_before={self.not_valid_before}, not_valid_after={self.not_valid_after})' def generate_uuid(): return str(uuid.uuid4()) class User(BaseModel, ModelUpdatedMixin): id: Mapped[uuid.UUID] = mapped_column(db.Uuid, primary_key=True, default=uuid.uuid4) username: Mapped[str] = mapped_column(db.String, unique=True, nullable=False) password_hashed: Mapped[str] = mapped_column(db.String, nullable=False) alternative_email: Mapped[Optional[str]] = mapped_column( db.String, nullable=True) last_login: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, default=False) app_tokens: Mapped[List['AppToken']] = relationship('AppToken', back_populates='user') passkey_credentials: Mapped[List['PasskeyCredential']] = relationship('PasskeyCredential', back_populates='user', cascade='delete,delete-orphan', passive_deletes=True) def __init__(self, **kwargs) -> None: super().__init__(**kwargs) @property def is_authenticated(self) -> bool: return True # TODO def get(self, key) -> None: print(f'getitem: {key}') # TODO @property def groups(self) -> list['Group']: admins = current_app.config['ADMINS'] if self.username in admins: return [Group(name='admin')] else: return [] @property def email(self) -> str: domain = current_app.config['DOMAIN'] return f'{self.username}@{domain}' def change_password(self, password_new: str) -> None: self.password_hashed = crypt.crypt(password_new) def get_token_by_name(self, name: str) -> Optional['AppToken']: for token in self.app_tokens: if token.name == name: return token return None def get_token_by_scope(self, scope: str) -> Iterator['AppToken']: for token in self.app_tokens: if scope in token.scopes.split(): yield token # type: ignore class AppToken(BaseModel, ModelUpdatedMixin): id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) scopes: Mapped[str] = mapped_column(nullable=False) # string of a list seperated by `,` user_id: Mapped[uuid.UUID] = mapped_column( db.Uuid, db.ForeignKey(User.id), nullable=False) user: Mapped[User] = relationship(User, back_populates="app_tokens") token: Mapped[str] = mapped_column(nullable=False) name: Mapped[str] = mapped_column(nullable=False) last_used: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True, default=None) @staticmethod def new(user: User, scopes: str, name: str): alphabet = string.ascii_letters + string.digits token = ''.join(secrets.choice(alphabet) for i in range(12)) return AppToken(scopes=scopes, token=token, user=user, name=name) class PasskeyCredential(BaseModel, ModelUpdatedMixin): # pylint: disable=too-few-public-methods """Passkey credential model""" id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) user_id: Mapped[uuid.UUID] = mapped_column(db.Uuid, db.ForeignKey('user.id', ondelete='CASCADE'), nullable=False) credential_id: Mapped[bytes] = mapped_column(db.LargeBinary, nullable=False) credential_public_key: Mapped[bytes] = mapped_column(db.LargeBinary, nullable=False) name: Mapped[str] = mapped_column(db.String(250), nullable=False) last_used: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True, default=None) sign_count: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0) user = db.relationship('User', back_populates='passkey_credentials') class Group(BaseModel, ModelUpdatedMixin): id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) name: Mapped[str] = mapped_column(db.String(), nullable=False, unique=True)