import pyotp import os from os.path import join from flask import Flask, request from flask_restful import Resource, Api from .schemas import (OTPSchema, CreateOTPSchema, DeleteOTPSchema, ListAccountSchema) from .config import etcd_client, env_vars from .helper import is_valid_otp, create_admin_if_dont_exists app = Flask(__name__) api = Api(app) create_admin_if_dont_exists(etcd_client) class Verify(Resource): @staticmethod def post(): data = request.json if data: schema = OTPSchema(data) if schema.is_valid(): return {"message": "Verified"}, 200 else: return schema.get_errors(), 400 else: return {"message": "No Data"}, 400 class Create(Resource): @staticmethod def post(): data = request.json schema = CreateOTPSchema(data) if schema.is_valid(): _key = join(env_vars.get("BASE_PREFIX"), data["name"]) if etcd_client.get(_key) is None: if not isinstance(data["realm"], list): realms = [data["realm"]] else: realms = data["realm"] _value = {"seed": pyotp.random_base32(), "realm": realms} etcd_client.put(_key, _value, value_in_json=True) return { "message": "Account Created", "credentials": { "name": data["name"], "realm": _value["realm"], "seed": _value["seed"] } }, 200 else: return schema.get_errors(), 400 class Delete(Resource): @staticmethod def post(): data = request.json schema = DeleteOTPSchema(data) if schema.is_valid(): _key = join(env_vars.get("BASE_PREFIX"), data["name"]) etcd_client.client.delete(_key) return {"message": "Account Deleted"} else: return schema.get_errors(), 400 class List(Resource): @staticmethod def post(): data = request.json schema = ListAccountSchema(data) if schema.is_valid(): result = etcd_client.get_prefix( env_vars.get("BASE_PREFIX"), value_in_json=True ) r = {} for entry in result: _name = entry.key.split("/")[-1] r["{}".format(_name)] = { "seed": entry.value["seed"], "realm": entry.value["realm"], } return r else: return schema.get_errors(), 400 api.add_resource(Verify, "/verify/") api.add_resource(Create, "/create/") api.add_resource(Delete, "/delete/") api.add_resource(List, "/list/") def main(): app.run(debug=True, host="::", port=env_vars.get("PORT", int)) if __name__ == "__main__": app.run(debug=True, host="::", port=env_vars.get("PORT", int))