diff options
Diffstat (limited to 'bitbake/lib/hashserv/sqlalchemy.py')
-rw-r--r-- | bitbake/lib/hashserv/sqlalchemy.py | 598 |
1 files changed, 598 insertions, 0 deletions
diff --git a/bitbake/lib/hashserv/sqlalchemy.py b/bitbake/lib/hashserv/sqlalchemy.py new file mode 100644 index 0000000000..f7b0226a7a --- /dev/null +++ b/bitbake/lib/hashserv/sqlalchemy.py @@ -0,0 +1,598 @@ +#! /usr/bin/env python3 +# +# Copyright (C) 2023 Garmin Ltd. +# +# SPDX-License-Identifier: GPL-2.0-only +# + +import logging +from datetime import datetime +from . import User + +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.pool import NullPool +from sqlalchemy import ( + MetaData, + Column, + Table, + Text, + Integer, + UniqueConstraint, + DateTime, + Index, + select, + insert, + exists, + literal, + and_, + delete, + update, + func, + inspect, +) +import sqlalchemy.engine +from sqlalchemy.orm import declarative_base +from sqlalchemy.exc import IntegrityError +from sqlalchemy.dialects.postgresql import insert as postgres_insert + +Base = declarative_base() + + +class UnihashesV3(Base): + __tablename__ = "unihashes_v3" + id = Column(Integer, primary_key=True, autoincrement=True) + method = Column(Text, nullable=False) + taskhash = Column(Text, nullable=False) + unihash = Column(Text, nullable=False) + gc_mark = Column(Text, nullable=False) + + __table_args__ = ( + UniqueConstraint("method", "taskhash"), + Index("taskhash_lookup_v4", "method", "taskhash"), + Index("unihash_lookup_v1", "unihash"), + ) + + +class OuthashesV2(Base): + __tablename__ = "outhashes_v2" + id = Column(Integer, primary_key=True, autoincrement=True) + method = Column(Text, nullable=False) + taskhash = Column(Text, nullable=False) + outhash = Column(Text, nullable=False) + created = Column(DateTime) + owner = Column(Text) + PN = Column(Text) + PV = Column(Text) + PR = Column(Text) + task = Column(Text) + outhash_siginfo = Column(Text) + + __table_args__ = ( + UniqueConstraint("method", "taskhash", "outhash"), + Index("outhash_lookup_v3", "method", "outhash"), + ) + + +class Users(Base): + __tablename__ = "users" + id = Column(Integer, primary_key=True, autoincrement=True) + username = Column(Text, nullable=False) + token = Column(Text, nullable=False) + permissions = Column(Text) + + __table_args__ = (UniqueConstraint("username"),) + + +class Config(Base): + __tablename__ = "config" + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(Text, nullable=False) + value = Column(Text) + __table_args__ = ( + UniqueConstraint("name"), + Index("config_lookup", "name"), + ) + + +# +# Old table versions +# +DeprecatedBase = declarative_base() + + +class UnihashesV2(DeprecatedBase): + __tablename__ = "unihashes_v2" + id = Column(Integer, primary_key=True, autoincrement=True) + method = Column(Text, nullable=False) + taskhash = Column(Text, nullable=False) + unihash = Column(Text, nullable=False) + + __table_args__ = ( + UniqueConstraint("method", "taskhash"), + Index("taskhash_lookup_v3", "method", "taskhash"), + ) + + +class DatabaseEngine(object): + def __init__(self, url, username=None, password=None): + self.logger = logging.getLogger("hashserv.sqlalchemy") + self.url = sqlalchemy.engine.make_url(url) + + if username is not None: + self.url = self.url.set(username=username) + + if password is not None: + self.url = self.url.set(password=password) + + async def create(self): + def check_table_exists(conn, name): + return inspect(conn).has_table(name) + + self.logger.info("Using database %s", self.url) + if self.url.drivername == 'postgresql+psycopg': + # Psygopg 3 (psygopg) driver can handle async connection pooling + self.engine = create_async_engine(self.url, max_overflow=-1) + else: + self.engine = create_async_engine(self.url, poolclass=NullPool) + + async with self.engine.begin() as conn: + # Create tables + self.logger.info("Creating tables...") + await conn.run_sync(Base.metadata.create_all) + + if await conn.run_sync(check_table_exists, UnihashesV2.__tablename__): + self.logger.info("Upgrading Unihashes V2 -> V3...") + statement = insert(UnihashesV3).from_select( + ["id", "method", "unihash", "taskhash", "gc_mark"], + select( + UnihashesV2.id, + UnihashesV2.method, + UnihashesV2.unihash, + UnihashesV2.taskhash, + literal("").label("gc_mark"), + ), + ) + self.logger.debug("%s", statement) + await conn.execute(statement) + + await conn.run_sync(Base.metadata.drop_all, [UnihashesV2.__table__]) + self.logger.info("Upgrade complete") + + def connect(self, logger): + return Database(self.engine, logger) + + +def map_row(row): + if row is None: + return None + return dict(**row._mapping) + + +def map_user(row): + if row is None: + return None + return User( + username=row.username, + permissions=set(row.permissions.split()), + ) + + +def _make_condition_statement(table, condition): + where = {} + for c in table.__table__.columns: + if c.key in condition and condition[c.key] is not None: + where[c] = condition[c.key] + + return [(k == v) for k, v in where.items()] + + +class Database(object): + def __init__(self, engine, logger): + self.engine = engine + self.db = None + self.logger = logger + + async def __aenter__(self): + self.db = await self.engine.connect() + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.close() + + async def close(self): + await self.db.close() + self.db = None + + async def _execute(self, statement): + self.logger.debug("%s", statement) + return await self.db.execute(statement) + + async def _set_config(self, name, value): + while True: + result = await self._execute( + update(Config).where(Config.name == name).values(value=value) + ) + + if result.rowcount == 0: + self.logger.debug("Config '%s' not found. Adding it", name) + try: + await self._execute(insert(Config).values(name=name, value=value)) + except IntegrityError: + # Race. Try again + continue + + break + + def _get_config_subquery(self, name, default=None): + if default is not None: + return func.coalesce( + select(Config.value).where(Config.name == name).scalar_subquery(), + default, + ) + return select(Config.value).where(Config.name == name).scalar_subquery() + + async def _get_config(self, name): + result = await self._execute(select(Config.value).where(Config.name == name)) + row = result.first() + if row is None: + return None + return row.value + + async def get_unihash_by_taskhash_full(self, method, taskhash): + async with self.db.begin(): + result = await self._execute( + select( + OuthashesV2, + UnihashesV3.unihash.label("unihash"), + ) + .join( + UnihashesV3, + and_( + UnihashesV3.method == OuthashesV2.method, + UnihashesV3.taskhash == OuthashesV2.taskhash, + ), + ) + .where( + OuthashesV2.method == method, + OuthashesV2.taskhash == taskhash, + ) + .order_by( + OuthashesV2.created.asc(), + ) + .limit(1) + ) + return map_row(result.first()) + + async def get_unihash_by_outhash(self, method, outhash): + async with self.db.begin(): + result = await self._execute( + select(OuthashesV2, UnihashesV3.unihash.label("unihash")) + .join( + UnihashesV3, + and_( + UnihashesV3.method == OuthashesV2.method, + UnihashesV3.taskhash == OuthashesV2.taskhash, + ), + ) + .where( + OuthashesV2.method == method, + OuthashesV2.outhash == outhash, + ) + .order_by( + OuthashesV2.created.asc(), + ) + .limit(1) + ) + return map_row(result.first()) + + async def unihash_exists(self, unihash): + async with self.db.begin(): + result = await self._execute( + select(UnihashesV3).where(UnihashesV3.unihash == unihash).limit(1) + ) + + return result.first() is not None + + async def get_outhash(self, method, outhash): + async with self.db.begin(): + result = await self._execute( + select(OuthashesV2) + .where( + OuthashesV2.method == method, + OuthashesV2.outhash == outhash, + ) + .order_by( + OuthashesV2.created.asc(), + ) + .limit(1) + ) + return map_row(result.first()) + + async def get_equivalent_for_outhash(self, method, outhash, taskhash): + async with self.db.begin(): + result = await self._execute( + select( + OuthashesV2.taskhash.label("taskhash"), + UnihashesV3.unihash.label("unihash"), + ) + .join( + UnihashesV3, + and_( + UnihashesV3.method == OuthashesV2.method, + UnihashesV3.taskhash == OuthashesV2.taskhash, + ), + ) + .where( + OuthashesV2.method == method, + OuthashesV2.outhash == outhash, + OuthashesV2.taskhash != taskhash, + ) + .order_by( + OuthashesV2.created.asc(), + ) + .limit(1) + ) + return map_row(result.first()) + + async def get_equivalent(self, method, taskhash): + async with self.db.begin(): + result = await self._execute( + select( + UnihashesV3.unihash, + UnihashesV3.method, + UnihashesV3.taskhash, + ).where( + UnihashesV3.method == method, + UnihashesV3.taskhash == taskhash, + ) + ) + return map_row(result.first()) + + async def remove(self, condition): + async def do_remove(table): + where = _make_condition_statement(table, condition) + if where: + async with self.db.begin(): + result = await self._execute(delete(table).where(*where)) + return result.rowcount + + return 0 + + count = 0 + count += await do_remove(UnihashesV3) + count += await do_remove(OuthashesV2) + + return count + + async def get_current_gc_mark(self): + async with self.db.begin(): + return await self._get_config("gc-mark") + + async def gc_status(self): + async with self.db.begin(): + gc_mark_subquery = self._get_config_subquery("gc-mark", "") + + result = await self._execute( + select(func.count()) + .select_from(UnihashesV3) + .where(UnihashesV3.gc_mark == gc_mark_subquery) + ) + keep_rows = result.scalar() + + result = await self._execute( + select(func.count()) + .select_from(UnihashesV3) + .where(UnihashesV3.gc_mark != gc_mark_subquery) + ) + remove_rows = result.scalar() + + return (keep_rows, remove_rows, await self._get_config("gc-mark")) + + async def gc_mark(self, mark, condition): + async with self.db.begin(): + await self._set_config("gc-mark", mark) + + where = _make_condition_statement(UnihashesV3, condition) + if not where: + return 0 + + result = await self._execute( + update(UnihashesV3) + .values(gc_mark=self._get_config_subquery("gc-mark", "")) + .where(*where) + ) + return result.rowcount + + async def gc_sweep(self): + async with self.db.begin(): + result = await self._execute( + delete(UnihashesV3).where( + # A sneaky conditional that provides some errant use + # protection: If the config mark is NULL, this will not + # match any rows because No default is specified in the + # select statement + UnihashesV3.gc_mark + != self._get_config_subquery("gc-mark") + ) + ) + await self._set_config("gc-mark", None) + + return result.rowcount + + async def clean_unused(self, oldest): + async with self.db.begin(): + result = await self._execute( + delete(OuthashesV2).where( + OuthashesV2.created < oldest, + ~( + select(UnihashesV3.id) + .where( + UnihashesV3.method == OuthashesV2.method, + UnihashesV3.taskhash == OuthashesV2.taskhash, + ) + .limit(1) + .exists() + ), + ) + ) + return result.rowcount + + async def insert_unihash(self, method, taskhash, unihash): + # Postgres specific ignore on insert duplicate + if self.engine.name == "postgresql": + statement = ( + postgres_insert(UnihashesV3) + .values( + method=method, + taskhash=taskhash, + unihash=unihash, + gc_mark=self._get_config_subquery("gc-mark", ""), + ) + .on_conflict_do_nothing(index_elements=("method", "taskhash")) + ) + else: + statement = insert(UnihashesV3).values( + method=method, + taskhash=taskhash, + unihash=unihash, + gc_mark=self._get_config_subquery("gc-mark", ""), + ) + + try: + async with self.db.begin(): + result = await self._execute(statement) + return result.rowcount != 0 + except IntegrityError: + self.logger.debug( + "%s, %s, %s already in unihash database", method, taskhash, unihash + ) + return False + + async def insert_outhash(self, data): + outhash_columns = set(c.key for c in OuthashesV2.__table__.columns) + + data = {k: v for k, v in data.items() if k in outhash_columns} + + if "created" in data and not isinstance(data["created"], datetime): + data["created"] = datetime.fromisoformat(data["created"]) + + # Postgres specific ignore on insert duplicate + if self.engine.name == "postgresql": + statement = ( + postgres_insert(OuthashesV2) + .values(**data) + .on_conflict_do_nothing( + index_elements=("method", "taskhash", "outhash") + ) + ) + else: + statement = insert(OuthashesV2).values(**data) + + try: + async with self.db.begin(): + result = await self._execute(statement) + return result.rowcount != 0 + except IntegrityError: + self.logger.debug( + "%s, %s already in outhash database", data["method"], data["outhash"] + ) + return False + + async def _get_user(self, username): + async with self.db.begin(): + result = await self._execute( + select( + Users.username, + Users.permissions, + Users.token, + ).where( + Users.username == username, + ) + ) + return result.first() + + async def lookup_user_token(self, username): + row = await self._get_user(username) + if not row: + return None, None + return map_user(row), row.token + + async def lookup_user(self, username): + return map_user(await self._get_user(username)) + + async def set_user_token(self, username, token): + async with self.db.begin(): + result = await self._execute( + update(Users) + .where( + Users.username == username, + ) + .values( + token=token, + ) + ) + return result.rowcount != 0 + + async def set_user_perms(self, username, permissions): + async with self.db.begin(): + result = await self._execute( + update(Users) + .where(Users.username == username) + .values(permissions=" ".join(permissions)) + ) + return result.rowcount != 0 + + async def get_all_users(self): + async with self.db.begin(): + result = await self._execute( + select( + Users.username, + Users.permissions, + ) + ) + return [map_user(row) for row in result] + + async def new_user(self, username, permissions, token): + try: + async with self.db.begin(): + await self._execute( + insert(Users).values( + username=username, + permissions=" ".join(permissions), + token=token, + ) + ) + return True + except IntegrityError as e: + self.logger.debug("Cannot create new user %s: %s", username, e) + return False + + async def delete_user(self, username): + async with self.db.begin(): + result = await self._execute( + delete(Users).where(Users.username == username) + ) + return result.rowcount != 0 + + async def get_usage(self): + usage = {} + async with self.db.begin() as session: + for name, table in Base.metadata.tables.items(): + result = await self._execute( + statement=select(func.count()).select_from(table) + ) + usage[name] = { + "rows": result.scalar(), + } + + return usage + + async def get_query_columns(self): + columns = set() + for table in (UnihashesV3, OuthashesV2): + for c in table.__table__.columns: + if not isinstance(c.type, Text): + continue + columns.add(c.key) + + return list(columns) |