diff --git a/lenticular_cloud/app.py b/lenticular_cloud/app.py index f21c093..adfbe6d 100644 --- a/lenticular_cloud/app.py +++ b/lenticular_cloud/app.py @@ -3,16 +3,17 @@ from flask import g, redirect, request from flask.helpers import url_for import time import subprocess +from lenticular_cloud.lenticular_services import lenticular_services from ory_hydra_client import Client import os from pathlib import Path -from ldap3 import Connection, Server, ALL -from . import model -from .pki import Pki +from .pki import pki from .hydra import hydra_service from .translations import init_babel +from .model import db, migrate +from .views import auth_views, frontend_views, init_login_manager, api_views, pki_views, admin_views, oauth2_views def get_git_hash(): @@ -31,13 +32,6 @@ def create_app() -> Flask: app.jinja_env.globals['GIT_HASH'] = get_git_hash() - #app.ldap_orm = Connection(app.config['LDAP_URL'], app.config['LDAP_BIND_DN'], app.config['LDAP_BIND_PW'], auto_bind=True) - server = Server(app.config['LDAP_URL'], get_info=ALL) - app.ldap_conn = Connection(server, app.config['LDAP_BIND_DN'], app.config['LDAP_BIND_PW'], auto_bind=True) # TODO auto_bind read docu - model.ldap_conn = app.ldap_conn - model.base_dn = app.config['LDAP_BASE_DN'] - - from .model import db, migrate db.init_app(app) migration_dir = Path(app.root_path) / 'migrations' migrate.init_app(app, db, directory=str(migration_dir)) @@ -53,9 +47,7 @@ def create_app() -> Flask: # password=app.config['HYDRA_ADMIN_PASSWORD']) hydra_service.set_hydra_client(Client(base_url=app.config['HYDRA_ADMIN_URL'])) - from .views import auth_views, frontend_views, init_login_manager, api_views, pki_views, admin_views, oauth2_views init_login_manager(app) - #oauth2.init_app(app) app.register_blueprint(auth_views) app.register_blueprint(frontend_views) app.register_blueprint(api_views) @@ -68,11 +60,9 @@ def create_app() -> Flask: request_start_time = time.time() g.request_time = lambda: "%.5fs" % (time.time() - request_start_time) - app.lenticular_services = {} - for service_name, service_config in app.config['LENTICULAR_CLOUD_SERVICES'].items(): - app.lenticular_services[service_name] = model.Service.from_config(service_name, service_config) + lenticular_services.init_app(app) - app.pki = Pki(app.config['PKI_PATH'], app.config['DOMAIN']) + pki.init_app(app) return app diff --git a/lenticular_cloud/cli.py b/lenticular_cloud/cli.py index c35dbc5..1e12701 100644 --- a/lenticular_cloud/cli.py +++ b/lenticular_cloud/cli.py @@ -1,5 +1,5 @@ import argparse -from .model import db, User, UserSignUp +from .model import db, User from .app import create_app from werkzeug.middleware.proxy_fix import ProxyFix from flask_migrate import upgrade @@ -19,7 +19,7 @@ def entry_point(): parser_user.set_defaults(func=cli_user) parser_signup = subparsers.add_parser('signup') - parser_signup.add_argument('--signup_id', type=int) + parser_signup.add_argument('--signup_id', type=str) parser_signup.set_defaults(func=cli_signup) parser_run = subparsers.add_parser('run') @@ -55,22 +55,25 @@ def entry_point(): def cli_user(args): - print(User.query.all()) + for user in User.query.all(): + print(f'{user.id} - Enabled: {user.enabled} - Name:`{user.username}`') pass def cli_signup(args): - - print(args.signup_id) - if args.signup_id is not None: - user_data = UserSignUp.query.get(args.signup_id) - user = User.new(user_data) - db.session.add(user) - db.session.delete(user_data) + if args.signup_id is not None: + user = User.query.get(args.signup_id) + if user == None: + print("user not found") + return + user.enabled = True + db.session.commit() else: # list - print(UserSignUp.query.all()) + print('disabled users:') + for user in User.query.filter_by(enabled=False).all(): + print(f'') def cli_run(app, args): diff --git a/lenticular_cloud/lenticular_services.py b/lenticular_cloud/lenticular_services.py new file mode 100644 index 0000000..784416c --- /dev/null +++ b/lenticular_cloud/lenticular_services.py @@ -0,0 +1,15 @@ +from flask import Flask +from .model import Service +import logging + +logger = logging.getLogger(__name__) + + +class LenticularServices(dict): + + def init_app(self, app: Flask) -> None: + for service_name, service_config in app.config['LENTICULAR_CLOUD_SERVICES'].items(): + self[service_name] = Service.from_config(service_name, service_config) + + +lenticular_services = LenticularServices() \ No newline at end of file diff --git a/lenticular_cloud/migrations/env.py b/lenticular_cloud/migrations/env.py index 68feded..be1d297 100644 --- a/lenticular_cloud/migrations/env.py +++ b/lenticular_cloud/migrations/env.py @@ -13,7 +13,9 @@ config = context.config # Interpret the config file for Python logging. # This line sets up loggers basically. -fileConfig(config.config_file_name) +if type(config.config_file_name) == str: + fileConfig(config.config_file_name) + logger = logging.getLogger('alembic.env') # add your model's MetaData object here diff --git a/lenticular_cloud/migrations/versions/0518a8625b50_remove_ldap_add_rest_to_db.py b/lenticular_cloud/migrations/versions/0518a8625b50_remove_ldap_add_rest_to_db.py new file mode 100644 index 0000000..6c12fcf --- /dev/null +++ b/lenticular_cloud/migrations/versions/0518a8625b50_remove_ldap_add_rest_to_db.py @@ -0,0 +1,57 @@ +"""remove ldap, add rest to db + +Revision ID: 0518a8625b50 +Revises: 52a21983d2a8 +Create Date: 2022-06-17 13:15:33.450531 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '0518a8625b50' +down_revision = '52a21983d2a8' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('app_token', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('service_name', sa.String(), nullable=False), + sa.Column('token', sa.String(), nullable=False), + sa.Column('name', sa.String(), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('group', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('name', sa.String(), nullable=False), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('name') + ) + op.drop_table('user_sign_up') + op.add_column('user', sa.Column('password_hashed', sa.String(), server_default="", nullable=False)) + op.add_column('user', sa.Column('enabled', sa.Boolean(), server_default="false", nullable=True)) + # ### end Alembic commands ### + + op.execute("UPDATE `user` SET enabled= 1;") + #op.execute('UPDATE `user` SET password_hashed = "";') + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('user', 'enabled') + op.drop_column('user', 'password_hashed') + op.create_table('user_sign_up', + sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('username', sa.VARCHAR(), nullable=False), + sa.Column('password', sa.VARCHAR(), nullable=False), + sa.Column('alternative_email', sa.VARCHAR(), nullable=True), + sa.Column('created_at', sa.DATETIME(), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + op.drop_table('group') + op.drop_table('app_token') + # ### end Alembic commands ### diff --git a/lenticular_cloud/model.py b/lenticular_cloud/model.py index 72a097d..91a1e9e 100644 --- a/lenticular_cloud/model.py +++ b/lenticular_cloud/model.py @@ -1,11 +1,5 @@ from flask import current_app -from ldap3_orm import AttrDef, EntryBase as _EntryBase, ObjectDef, EntryType -from ldap3_orm import Reader -from ldap3 import Connection, Entry, HASHED_SALTED_SHA256 -from ldap3.utils.conv import escape_filter_chars -from ldap3.utils.hashed import hashed from flask_login import UserMixin -from ldap3.core.exceptions import LDAPSessionTerminatedByServerError from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import serialization from collections.abc import MutableSequence @@ -21,24 +15,17 @@ from datetime import datetime import uuid import pyotp from typing import Optional, Callable - +from cryptography.x509 import Certificate as CertificateObj +from sqlalchemy.ext.asyncio import create_async_engine logger = logging.getLogger(__name__) -ldap_conn = None # type: Connection -base_dn = '' + + + db = SQLAlchemy() # type: SQLAlchemy migrate = Migrate() -class UserSignUp(db.Model): - id = db.Column(db.Integer, primary_key=True) - username = db.Column(db.String, nullable=False) - password = db.Column(db.String, nullable=False) - alternative_email = db.Column(db.String) - created_at = db.Column(db.DateTime, nullable=False, - default=datetime.now) - - class SecurityUser(UserMixin): def __init__(self, username): @@ -48,90 +35,6 @@ class SecurityUser(UserMixin): return self._username -class LambdaStr: - - def __init__(self, lam: Callable[[],str]): - self.lam = lam - - def __str__(self) -> str: - return self.lam() - - -class EntryBase(db.Model): - __abstract__ = True # for sqlalchemy - - _type = None # will get replaced by the local type - _ldap_query_object = None # will get replaced by the local type - _base_dn = LambdaStr(lambda: base_dn) - -# def __init__(self, ldap_object=None, **kwargs): -# if ldap_object is None: -# self._ldap_object = self.get_type()(**kwargs) -# else: -# self._ldap_object = ldap_object - dn = '' - base_dn = '' - - def __str__(self) -> str: - return str(self._ldap_object) - - @classmethod - def get_object_def(cls) -> ObjectDef: - return ObjectDef(cls.object_classes, ldap_conn) - - @classmethod - def get_entry_type(cls) -> EntryType: - return EntryType(cls.get_dn(), cls.object_classes, ldap_conn) - - @classmethod - def get_base(cls) -> str: - return cls.base_dn.format(_base_dn=base_dn) - - @classmethod - def get_dn(cls) -> str: - return cls.dn.replace('{base_dn}', cls.get_base()) - - @classmethod - def get_type(cls): - if cls._type is None: - cls._type = EntryType(cls.get_dn(), cls.object_classes, ldap_conn) - return cls._type - - def ldap_commit(self): - self._ldap_object.entry_commit_changes() - - def ldap_add(self): - ret = ldap_conn.add( - self.entry_dn, self.object_classes, self._ldap_object.entry_attributes_as_dict) - if not ret: - raise Exception('ldap error') - - @classmethod - def query_(cls): - if cls._ldap_query_object is None: - cls._ldap_query_object = cls._query(cls) - return cls._ldap_query_object - - class _query(object): - def __init__(self, clazz): - self._class = clazz - - def _mapping(self, ldap_object): - return ldap_object - - def _query(self, ldap_filter: str): - reader = Reader(ldap_conn, self._class.get_object_def(), self._class.get_base(), ldap_filter) - try: - reader.search() - except LDAPSessionTerminatedByServerError: - ldap_conn.bind() - reader.search() - return [self._mapping(entry) for entry in reader] - - def all(self): - return self._query(None) - - class Service(object): def __init__(self, name): @@ -171,7 +74,7 @@ class Service(object): class Certificate(object): - def __init__(self, cn, ca_name: str, cert_data, revoked=False): + def __init__(self, cn, ca_name: str, cert_data: CertificateObj, revoked=False): self._cn = cn self._ca_name = ca_name self._cert_data = cert_data @@ -225,11 +128,13 @@ def generate_uuid(): return str(uuid.uuid4()) -class User(EntryBase): +class User(db.Model): id = db.Column( db.String(length=36), primary_key=True, default=generate_uuid) username = db.Column( db.String, unique=True, nullable=False) + password_hashed = db.Column( + db.String, nullable=False) alternative_email = db.Column( db.String, nullable=True) created_at = db.Column(db.DateTime, nullable=False, @@ -238,15 +143,12 @@ class User(EntryBase): default=datetime.now, onupdate=datetime.now) last_login = db.Column(db.DateTime, nullable=True) + enabled = db.Column(db.Boolean, nullable=False, default=False) + totps = db.relationship('Totp', back_populates='user') webauthn_credentials = db.relationship('WebauthnCredential', back_populates='user', cascade='delete,delete-orphan', passive_deletes=True) - dn = "uid={uid},{base_dn}" - base_dn = "ou=users,{_base_dn}" - object_classes = ["inetOrgPerson"] #, "LenticularUser"] - def __init__(self, **kwargs): - self._ldap_object = None super(db.Model).__init__(**kwargs) @property @@ -256,9 +158,6 @@ class User(EntryBase): def get(self, key): print(f'getitem: {key}') # TODO - def make_writeable(self): - self._ldap_object = self._ldap_object.entry_writable() - @property def groups(self) -> list[str]: if self.username == 'tuxcoder': @@ -266,58 +165,20 @@ class User(EntryBase): else: return [] - @property - def entry_dn(self) -> str: - return self._ldap_object.entry_dn - @property def email(self) -> str: domain = current_app.config['DOMAIN'] return f'{self.username}@{domain}' - return self._ldap_object.mail def change_password(self, password_new: str) -> bool: - self.make_writeable() password_hashed = crypt.crypt(password_new) - self._ldap_object.userPassword = ('{CRYPT}' + password_hashed).encode() - self.ldap_commit() return True - class _query(EntryBase._query): - - def _mapping(self, ldap_object): - user = User.query.filter(User.username == str(ldap_object.uid)).first() - if user is None: - # migration time - user = User() - user.username = str(ldap_object.uid) - db.session.add(user) - db.session.commit() - user._ldap_object = ldap_object - return user - - def by_username(self, username) -> Optional['User']: - result = self._query('(uid={username:s})'.format(username=escape_filter_chars(username))) - if len(result) > 0 and isinstance(result[0], User): - return result[0] - else: - return None - - @staticmethod - def new(user_data: UserSignUp): - user = User() - user.username = user_data.username.lower() - domain = current_app.config['DOMAIN'] - ldap_object = User.get_entry_type()( - uid=user_data.username.lower(), - sn=user_data.username, - cn=user_data.username, - userPassword='{CRYPT}' + user_data.password, - mail=f'{user_data.username}@{domain}') - user._ldap_object = ldap_object - user.ldap_add() - return user - +class AppToken(db.Model): + id = db.Column(db.Integer, primary_key=True) + service_name = db.Column(db.String, nullable=False) + token = db.Column(db.String, nullable=False) + name = db.Column(db.String, nullable=False) class Totp(db.Model): @@ -349,14 +210,7 @@ class WebauthnCredential(db.Model): # pylint: disable=too-few-public-methods user = db.relationship('User', back_populates='webauthn_credentials') -class Group(EntryBase): - __abstract__ = True # for sqlalchemy, disable for now - dn = "cn={cn},{base_dn}" - base_dn = "ou=Users,{_base_dn}" - object_classes = ["top"] - - fullname = AttrDef("cn") - +class Group(db.Model): id = db.Column(db.Integer, primary_key=True) name = db.Column(db.String(), nullable=False, unique=True) diff --git a/lenticular_cloud/pki.py b/lenticular_cloud/pki.py index e595bd7..9923fd7 100644 --- a/lenticular_cloud/pki.py +++ b/lenticular_cloud/pki.py @@ -1,4 +1,4 @@ -from flask import current_app +from flask import Flask from cryptography import x509 from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes @@ -39,12 +39,18 @@ def safe_filename(name): class Pki(object): - def __init__(self, pki_path: str, domain: str): + def __init__(self): + self._pki_path = Path() + self._domain = "" + + + def init_app(self, app: Flask) -> None: ''' pki_path: str base path from the pkis ''' - self._pki_path = Path(pki_path) - self._domain = domain + self._pki_path = Path(os.getcwd()) / app.config['PKI_PATH'] + self._domain = app.config['DOMAIN'] + def _init_ca(self, service: Service): @@ -322,3 +328,4 @@ class Pki(object): backend=default_backend()) return crl +pki = Pki() \ No newline at end of file diff --git a/lenticular_cloud/views/admin.py b/lenticular_cloud/views/admin.py index 3b8cb26..caa3d4a 100644 --- a/lenticular_cloud/views/admin.py +++ b/lenticular_cloud/views/admin.py @@ -11,7 +11,7 @@ from ory_hydra_client.models import OAuth2Client, GenericError from typing import Optional import logging -from ..model import db, User, UserSignUp +from ..model import db, User from .oauth2 import redirect_login, oauth2 from ..form.admin import OAuth2ClientForm from ..hydra import hydra_service @@ -50,28 +50,24 @@ async def users(): @admin_views.route('/registrations', methods=['GET']) def registrations() -> ResponseReturnValue: - users = UserSignUp.query.all() + users = User.query.filter_by(enabled=False).all() return render_template('admin/registrations.html.j2', users=users) @admin_views.route('/registration/', methods=['DELETE']) def registration_delete(registration_id) -> ResponseReturnValue: - user_data = UserSignUp.query.get(registration_id) - if user_data is None: + user = User.query.get(registration_id) + if user is None: return jsonify({}), 404 - db.session.delete(user_data) + db.session.delete(user) db.session.commit() return jsonify({}) @admin_views.route('/registration/', methods=['PUT']) def registration_accept(registration_id) -> ResponseReturnValue: - user_data = UserSignUp.query.get(registration_id) - #create user - user = User.new(user_data) - - db.session.add(user) - db.session.delete(user_data) + user = User.query.get(registration_id) + user.enabled = True db.session.commit() return jsonify({}) diff --git a/lenticular_cloud/views/api.py b/lenticular_cloud/views/api.py index 73cbe42..ed278a6 100644 --- a/lenticular_cloud/views/api.py +++ b/lenticular_cloud/views/api.py @@ -64,7 +64,7 @@ def email_login() -> ResponseReturnValue: logger.error(f'{request}') logger.error(f'{request.headers}') if not request.is_json: - return {}, 400 + return jsonify({}), 400 req_payload = request.get_json() logger.error(f'{req_payload}') password = req_payload["password"] diff --git a/lenticular_cloud/views/auth.py b/lenticular_cloud/views/auth.py index 4074b3e..2f8a6ba 100644 --- a/lenticular_cloud/views/auth.py +++ b/lenticular_cloud/views/auth.py @@ -21,7 +21,7 @@ from ory_hydra_client.api.admin import get_consent_request, accept_consent_reque from ory_hydra_client.models import AcceptLoginRequest, AcceptConsentRequest, ConsentRequestSession, GenericError, ConsentRequestSessionAccessToken, ConsentRequestSessionIdToken from typing import Optional -from ..model import db, User, SecurityUser, UserSignUp +from ..model import db, User, SecurityUser from ..form.auth import ConsentForm, LoginForm, RegistrationForm from ..auth_providers import AUTH_PROVIDER_LIST from ..hydra import hydra_service @@ -118,7 +118,7 @@ async def login() -> ResponseReturnValue: return redirect(resp.redirect_to) form = LoginForm() if form.validate_on_submit(): - user = User.query_().by_username(form.data['name']) + user = User.query.filter_by(username=form.data['name']).first() if user: session['username'] = str(user.username) else: @@ -141,7 +141,7 @@ async def login_auth() -> ResponseReturnValue: if 'username' not in session: return redirect(url_for('auth.login')) auth_forms = {} - user = User.query_().by_username(session['username']) + user = User.query.filter_by(username=session['username']).first() for auth_provider in AUTH_PROVIDER_LIST: form = auth_provider.get_form() if auth_provider.get_name() not in session['auth_providers'] and\ @@ -216,9 +216,9 @@ def sign_up(): def sign_up_submit(): form = RegistrationForm() if form.validate_on_submit(): - user = UserSignUp() + user = User() user.username = form.data['username'] - user.password = crypt.crypt(form.data['password']) + user.password_hashed = crypt.crypt(form.data['password']) user.alternative_email = form.data['alternative_email'] db.session.add(user) db.session.commit() diff --git a/lenticular_cloud/views/frontend.py b/lenticular_cloud/views/frontend.py index bc79508..6348391 100644 --- a/lenticular_cloud/views/frontend.py +++ b/lenticular_cloud/views/frontend.py @@ -31,6 +31,8 @@ from ..auth_providers import LdapAuthProvider from .auth import webauthn from .oauth2 import redirect_login, oauth2 from ..hydra import hydra_service +from ..pki import pki +from ..lenticular_services import lenticular_services frontend_views = Blueprint('frontend', __name__, url_prefix='') logger = logging.getLogger(__name__) @@ -43,8 +45,10 @@ def before_request() -> Optional[ResponseReturnValue]: logger.info('user not logged in redirect') return redirect_login() except MissingTokenError: + logger.info('MissingTokenError redirect user to login') return redirect_login() except InvalidTokenError: + logger.info('InvalidTokenError redirect user to login') return redirect_login() return None @@ -72,20 +76,20 @@ def index() -> ResponseReturnValue: @frontend_views.route('/client_cert') def client_cert() -> ResponseReturnValue: client_certs = {} - for service in current_app.lenticular_services.values(): + for service in lenticular_services.values(): client_certs[str(service.name)] = \ - current_app.pki.get_client_certs(current_user, service) + pki.get_client_certs(current_user, service) return render_template( 'frontend/client_cert.html.j2', - services=current_app.lenticular_services, + services=lenticular_services, client_certs=client_certs) @frontend_views.route('/client_cert//') def get_client_cert(service_name, serial_number) -> ResponseReturnValue: - service = current_app.lenticular_services[service_name] - cert = current_app.pki.get_client_cert( + service = lenticular_services[service_name] + cert = pki.get_client_cert( current_user, service, serial_number) return jsonify({ 'data': { @@ -96,10 +100,10 @@ def get_client_cert(service_name, serial_number) -> ResponseReturnValue: @frontend_views.route( '/client_cert//', methods=['DELETE']) def revoke_client_cert(service_name, serial_number) -> ResponseReturnValue: - service = current_app.lenticular_services[service_name] - cert = current_app.pki.get_client_cert( + service = lenticular_services[service_name] + cert = pki.get_client_cert( current_user, service, serial_number) - current_app.pki.revoke_certificate(cert) + pki.revoke_certificate(cert) return jsonify({}) @@ -107,11 +111,11 @@ def revoke_client_cert(service_name, serial_number) -> ResponseReturnValue: '/client_cert//new', methods=['GET', 'POST']) def client_cert_new(service_name) -> ResponseReturnValue: - service = current_app.lenticular_services[service_name] + service = lenticular_services[service_name] form = ClientCertForm() if form.validate_on_submit(): valid_time = int(form.data['valid_time']) * timedelta(1, 0, 0) - cert = current_app.pki.signing_publickey( + cert = pki.signing_publickey( current_user, service, form.data['publickey'], @@ -120,7 +124,7 @@ def client_cert_new(service_name) -> ResponseReturnValue: 'status': 'ok', 'data': { 'cert': cert.pem(), - 'ca_cert': current_app.pki.get_ca_cert_pem(service) + 'ca_cert': pki.get_ca_cert_pem(service) }}) elif form.is_submitted(): return jsonify({ @@ -252,7 +256,7 @@ def webauthn_register_route() -> ResponseReturnValue: return redirect(url_for('app.webauthn_list_route')) except (KeyError, ValueError) as e: - current_app.logger.exception(e) + logger.exception(e) flash('Error during registration.', 'error') return render_template('frontend/webauthn_register.html', form=form) diff --git a/lenticular_cloud/views/oauth2.py b/lenticular_cloud/views/oauth2.py index 421a72e..e1a9a63 100644 --- a/lenticular_cloud/views/oauth2.py +++ b/lenticular_cloud/views/oauth2.py @@ -1,13 +1,16 @@ from authlib.integrations.flask_client import OAuth from authlib.integrations.base_client.errors import MismatchingStateError -from flask import Flask, Blueprint, session, request, redirect, url_for +from flask import Flask, Blueprint, Response, session, request, redirect, url_for from flask_login import login_user, logout_user, current_user from flask.typing import ResponseReturnValue from flask_login import LoginManager from typing import Optional +import logging from ..model import User, SecurityUser +logger = logging.getLogger(__name__) + def fetch_token(name: str) -> Optional[dict]: token = session.get('token', None) if isinstance(token, dict): @@ -24,7 +27,10 @@ def redirect_login() -> ResponseReturnValue: logout_user() session['next_url'] = request.path redirect_uri = url_for('oauth2.authorized', _external=True) - return oauth2.custom.authorize_redirect(redirect_uri) + response = oauth2.custom.authorize_redirect(redirect_uri) + #if isinstance(response, ResponseReturnValue): + # raise RuntimeError("invalid redirect") + return response @oauth2_views.route('/authorized') @@ -32,29 +38,38 @@ def authorized() -> ResponseReturnValue: try: token = oauth2.custom.authorize_access_token() except MismatchingStateError: + logger.warning("MismatchingStateError redirect user") return redirect(url_for('oauth2.login')) if token is None: return 'bad request', 400 session['token'] = token userinfo = oauth2.custom.get('/userinfo').json() - db_user = User.query.get(str(userinfo["sub"])) - login_user(SecurityUser(db_user.username)) - + logger.info(f"userinfo `{userinfo}`") + user = User.query.get(str(userinfo["sub"])) + if user is None: + return "user not found", 404 + logger.info(f"login user `{user.username}`") + login_user(SecurityUser(user.username)) + logger.info(f"session user `{session}`") next_url = request.args.get('next_url') if next_url is None: next_url = '/' return redirect(next_url) + @oauth2_views.route('login') def login() -> ResponseReturnValue: redirect_uri = url_for('.authorized', _external=True) - return oauth2.custom.authorize_redirect(redirect_uri) + response = oauth2.custom.authorize_redirect(redirect_uri) + #if type(response) != Response: + # raise RuntimeError("invalid redirect") + return response @login_manager.user_loader def user_loader(username) -> Optional[User]: - user = User.query_().by_username(username) + user = User.query.filter_by(username=username).first() if isinstance(user, User): return user else: @@ -65,12 +80,15 @@ def request_loader(_request): pass @login_manager.unauthorized_handler -def unauthorized(): - redirect_login() +def unauthorized() -> Optional[User]: + pass -def init_login_manager(app: Flask): +def init_login_manager(app: Flask) -> None: base_url = app.config['HYDRA_PUBLIC_URL'] + if not isinstance(base_url, str): + raise RuntimeError("HYDRA_PUBLIC_URL not set") + oauth2.register( name="custom", client_id=app.config['OAUTH_ID'], diff --git a/lenticular_cloud/views/pki.py b/lenticular_cloud/views/pki.py index c872d7f..08df7ea 100644 --- a/lenticular_cloud/views/pki.py +++ b/lenticular_cloud/views/pki.py @@ -1,5 +1,7 @@ -from flask import current_app, Blueprint +from flask import Blueprint from cryptography.hazmat.primitives import serialization +from ..lenticular_services import lenticular_services +from ..pki import pki pki_views = Blueprint('pki', __name__, url_prefix='/') @@ -7,7 +9,7 @@ pki_views = Blueprint('pki', __name__, url_prefix='/') @pki_views.route('/.crl') def crl(service_name: str): - service = current_app.lenticular_services[service_name] - crl = current_app.pki.get_crl(service) + service = lenticular_services[service_name] + crl = pki.get_crl(service) return crl.public_bytes(encoding=serialization.Encoding.DER)