uncloud-mravi/uncloud/hack/db.py
2020-02-09 12:12:15 +01:00

127 lines
3.8 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# 2020 Nico Schottelius (nico.schottelius at ungleich.ch)
#
# This file is part of uncloud.
#
# uncloud is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# uncloud is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with uncloud. If not, see <http://www.gnu.org/licenses/>.
#
#
import etcd3
import json
import logging
from functools import wraps
from uncloud import UncloudException
log = logging.getLogger(__name__)
def readable_errors(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except etcd3.exceptions.ConnectionFailedError as e:
raise UncloudException('Cannot connect to etcd: is etcd running and reachable? {}'.format(e))
except etcd3.exceptions.ConnectionTimeoutError as e:
raise UncloudException('etcd connection timeout. {}'.format(e))
return wrapper
class DB(object):
def __init__(self, config, prefix="/"):
self.config = config
# Root for everything
self.base_prefix= '/nicohack'
# Can be set from outside
self.prefix = prefix
try:
self.connect()
except FileNotFoundError as e:
raise UncloudException("Is the path to the etcd certs correct? {}".format(e))
@readable_errors
def connect(self):
self._db_clients = []
for endpoint in self.config.etcd_hosts:
client = etcd3.client(host=endpoint, **self.config.etcd_args)
self._db_clients.append(client)
def realkey(self, key):
return "{}{}/{}".format(self.base_prefix,
self.prefix,
key)
@readable_errors
def get(self, key, as_json=False, **kwargs):
value, _ = self._db_clients[0].get(self.realkey(key), **kwargs)
if as_json:
value = json.loads(value)
return value
@readable_errors
def get_prefix(self, key, as_json=False, **kwargs):
for value, meta in self._db_clients[0].get_prefix(self.realkey(key), **kwargs):
k = meta.key.decode("utf-8")
value = value.decode("utf-8")
if as_json:
value = json.loads(value)
yield (k, value)
@readable_errors
def set(self, key, value, as_json=False, **kwargs):
if as_json:
value = json.dumps(value)
log.debug("Setting {} = {}".format(self.realkey(key), value))
# FIXME: iterate over clients in case of failure ?
return self._db_clients[0].put(self.realkey(key), value, **kwargs)
@readable_errors
def increment(self, key, **kwargs):
print(self.realkey(key))
print("prelock")
lock = self._db_clients[0].lock('/nicohack/foo')
print("prelockacq")
lock.acquire()
print("prelockrelease")
lock.release()
with self._db_clients[0].lock("/nicohack/mac/last_used_index") as lock:
print("in lock")
pass
# with self._db_clients[0].lock(self.realkey(key)) as lock:# value = int(self.get(self.realkey(key), **kwargs))
# self.set(self.realkey(key), str(value + 1), **kwargs)
if __name__ == '__main__':
endpoints = [ "https://etcd1.ungleich.ch:2379",
"https://etcd2.ungleich.ch:2379",
"https://etcd3.ungleich.ch:2379" ]
db = DB(url=endpoints)