Source code for

import logging
import traceback
from io import StringIO
from textwrap import dedent
from threading import Lock
from typing import Optional, Dict, Union, Callable, Iterable

import sqlalchemy_utils
from vortex.DeferUtil import isMainThread

from import (
from pytmpdir.directory_ import Directory
from sqlalchemy import create_engine, Integer
from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import scoped_session
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.schema import MetaData, Sequence
from vortex.DeferUtil import deferToThreadWrapWithLogger

logger = logging.getLogger(__name__)

DbSessionCreator = Callable[[], Session]

DelcarativeIdGen = Optional[Iterable[int]]
DeclarativeIdCreator = Callable[[object, int], DelcarativeIdGen]

[docs]class DbConnection: def __init__( self, dbConnectString: str, metadata: MetaData, alembicDir: str, dbEngineArgs: Optional[Dict[str, Union[str, int]]] = None, enableForeignKeys=False, enableCreateAll=True, ): """SQLAlchemy Database Connection This class takes care of migrating the database and establishing thing database connections and ORM sessions. :param dbConnectString: The connection string for the DB. See :param metadata: The instance of the metadata for this connection, This is schema qualified MetaData(schema="schama_name") :param alembicDir: The absolute location of the alembic directory (versions dir lives under this) :param dbEngineArgs: The arguments to pass to the database engine, See :param enableCreateAll: If the schema doesn't exist, then the migration is allowed to use matadata.create_all() :param enableForeignKeys: Perform a check to ensure foriegn keys have indexes after the db is migrated and connected. """ self._dbConnectString = dbConnectString self._metadata = metadata self._alembicDir = alembicDir self._dbEngine = None self._ScopedSession = None self._dbEngineArgs = ( dbEngineArgs if dbEngineArgs else { "client_encoding": "utf8", "echo": False, "executemany_mode": "batch", "max_overflow": 50, "pool_recycle": 3540, "pool_size": 5, "pool_timeout": 60, } ) self._sequenceMutex = Lock() self._enableForeignKeys = enableForeignKeys self._enableCreateAll = enableCreateAll
[docs] def closeAllSessions(self): """Close All Session Close all ORM sessions connected to this DB engine. """ self.ormSessionCreator() # Ensure we have a session maker and session self._ScopedSession.close_all()
@property def ormSessionCreator(self) -> DbSessionCreator: sessionCreator = self.__ormSessionCreator def call(): if isMainThread(): logger.warning( "This plugin has called" " SQLAlchemy code in the main thread," " this is very bad" ) for line in traceback.format_stack(): logger.debug(line.strip().replace("\n", "")) return sessionCreator() return call @property def __ormSessionCreator(self) -> DbSessionCreator: """Get Orm Session :return: A SQLAlchemy session scoped for the callers thread.. """ assert self._dbConnectString if self._ScopedSession: return self._ScopedSession self._ScopedSession = scoped_session(sessionmaker(bind=self.dbEngine)) return self._ScopedSession @property def dbEngine(self) -> Engine: """Get DB Engine This is not thread safe, use the ormSesson to execute SQL statements instead. self.ormSession.execute(...) :return: the DB Engine used to connect to the database. """ if self._dbEngine is None: self._dbEngine = create_engine( self._dbConnectString, **self._dbEngineArgs ) return self._dbEngine
[docs] def migrate(self) -> None: """Migrate Perform a database migration, upgrading to the latest schema level. """ assert self.ormSessionCreator, "ormSessionCreator is not defined" connection = self._dbEngine.connect() isDbInitialised = self._dbEngine.dialect.has_table( connection, "alembic_version", schema=self._metadata.schema ) connection.close() if isDbInitialised or not self._enableCreateAll: self._doMigration(self._dbEngine) else: self._doCreateAll(self._dbEngine) if self._enableForeignKeys: self.checkForeignKeys(self._dbEngine)
[docs] def checkForeignKeys(self, engine: Engine) -> None: """Check Foreign Keys Log any foreign keys that don't have indexes assigned to them. This is a performance issue. """ missing = sqlalchemy_utils.functions.non_indexed_foreign_keys( self._metadata, engine=engine ) for table, keys in missing.items(): for key in keys: logger.warning("Missing index on ForeignKey %s" % key.columns)
@deferToThreadWrapWithLogger(logger) def prefetchDeclarativeIds(self, Declarative, count) -> DelcarativeIdGen: """Prefetch Declarative IDs This function prefetches a chunk of IDs from a database sequence. Doing this allows us to preallocate the IDs before an insert, which significantly speeds up : * Orm inserts, especially those using inheritance * When we need the ID to assign it to a related object that we're also inserting. :param Declarative: The SQLAlchemy declarative class. (The class that inherits from DeclarativeBase) :param count: The number of IDs to prefetch :return: An iterable that dispenses the new IDs """ return _commonPrefetchDeclarativeIds( self.dbEngine, self._sequenceMutex, Declarative, count ) def _runAlembicCommand(self, command, *args): configFile = self._writeAlembicIni() from alembic.config import Config alembic_cfg = Config( command(alembic_cfg, *args) def _doCreateAll(self, engine): ensureSchemaExists(engine, self._metadata.schema) self._metadata.create_all(engine) from alembic import command self._runAlembicCommand(command.stamp, "head") def _writeAlembicIni(self): cfg = """ [alembic] script_location = %(alembicDir)s sourceless = true sqlalchemy.url = %(url)s [alembic:exclude] tables = spatial_ref_sys [logging] default_level = INFO """ cfg = dedent(cfg) cfg %= {"alembicDir": self._alembicDir, "url": self._dbConnectString} dir = Directory() file = dir.createTempFile() with, asBin=False) as f: f.write(cfg) return file.namedTempFileReader() def _doMigration(self, engine): ensureSchemaExists(engine, self._metadata.schema) from alembic import command self._runAlembicCommand(command.upgrade, "head")
[docs]def convertToCoreSqlaInsert(ormObj, Declarative): insertDict = dict() for fieldName in Declarative.tupleFieldNames(): value = getattr(ormObj, fieldName) if value is None: Col = getattr(Declarative, fieldName) if isinstance(Col, InstrumentedAttribute): value = Col.server_default.arg if Col.server_default else None if value == "false": value = False elif value == "true": value = True insertDict[fieldName] = value return insertDict
[docs]def pgCopyInsert(rawConn, table, inserts): colTypes = [c.type for c in table.c] def convert(index, val): if val is None: return "\\N" if isinstance(colTypes[index], Integer): return str(val).split(".")[0] return ( str(val) .replace("\\", "\\\\") .replace("\t", "\\t") .replace("\n", "\\n") .replace("\r", "\\r") ) columns = [str(c).split(".")[1] for c in table.c] f = StringIO() for insert in inserts: line = "" for i, c in enumerate(columns): line += convert(i, insert[c]) line += "\n" if i == len(columns) - 1 else "\t" f.write(line) cursor = rawConn.cursor() cursor.copy_from( f, '"%s"."%s"' % (table.schema,, sep="\t", null="\\N", columns=tuple(['"%s"' % c for c in columns]), ) f.close() cursor.close()
def _commonPrefetchDeclarativeIds( engine, mutex, Declarative, count ) -> Optional[Iterable[int]]: """Common Prefetch Declarative IDs This function is used by the worker and server """ if not count: logger.debug("Count was zero, no range returned") return conn = engine.connect() transaction = conn.begin() mutex.acquire() try: sequence = Sequence( "%s_id_seq" % Declarative.__tablename__, schema=Declarative.metadata.schema, ) if isPostGreSQLDialect(engine): sql = "SELECT setval('%(seq)s', (select nextval('%(seq)s') + %(add)s), true)" sql %= { "seq": '"%s"."%s"' % (sequence.schema,, "add": count, } nextStartId = conn.execute(sql).fetchone()[0] startId = nextStartId - count elif isMssqlDialect(engine): startId = ( conn.execute( 'SELECT NEXT VALUE FOR "%s"."%s"' % (sequence.schema, ).fetchone()[0] + 1 ) nextStartId = startId + count conn.execute( 'alter sequence "%s"."%s" restart with %s' % (sequence.schema,, nextStartId) ) else: raise NotImplementedError() transaction.commit() return iter(range(startId, nextStartId)) finally: mutex.release() conn.close()