mirror of https://codeberg.org/dribdat/dribdat.git
112 lines
3.5 KiB
Python
112 lines
3.5 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""Database module, including the SQLAlchemy database object and DB-related utilities."""
|
|
from .compat import basestring
|
|
from .extensions import db
|
|
from sqlalchemy.types import Integer, TypeDecorator
|
|
from decimal import Decimal
|
|
|
|
# Alias common SQLAlchemy names
|
|
Column = db.Column
|
|
relationship = db.relationship
|
|
|
|
|
|
class CRUDMixin(object):
|
|
"""Mixin that adds convenience methods for CRUD (create, read, update, delete) operations."""
|
|
|
|
@classmethod
|
|
def create(cls, **kwargs):
|
|
"""Create a new record and save it the database."""
|
|
instance = cls(**kwargs)
|
|
return instance.save()
|
|
|
|
def update(self, commit=True, **kwargs):
|
|
"""Update specific fields of a record."""
|
|
for attr, value in kwargs.items():
|
|
setattr(self, attr, value)
|
|
return commit and self.save() or self
|
|
|
|
def save(self, commit=True):
|
|
"""Save the record."""
|
|
db.session.add(self)
|
|
if commit:
|
|
db.session.commit()
|
|
return self
|
|
|
|
def delete(self, commit=True):
|
|
"""Remove the record from the database."""
|
|
db.session.delete(self)
|
|
return commit and db.session.commit()
|
|
|
|
|
|
class Model(CRUDMixin, db.Model):
|
|
"""Base model class that includes CRUD convenience methods."""
|
|
|
|
__abstract__ = True
|
|
|
|
|
|
class PkModel(Model):
|
|
"""Base model class that includes CRUD convenience methods, plus adds a 'primary key' column named ``id``."""
|
|
|
|
__abstract__ = True
|
|
id = Column(db.Integer, primary_key=True)
|
|
|
|
@classmethod
|
|
def get_by_id(cls, record_id):
|
|
"""Get record by ID."""
|
|
if any(
|
|
(
|
|
isinstance(record_id, basestring) and record_id.isdigit(),
|
|
isinstance(record_id, (int, float)),
|
|
)
|
|
):
|
|
return db.session.get(cls, int(record_id))
|
|
return None
|
|
|
|
|
|
def reference_col(
|
|
tablename, nullable=False, pk_name="id", foreign_key_kwargs=None, column_kwargs=None
|
|
):
|
|
"""Column that adds primary key foreign key reference.
|
|
|
|
Usage: ::
|
|
|
|
category_id = reference_col('category')
|
|
category = relationship('Category', backref='categories')
|
|
"""
|
|
foreign_key_kwargs = foreign_key_kwargs or {}
|
|
column_kwargs = column_kwargs or {}
|
|
|
|
return Column(
|
|
db.ForeignKey(f"{tablename}.{pk_name}", **foreign_key_kwargs),
|
|
nullable=nullable,
|
|
**column_kwargs,
|
|
)
|
|
|
|
|
|
class SqliteDecimal(TypeDecorator):
|
|
# This TypeDecorator use Sqlalchemy Integer as impl. It converts Decimals
|
|
# from Python to Integers which is later stored in Sqlite database.
|
|
# code by zhukailei via https://stackoverflow.com/a/52526847
|
|
impl = Integer
|
|
|
|
def __init__(self, scale):
|
|
# It takes a 'scale' parameter, which specifies the number of digits
|
|
# to the right of the decimal point of the number in the column.
|
|
TypeDecorator.__init__(self)
|
|
self.scale = scale
|
|
self.multiplier_int = 10 ** self.scale
|
|
|
|
def process_bind_param(self, value, dialect):
|
|
# e.g. value = Column(SqliteDecimal(2)) means a value such as
|
|
# Decimal('12.34') will be converted to 1234 in Sqlite
|
|
if value is not None:
|
|
value = int(Decimal(value) * self.multiplier_int)
|
|
return value
|
|
|
|
def process_result_value(self, value, dialect):
|
|
# e.g. Integer 1234 in Sqlite will be converted to Decimal('12.34'),
|
|
# when query takes place.
|
|
if value is not None:
|
|
value = Decimal(value) / self.multiplier_int
|
|
return value
|