Synced with sloria 0.11.0

This commit is contained in:
Oleg Lavrovsky 2016-09-18 23:28:46 +02:00
parent 161a9c5bb7
commit 5c4ac761a7
30 changed files with 379 additions and 151 deletions

View File

@ -88,6 +88,6 @@ For a full migration command reference, run `python manage.py db --help`.
## Credits
Developed by [Oleg Lavrovsky](http://github.com/loleg) based on [Steven Loria](http://github.com/sloria/)'s [cookiecutter](http://github.com/audreyr/cookiecutter/) template.
Developed by [Oleg Lavrovsky](http://github.com/loleg) based on [Steven Loria's flask-cookiecutter](https://github.com/sloria/cookiecutter-flask).
With thanks to [Swisscom](http://swisscom.com)'s F. Wieser and M.-C. Gasser for conceptual inputs and financial support of the first release of this project.

10
autoapp.py Executable file
View File

@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-
"""Create an application instance."""
from flask.helpers import get_debug_flag
from dribdat.app import create_app
from dribdat.settings import DevConfig, ProdConfig
CONFIG = DevConfig if get_debug_flag() else ProdConfig
app = create_app(CONFIG)

View File

@ -0,0 +1 @@
"""Main application package."""

View File

@ -3,23 +3,23 @@
from flask import Flask, render_template
from flaskext.markdown import Markdown
from dribdat.settings import ProdConfig
from dribdat import commands, public, user, admin
from dribdat.assets import assets
from dribdat.extensions import (
hashing,
cache,
csrf_protect,
db,
login_manager,
migrate,
debug_toolbar,
opbeat,
)
from dribdat import public, user, admin
from dribdat.settings import ProdConfig
def create_app(config_object=ProdConfig):
"""An application factory, as explained here:
http://flask.pocoo.org/docs/patterns/appfactories/
"""An application factory, as explained here: http://flask.pocoo.org/docs/patterns/appfactories/.
:param config_object: The configuration object to use.
"""
@ -30,23 +30,26 @@ def create_app(config_object=ProdConfig):
register_errorhandlers(app)
register_filters(app)
register_loggers(app)
register_shellcontext(app)
register_commands(app)
return app
def register_extensions(app):
"""Register Flask extensions."""
assets.init_app(app)
hashing.init_app(app)
cache.init_app(app)
db.init_app(app)
login_manager.init_app(app)
debug_toolbar.init_app(app)
opbeat.init_app(app)
migrate.init_app(app, db)
Markdown(app)
return None
def register_blueprints(app):
"""Register Flask blueprints."""
app.register_blueprint(public.views.blueprint)
app.register_blueprint(user.views.blueprint)
app.register_blueprint(admin.views.blueprint)
@ -54,14 +57,36 @@ def register_blueprints(app):
def register_errorhandlers(app):
"""Register error handlers."""
def render_error(error):
"""Render error template."""
# If a HTTPException, pull the `code` attribute; default to 500
error_code = getattr(error, 'code', 500)
return render_template("{0}.html".format(error_code)), error_code
return render_template('{0}.html'.format(error_code)), error_code
for errcode in [401, 404, 500]:
app.errorhandler(errcode)(render_error)
return None
def register_shellcontext(app):
"""Register shell context objects."""
def shell_context():
"""Shell context objects."""
return {
'db': db,
'User': user.models.User}
app.shell_context_processor(shell_context)
def register_commands(app):
"""Register Click commands."""
app.cli.add_command(commands.test)
app.cli.add_command(commands.lint)
app.cli.add_command(commands.clean)
app.cli.add_command(commands.urls)
def register_filters(app):
@app.template_filter()
def pretty_date(value):
@ -70,6 +95,7 @@ def register_filters(app):
def format_date(value, format='%Y-%m-%d'):
return value.strftime(format)
def register_loggers(app):
# if os.environ.get('HEROKU') is not None:
# app.logger.info('hello Heroku!')

View File

@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
"""Application assets."""
from flask_assets import Bundle, Environment
css = Bundle(
@ -22,5 +23,5 @@ js = Bundle(
assets = Environment()
assets.register("js_all", js)
assets.register("css_all", css)
assets.register('js_all', js)
assets.register('css_all', css)

126
dribdat/commands.py Normal file
View File

@ -0,0 +1,126 @@
# -*- coding: utf-8 -*-
"""Click commands."""
import os
from glob import glob
from subprocess import call
import click
from flask import current_app
from flask.cli import with_appcontext
from werkzeug.exceptions import MethodNotAllowed, NotFound
HERE = os.path.abspath(os.path.dirname(__file__))
PROJECT_ROOT = os.path.join(HERE, os.pardir)
TEST_PATH = os.path.join(PROJECT_ROOT, 'tests')
@click.command()
def test():
"""Run the tests."""
import pytest
rv = pytest.main([TEST_PATH, '--verbose'])
exit(rv)
@click.command()
@click.option('-f', '--fix-imports', default=False, is_flag=True,
help='Fix imports using isort, before linting')
def lint(fix_imports):
"""Lint and check code style with flake8 and isort."""
skip = ['requirements']
root_files = glob('*.py')
root_directories = [
name for name in next(os.walk('.'))[1] if not name.startswith('.')]
files_and_directories = [
arg for arg in root_files + root_directories if arg not in skip]
def execute_tool(description, *args):
"""Execute a checking tool with its arguments."""
command_line = list(args) + files_and_directories
click.echo('{}: {}'.format(description, ' '.join(command_line)))
rv = call(command_line)
if rv != 0:
exit(rv)
if fix_imports:
execute_tool('Fixing import order', 'isort', '-rc')
execute_tool('Checking code style', 'flake8')
@click.command()
def clean():
"""Remove *.pyc and *.pyo files recursively starting at current directory.
Borrowed from Flask-Script, converted to use Click.
"""
for dirpath, dirnames, filenames in os.walk('.'):
for filename in filenames:
if filename.endswith('.pyc') or filename.endswith('.pyo'):
full_pathname = os.path.join(dirpath, filename)
click.echo('Removing {}'.format(full_pathname))
os.remove(full_pathname)
@click.command()
@click.option('--url', default=None,
help='Url to test (ex. /static/image.png)')
@click.option('--order', default='rule',
help='Property on Rule to order by (default: rule)')
@with_appcontext
def urls(url, order):
"""Display all of the url matching routes for the project.
Borrowed from Flask-Script, converted to use Click.
"""
rows = []
column_length = 0
column_headers = ('Rule', 'Endpoint', 'Arguments')
if url:
try:
rule, arguments = (
current_app.url_map
.bind('localhost')
.match(url, return_rule=True))
rows.append((rule.rule, rule.endpoint, arguments))
column_length = 3
except (NotFound, MethodNotAllowed) as e:
rows.append(('<{}>'.format(e), None, None))
column_length = 1
else:
rules = sorted(
current_app.url_map.iter_rules(),
key=lambda rule: getattr(rule, order))
for rule in rules:
rows.append((rule.rule, rule.endpoint, None))
column_length = 2
str_template = ''
table_width = 0
if column_length >= 1:
max_rule_length = max(len(r[0]) for r in rows)
max_rule_length = max_rule_length if max_rule_length > 4 else 4
str_template += '{:' + str(max_rule_length) + '}'
table_width += max_rule_length
if column_length >= 2:
max_endpoint_length = max(len(str(r[1])) for r in rows)
# max_endpoint_length = max(rows, key=len)
max_endpoint_length = (
max_endpoint_length if max_endpoint_length > 8 else 8)
str_template += ' {:' + str(max_endpoint_length) + '}'
table_width += 2 + max_endpoint_length
if column_length >= 3:
max_arguments_length = max(len(str(r[2])) for r in rows)
max_arguments_length = (
max_arguments_length if max_arguments_length > 9 else 9)
str_template += ' {:' + str(max_arguments_length) + '}'
table_width += 2 + max_arguments_length
click.echo(str_template.format(*column_headers[:column_length]))
click.echo('-' * table_width)
for row in rows:
click.echo(str_template.format(*row[:column_length]))

View File

@ -1,17 +1,15 @@
# -*- coding: utf-8 -*-
"""Python 2/3 compatibility module."""
import sys
PY2 = int(sys.version[0]) == 2
if PY2:
text_type = unicode
text_type = unicode # noqa
binary_type = str
string_types = (str, unicode)
unicode = unicode
basestring = basestring
string_types = (str, unicode) # noqa
unicode = unicode # noqa
basestring = basestring # noqa
else:
text_type = str
binary_type = bytes

View File

@ -1,11 +1,9 @@
# -*- coding: utf-8 -*-
"""Database module, including the SQLAlchemy database object and DB-related
utilities.
"""
"""Database module, including the SQLAlchemy database object and DB-related utilities."""
from sqlalchemy.orm import relationship
from .extensions import db
from .compat import basestring
from .extensions import db
# Alias common SQLAlchemy names
Column = db.Column
@ -13,9 +11,7 @@ relationship = relationship
class CRUDMixin(object):
"""Mixin that adds convenience methods for CRUD (create, read, update, delete)
operations.
"""
"""Mixin that adds convenience methods for CRUD (create, read, update, delete) operations."""
@classmethod
def create(cls, **kwargs):
@ -25,7 +21,7 @@ class CRUDMixin(object):
def update(self, commit=True, **kwargs):
"""Update specific fields of a record."""
for attr, value in kwargs.iteritems():
for attr, value in kwargs.items():
setattr(self, attr, value)
return commit and self.save() or self
@ -41,39 +37,41 @@ class CRUDMixin(object):
db.session.delete(self)
return commit and db.session.commit()
class Model(CRUDMixin, db.Model):
"""Base model class that includes CRUD convenience methods."""
__abstract__ = True
# From Mike Bayer's "Building the app" talk
# https://speakerdeck.com/zzzeek/building-the-app
class SurrogatePK(object):
"""A mixin that adds a surrogate integer 'primary key' column named
``id`` to any declarative-mapped class.
"""
"""A mixin that adds a surrogate integer 'primary key' column named ``id`` to any declarative-mapped class."""
__table_args__ = {'extend_existing': True}
id = db.Column(db.Integer, primary_key=True)
@classmethod
def get_by_id(cls, id):
def get_by_id(cls, record_id):
"""Get record by ID."""
if any(
(isinstance(id, basestring) and id.isdigit(),
isinstance(id, (int, float))),
(isinstance(record_id, basestring) and record_id.isdigit(),
isinstance(record_id, (int, float))),
):
return cls.query.get(int(id))
return cls.query.get(int(record_id))
return None
def ReferenceCol(tablename, nullable=False, pk_name='id', **kwargs):
def reference_col(tablename, nullable=False, pk_name='id', **kwargs):
"""Column that adds primary key foreign key reference.
Usage: ::
category_id = ReferenceCol('category')
category_id = reference_col('category')
category = relationship('Category', backref='categories')
"""
return db.Column(
db.ForeignKey("{0}.{1}".format(tablename, pk_name)),
db.ForeignKey('{0}.{1}'.format(tablename, pk_name)),
nullable=nullable, **kwargs)

View File

@ -1,25 +1,17 @@
# -*- coding: utf-8 -*-
"""Extensions module. Each extension is initialized in the app factory located
in app.py
"""
"""Extensions module. Each extension is initialized in the app factory located in app.py."""
from flask.ext.hashing import Hashing
hashing = Hashing()
from flask_login import LoginManager
login_manager = LoginManager()
from flask_sqlalchemy import SQLAlchemy
db = SQLAlchemy()
from flask_migrate import Migrate
migrate = Migrate()
from flask_cache import Cache
cache = Cache()
from flask_caching import Cache
from flask_debugtoolbar import DebugToolbarExtension
debug_toolbar = DebugToolbarExtension()
from flask_login import LoginManager
from flask_migrate import Migrate
from flask_sqlalchemy import SQLAlchemy
from flask_wtf.csrf import CsrfProtect
from opbeat.contrib.flask import Opbeat
opbeat = Opbeat()
hashing = Hashing()
csrf_protect = CsrfProtect()
login_manager = LoginManager()
db = SQLAlchemy()
migrate = Migrate()
cache = Cache()
debug_toolbar = DebugToolbarExtension()

View File

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""The public module, including the homepage and user auth."""
from . import views
from . import views # noqa

View File

@ -12,14 +12,18 @@ from dribdat.user.models import User
from wtforms.validators import AnyOf, required, length
class LoginForm(Form):
username = TextField('Username', validators=[DataRequired()])
"""Login form."""
username = StringField('Username', validators=[DataRequired()])
password = PasswordField('Password', validators=[DataRequired()])
def __init__(self, *args, **kwargs):
"""Create instance."""
super(LoginForm, self).__init__(*args, **kwargs)
self.user = None
def validate(self):
"""Validate the form."""
initial_validation = super(LoginForm, self).validate()
if not initial_validation:
return False

View File

@ -22,8 +22,9 @@ def get_current_event():
return event
@login_manager.user_loader
def load_user(id):
return User.get_by_id(int(id))
def load_user(user_id):
"""Load user by ID."""
return User.get_by_id(int(user_id))
@blueprint.route("/")
def home():
@ -50,6 +51,7 @@ def login():
@blueprint.route('/logout/')
@login_required
def logout():
"""Logout."""
logout_user()
flash('You are logged out.', 'info')
return redirect(url_for('public.home'))
@ -57,6 +59,7 @@ def logout():
@blueprint.route("/register/", methods=['GET', 'POST'])
def register():
"""Register new user."""
form = RegisterForm(request.form, csrf_enabled=False)
if request.args.get('name'):
form.username.data = request.args.get('name')

View File

@ -1,22 +1,26 @@
# -*- coding: utf-8 -*-
"""Application configuration."""
import os
os_env = os.environ
class Config(object):
SECRET_KEY = os_env.get('DRIBDAT_SECRET', 'secret-key') # TODO: Change me
"""Base configuration."""
SECRET_KEY = os.environ.get('DRIBDAT_SECRET', 'jaNo-Ol771--yS6se87-2y')
APP_DIR = os.path.abspath(os.path.dirname(__file__)) # This directory
PROJECT_ROOT = os.path.abspath(os.path.join(APP_DIR, os.pardir))
ASSETS_DEBUG = False
DEBUG_TB_ENABLED = False # Disable Debug toolbar
DEBUG_TB_INTERCEPT_REDIRECTS = False
CACHE_TYPE = 'simple' # Can be "memcached", "redis", etc.
SQLALCHEMY_TRACK_MODIFICATIONS = False
CACHE_TYPE = 'simple' # Can be "memcached", "redis", etc.
SERVER_NAME = os_env.get('SERVER_URL', 'localhost:5000')
class ProdConfig(Config):
"""Production configuration."""
ENV = 'prod'
DEBUG = False
SQLALCHEMY_DATABASE_URI = os_env.get('DRIBDAT_DB', 'postgresql://localhost/example')
@ -25,6 +29,7 @@ class ProdConfig(Config):
class DevConfig(Config):
"""Development configuration."""
ENV = 'dev'
DEBUG = True
DB_NAME = 'dev.db'
@ -37,6 +42,8 @@ class DevConfig(Config):
class TestConfig(Config):
"""Test configuration."""
TESTING = True
DEBUG = True
SQLALCHEMY_DATABASE_URI = 'sqlite://'

View File

@ -24,9 +24,6 @@
contact the organisers of the events listed here and they
should be able to help you - or just fork our open source project
and set it up to your liking.</p>
<p>Developed by <a href="http://github.com/loleg">Oleg Lavrovsky</a>
based on <a href="http://github.com/sloria/">Steven Loria</a>'s
<a href="http://github.com/audreyr/cookiecutter/">cookiecutter</a> template.</p>
<p>
<a href="https://github.com/loleg/dribdat" class="btn btn-primary">Visit on GitHub &raquo;</a>
<a href="https://github.com/loleg/dribdat/issues" class="btn btn-defalt">Issues</a>

View File

@ -8,18 +8,19 @@
<p>Create an account here to be able to submit projects.</p>
<br/>
<form id="registerForm" class="form form-register" method="POST" action="" role="form">
<input type="hidden" name="csrf_token" value="{{ csrf_token() }}"/>
{{ form.hidden_tag() }}
<div class="form-group">
{{form.username.label}}
{{form.username(placeholder="", class_="form-control")}}
{{form.username(placeholder="Username", class_="form-control")}}
</div>
<div class="form-group">
{{form.email.label}}
{{form.email(placeholder="name@domain.tld", class_="form-control")}}
{{form.email(placeholder="Email", class_="form-control")}}
</div>
<div class="form-group">
{{form.teamname.label}}
{{form.teamname(placeholder="Dream Team", class_="form-control")}}
{{form.teamname(placeholder="My Team Name", class_="form-control")}}
</div>
<div class="form-group">
{{form.webpage_url.label}}
@ -27,11 +28,11 @@
</div>
<div class="form-group">
{{form.password.label}}
{{form.password(placeholder="**********", class_="form-control")}}
{{form.password(placeholder="Password", class_="form-control")}}
</div>
<div class="form-group">
{{form.confirm.label}}
{{form.confirm(placeholder="**********", class_="form-control")}}
{{form.confirm(placeholder="Password (again)", class_="form-control")}}
</div>
<p><input class="btn btn-default btn-submit" type="submit" value="Register"></p>
</form>

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
"""The user module."""
from . import views
from . import views # noqa
from .constants import USER_ROLE, ADMIN, USER, USER_STATUS, ACTIVE

View File

@ -1,37 +1,42 @@
# -*- coding: utf-8 -*-
from flask_wtf import Form
from wtforms import TextField, PasswordField
from wtforms import TextField, PasswordField, StringField
from wtforms.validators import DataRequired, Email, EqualTo, Length
from .models import User
class RegisterForm(Form):
username = TextField('Username',
validators=[DataRequired(), Length(min=3, max=25)])
email = TextField('Email',
validators=[DataRequired(), Email(), Length(min=6, max=40)])
"""Register form."""
username = StringField('Username',
validators=[DataRequired(), Length(min=3, max=25)])
email = StringField('Email',
validators=[DataRequired(), Email(), Length(min=6, max=40)])
password = PasswordField('Password',
validators=[DataRequired(), Length(min=6, max=40)])
confirm = PasswordField('Verify password',
[DataRequired(), EqualTo('password', message='Passwords must match')])
# DRIBDAT fields
teamname = TextField(u'Team name')
webpage_url = TextField(u'Team web link')
password = PasswordField('Password',
validators=[DataRequired(), Length(min=6, max=40)])
confirm = PasswordField('Verify password',
[DataRequired(), EqualTo('password', message='Passwords must match')])
def __init__(self, *args, **kwargs):
"""Create instance."""
super(RegisterForm, self).__init__(*args, **kwargs)
self.user = None
def validate(self):
"""Validate the form."""
initial_validation = super(RegisterForm, self).validate()
if not initial_validation:
return False
user = User.query.filter_by(username=self.username.data).first()
if user:
self.username.errors.append("Username already registered")
self.username.errors.append('Username already registered')
return False
user = User.query.filter_by(email=self.email.data).first()
if user:
self.email.errors.append("Email already registered")
self.email.errors.append('Email already registered')
return False
return True

View File

@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
"""User models."""
import datetime as dt
from flask_login import UserMixin
@ -19,41 +20,49 @@ from sqlalchemy import or_
class Role(SurrogatePK, Model):
__tablename__ = 'roles'
name = Column(db.String(80), unique=True, nullable=False)
user_id = ReferenceCol('users', nullable=True)
user_id = reference_col('users', nullable=True)
user = relationship('User', backref='roles')
def __init__(self, name, **kwargs):
"""Create instance."""
db.Model.__init__(self, name=name, **kwargs)
def __repr__(self):
"""Represent instance as a unique string."""
return '<Role({name})>'.format(name=self.name)
class User(UserMixin, SurrogatePK, Model):
"""A user of the app."""
__tablename__ = 'users'
username = Column(db.String(80), unique=True, nullable=False)
email = Column(db.String(80), unique=True, nullable=False)
teamname = Column(db.String(128), nullable=True)
webpage_url = Column(db.String(128), nullable=True)
#: The hashed password
password = Column(db.String(128), nullable=True)
created_at = Column(db.DateTime, nullable=False, default=dt.datetime.utcnow)
active = Column(db.Boolean(), default=False)
is_admin = Column(db.Boolean(), default=False)
def __init__(self, username=None, email=None, password=None, **kwargs):
"""Create instance."""
if username and email:
db.Model.__init__(self, username=username, email=email, **kwargs)
if password:
self.set_password(password)
def set_password(self, password):
"""Set password."""
self.password = hashing.hash_value(password)
def check_password(self, value):
"""Check password."""
return hashing.check_value(self.password, value)
def __repr__(self):
"""Represent instance as a unique string."""
return '<User({username!r})>'.format(username=self.username)

View File

@ -1,12 +1,13 @@
# -*- coding: utf-8 -*-
"""User views."""
from flask import Blueprint, render_template
from flask_login import login_required
blueprint = Blueprint("user", __name__, url_prefix='/users',
static_folder="../static")
blueprint = Blueprint('user', __name__, url_prefix='/users', static_folder='../static')
@blueprint.route("/")
@blueprint.route('/')
@login_required
def members():
return render_template("users/members.html")
"""List members."""
return render_template('users/members.html')

View File

@ -3,9 +3,8 @@
from flask import flash
def flash_errors(form, category="warning"):
def flash_errors(form, category='warning'):
"""Flash all errors for a form."""
for field, errors in form.errors.items():
for error in errors:
flash("{0} - {1}"
.format(getattr(form, field).label.text, error), category)
flash('{0} - {1}'.format(getattr(form, field).label.text, error), category)

View File

@ -1,5 +1,3 @@
# Included because many Paas's require a requirements.txt file in the project root
# Just installs the production requirements.
-r requirements/prod.txt
psycopg2==2.6.1

View File

@ -2,9 +2,16 @@
-r prod.txt
# Testing
pytest>=2.9.1
webtest
pytest==2.9.0
WebTest==2.0.20
factory-boy==2.6.1
# Management script
Flask-Script
# Lint and code style
flake8==2.5.4
flake8-blind-except==0.1.0
flake8-debugger==1.4.0
flake8-docstrings==0.2.5
flake8-isort==1.2
flake8-quotes==0.2.4
isort==4.2.2
pep8-naming==0.3.3

View File

@ -1,45 +1,48 @@
# DRIBDAT requirements
requests>=2.9.1
Flask-Markdown>=0.3
pyquery>=1.2.11
# Everything needed in production
setuptools==20.2.2
wheel==0.29.0
# Flask
Flask==0.10.1
Flask==0.11.1
MarkupSafe==0.23
Werkzeug==0.11.5
Werkzeug==0.11.4
Jinja2==2.8
itsdangerous==0.24
click>=5.0
# Database
Flask-SQLAlchemy==2.1
psycopg2==2.6.1
SQLAlchemy==1.0.12
# Migrations
Flask-Migrate==1.8.0
Flask-Migrate==2.0.0
# Forms
Flask-WTF==0.12
WTForms==2.1
# Deployment
gunicorn>=19.4.5
gunicorn>=19.1.1
# Assets
Flask-Assets==0.11
Flask-Assets==0.12
cssmin>=0.2.0
jsmin>=2.2.1
jsmin>=2.0.11
# Auth
Flask-Login==0.3.2
Flask-Hashing==1.1
Flask-Bcrypt==0.7.1
# Caching
Flask-Cache>=0.13.1
Flask-Caching>=1.0.0
# Debug toolbar
Flask-DebugToolbar==0.10.0
# Monitoring
opbeat==3.3
# Other
requests>=2.9.1
Flask-Markdown>=0.3
pyquery>=1.2.11

View File

@ -0,0 +1 @@
"""Tests for the app."""

View File

@ -1,19 +1,19 @@
# -*- coding: utf-8 -*-
"""Defines fixtures available to all tests."""
import os
import pytest
from webtest import TestApp
from dribdat.settings import TestConfig
from dribdat.app import create_app
from dribdat.database import db as _db
from dribdat.settings import TestConfig
from .factories import UserFactory
@pytest.yield_fixture(scope='function')
def app():
"""An application for the tests."""
_app = create_app(TestConfig)
ctx = _app.test_request_context()
ctx.push()
@ -31,17 +31,21 @@ def testapp(app):
@pytest.yield_fixture(scope='function')
def db(app):
"""A database for the tests."""
_db.app = app
with app.app_context():
_db.create_all()
yield _db
# Explicitly close DB connection
_db.session.close()
_db.drop_all()
@pytest.fixture
def user(db):
"""A user for the tests."""
user = UserFactory(password='myprecious')
db.session.commit()
return user

View File

@ -1,23 +1,31 @@
# -*- coding: utf-8 -*-
from factory import Sequence, PostGenerationMethodCall
"""Factories to help in tests."""
from factory import PostGenerationMethodCall, Sequence
from factory.alchemy import SQLAlchemyModelFactory
from dribdat.user.models import User
from dribdat.database import db
from dribdat.user.models import User
class BaseFactory(SQLAlchemyModelFactory):
"""Base factory."""
class Meta:
"""Factory configuration."""
abstract = True
sqlalchemy_session = db.session
class UserFactory(BaseFactory):
username = Sequence(lambda n: "user{0}".format(n))
email = Sequence(lambda n: "user{0}@example.com".format(n))
"""User factory."""
username = Sequence(lambda n: 'user{0}'.format(n))
email = Sequence(lambda n: 'user{0}@example.com'.format(n))
password = PostGenerationMethodCall('set_password', 'example')
active = True
class Meta:
"""Factory configuration."""
model = User

View File

@ -1,9 +1,11 @@
# -*- coding: utf-8 -*-
"""Test configs."""
from dribdat.app import create_app
from dribdat.settings import ProdConfig, DevConfig
from dribdat.settings import DevConfig, ProdConfig
def test_production_config():
"""Production config."""
app = create_app(ProdConfig)
assert app.config['ENV'] == 'prod'
assert app.config['DEBUG'] is False
@ -12,6 +14,7 @@ def test_production_config():
def test_dev_config():
"""Development config."""
app = create_app(DevConfig)
assert app.config['ENV'] == 'dev'
assert app.config['DEBUG'] is True

View File

@ -1,38 +1,41 @@
# -*- coding: utf-8 -*-
import pytest
"""Test forms."""
from dribdat.public.forms import LoginForm
from dribdat.user.forms import RegisterForm
from .factories import UserFactory
class TestRegisterForm:
"""Register form."""
def test_validate_user_already_registered(self, user):
# Enters username that is already registered
"""Enter username that is already registered."""
form = RegisterForm(username=user.username, email='foo@bar.com',
password='example', confirm='example')
password='example', confirm='example')
assert form.validate() is False
assert 'Username already registered' in form.username.errors
def test_validate_email_already_registered(self, user):
# enters email that is already registered
"""Enter email that is already registered."""
form = RegisterForm(username='unique', email=user.email,
password='example', confirm='example')
password='example', confirm='example')
assert form.validate() is False
assert 'Email already registered' in form.email.errors
def test_validate_success(self, db):
"""Register with success."""
form = RegisterForm(username='newusername', email='new@test.test',
password='example', confirm='example')
password='example', confirm='example')
assert form.validate() is True
class TestLoginForm:
"""Login form."""
def test_validate_success(self, user):
"""Login successful."""
user.set_password('example')
user.save()
form = LoginForm(username=user.username, password='example')
@ -40,12 +43,14 @@ class TestLoginForm:
assert form.user == user
def test_validate_unknown_username(self, db):
"""Unknown username."""
form = LoginForm(username='unknown', password='example')
assert form.validate() is False
assert 'Unknown username' in form.username.errors
assert form.user is None
def test_validate_invalid_password(self, user):
"""Invalid password."""
user.set_password('example')
user.save()
form = LoginForm(username=user.username, password='wrongpassword')
@ -53,6 +58,7 @@ class TestLoginForm:
assert 'Invalid password' in form.password.errors
def test_validate_inactive_user(self, user):
"""Inactive user."""
user.active = False
user.set_password('example')
user.save()

View File

@ -3,19 +3,20 @@
See: http://webtest.readthedocs.org/
"""
import pytest
from flask import url_for
from dribdat.user.models import User
from .factories import UserFactory
class TestLoggingIn:
"""Login."""
def test_can_log_in_returns_200(self, user, testapp):
"""Login successful."""
# Goes to homepage
res = testapp.get("/")
res = testapp.get('/')
# Fills out login form in navbar
form = res.forms['loginForm']
form['username'] = user.username
@ -25,7 +26,8 @@ class TestLoggingIn:
assert res.status_code == 200
def test_sees_alert_on_log_out(self, user, testapp):
res = testapp.get("/")
"""Show alert on logout."""
res = testapp.get('/')
# Fills out login form in navbar
form = res.forms['loginForm']
form['username'] = user.username
@ -37,8 +39,9 @@ class TestLoggingIn:
assert 'You are logged out.' in res
def test_sees_error_message_if_password_is_incorrect(self, user, testapp):
"""Show error if password is incorrect."""
# Goes to homepage
res = testapp.get("/")
res = testapp.get('/')
# Fills out login form, password incorrect
form = res.forms['loginForm']
form['username'] = user.username
@ -46,11 +49,12 @@ class TestLoggingIn:
# Submits
res = form.submit()
# sees error
assert "Invalid password" in res
assert 'Invalid password' in res
def test_sees_error_message_if_username_doesnt_exist(self, user, testapp):
"""Show error if username doesn't exist."""
# Goes to homepage
res = testapp.get("/")
res = testapp.get('/')
# Fills out login form, password incorrect
form = res.forms['loginForm']
form['username'] = 'unknown'
@ -58,19 +62,21 @@ class TestLoggingIn:
# Submits
res = form.submit()
# sees error
assert "Unknown user" in res
assert 'Unknown user' in res
class TestRegistering:
"""Register a user."""
def test_can_register(self, user, testapp):
"""Register a new user."""
old_count = len(User.query.all())
# Goes to homepage
res = testapp.get("/")
res = testapp.get('/')
# Clicks Create Account button
res = res.click("Create account")
res = res.click('Create account')
# Fills out the form
form = res.forms["registerForm"]
form = res.forms['registerForm']
form['username'] = 'foobar'
form['email'] = 'foo@bar.com'
form['password'] = 'secret'
@ -82,10 +88,11 @@ class TestRegistering:
assert len(User.query.all()) == old_count + 1
def test_sees_error_message_if_passwords_dont_match(self, user, testapp):
"""Show error if passwords don't match."""
# Goes to registration page
res = testapp.get(url_for("public.register"))
res = testapp.get(url_for('public.register'))
# Fills out form, but passwords don't match
form = res.forms["registerForm"]
form = res.forms['registerForm']
form['username'] = 'foobar'
form['email'] = 'foo@bar.com'
form['password'] = 'secret'
@ -93,15 +100,16 @@ class TestRegistering:
# Submits
res = form.submit()
# sees error message
assert "Passwords must match" in res
assert 'Passwords must match' in res
def test_sees_error_message_if_user_already_registered(self, user, testapp):
"""Show error if user already registered."""
user = UserFactory(active=True) # A registered user
user.save()
# Goes to registration page
res = testapp.get(url_for("public.register"))
res = testapp.get(url_for('public.register'))
# Fills out form, but username is already registered
form = res.forms["registerForm"]
form = res.forms['registerForm']
form['username'] = user.username
form['email'] = 'foo@bar.com'
form['password'] = 'secret'
@ -109,4 +117,4 @@ class TestRegistering:
# Submits
res = form.submit()
# sees error
assert "Username already registered" in res
assert 'Username already registered' in res

View File

@ -4,14 +4,17 @@ import datetime as dt
import pytest
from dribdat.user.models import User, Role
from dribdat.user.models import Role, User
from .factories import UserFactory
@pytest.mark.usefixtures('db')
class TestUser:
"""User tests."""
def test_get_by_id(self):
"""Get user by ID."""
user = User('foo', 'foo@bar.com')
user.save()
@ -19,18 +22,21 @@ class TestUser:
assert retrieved == user
def test_created_at_defaults_to_datetime(self):
"""Test creation date."""
user = User(username='foo', email='foo@bar.com')
user.save()
assert bool(user.created_at)
assert isinstance(user.created_at, dt.datetime)
def test_password_is_nullable(self):
"""Test null password."""
user = User(username='foo', email='foo@bar.com')
user.save()
assert user.password is None
def test_factory(self, db):
user = UserFactory(password="myprecious")
"""Test user factory."""
user = UserFactory(password='myprecious')
db.session.commit()
assert bool(user.username)
assert bool(user.email)
@ -40,15 +46,22 @@ class TestUser:
assert user.check_password('myprecious')
def test_check_password(self):
user = User.create(username="foo", email="foo@bar.com",
password="foobarbaz123")
"""Check password."""
user = User.create(username='foo', email='foo@bar.com',
password='foobarbaz123')
assert user.check_password('foobarbaz123') is True
assert user.check_password("barfoobaz") is False
assert user.check_password('barfoobaz') is False
def test_full_name(self):
"""User full name."""
user = UserFactory(first_name='Foo', last_name='Bar')
assert user.full_name == 'Foo Bar'
def test_roles(self):
"""Add a role to a user."""
role = Role(name='admin')
role.save()
u = UserFactory()
u.roles.append(role)
u.save()
assert role in u.roles
user = UserFactory()
user.roles.append(role)
user.save()
assert role in user.roles