cdist/cdist/inventory.py

405 lines
14 KiB
Python

# -*- coding: utf-8 -*-
#
# 2016 Darko Poljak (darko.poljak at gmail.com)
#
# This file is part of cdist.
#
# cdist is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# cdist is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with cdist. If not, see <http://www.gnu.org/licenses/>.
#
#
import cdist
import logging
import os
import os.path
import itertools
import sys
import cdist.configuration
from cdist.hostsource import hostfile_process_line
DIST_INVENTORY_DB_NAME = "inventory"
dist_inventory_db = os.path.abspath(os.path.join(
os.path.dirname(cdist.__file__), DIST_INVENTORY_DB_NAME))
def determine_default_inventory_dir(args, configuration):
# The order of inventory dir setting by decreasing priority
# 1. inventory_dir from configuration
# 2. ~/.cdist/inventory if HOME env var is set
# 3. distribution inventory directory
inventory_dir_set = False
if 'inventory_dir' in configuration:
val = configuration['inventory_dir']
if val:
args.inventory_dir = val
inventory_dir_set = True
if not inventory_dir_set:
home = cdist.home_dir()
if home:
args.inventory_dir = os.path.join(home, DIST_INVENTORY_DB_NAME)
else:
args.inventory_dir = dist_inventory_db
def contains_all(big, little):
"""Return True if big contains all elements from little,
False otherwise.
"""
return set(little).issubset(set(big))
def contains_any(big, little):
"""Return True if big contains any element from little,
False otherwise.
"""
for x in little:
if x in big:
return True
return False
def check_always_true(x, y):
return True
def rstrip_nl(s):
'''str.rstrip "\n" from s'''
return str.rstrip(s, "\n")
class Inventory:
"""Inventory main class"""
def __init__(self, db_basedir=dist_inventory_db, configuration=None):
self.db_basedir = db_basedir
if configuration:
self.configuration = configuration
else:
self.configuration = {}
self.log = logging.getLogger("inventory")
self.init_db()
def init_db(self):
self.log.trace("Init db: %s", self.db_basedir)
if not os.path.exists(self.db_basedir):
os.makedirs(self.db_basedir, exist_ok=True)
elif not os.path.isdir(self.db_basedir):
raise cdist.Error(("Invalid inventory db basedir \'{}\',"
" must be a directory").format(self.db_basedir))
@staticmethod
def strlist_to_list(slist):
if slist:
result = [x for x in slist.split(',') if x]
else:
result = []
return result
def _input_values(self, source):
"""Yield input values from source.
Source can be a sequence or filename (stdin if '-').
In case of filename each line represents one input value.
"""
if isinstance(source, str):
import fileinput
try:
with fileinput.FileInput(files=(source)) as f:
for x in f:
result = hostfile_process_line(x, strip_func=rstrip_nl)
if result:
yield result
except (IOError, OSError) as e:
raise cdist.Error("Error reading from \'{}\'".format(
source))
else:
if source:
for x in source:
if x:
yield x
def _host_path(self, host):
hostpath = os.path.join(self.db_basedir, host)
return hostpath
def _all_hosts(self):
return os.listdir(self.db_basedir)
def _check_host(self, hostpath):
if not os.path.exists(hostpath):
return False
else:
if not os.path.isfile(hostpath):
raise cdist.Error(("Host path \'{}\' exists, but is not"
" a valid file").format(hostpath))
return True
def _read_host_tags(self, hostpath):
result = set()
with open(hostpath, "rt") as f:
for tag in f:
tag = tag.rstrip("\n")
if tag:
result.add(tag)
return result
def _get_host_tags(self, host):
hostpath = self._host_path(host)
if self._check_host(hostpath):
return self._read_host_tags(hostpath)
else:
return None
def _write_host_tags(self, host, tags):
hostpath = self._host_path(host)
if self._check_host(hostpath):
with open(hostpath, "wt") as f:
for tag in tags:
f.write("{}\n".format(tag))
return True
else:
return False
@classmethod
def commandline(cls, args):
"""Manipulate inventory db"""
log = logging.getLogger("inventory")
if 'taglist' in args:
args.taglist = cls.strlist_to_list(args.taglist)
cfg = cdist.configuration.Configuration(args)
configuration = cfg.get_config(section='GLOBAL')
determine_default_inventory_dir(args, configuration)
log.debug("Using inventory: %s", args.inventory_dir)
log.trace("Inventory args: %s", vars(args))
log.trace("Inventory command: %s", args.subcommand)
if args.subcommand == "list":
c = InventoryList(hosts=args.host, istag=args.tag,
hostfile=args.hostfile,
db_basedir=args.inventory_dir,
list_only_host=args.list_only_host,
has_all_tags=args.has_all_tags,
configuration=configuration)
elif args.subcommand == "add-host":
c = InventoryHost(hosts=args.host, hostfile=args.hostfile,
db_basedir=args.inventory_dir,
configuration=configuration)
elif args.subcommand == "del-host":
c = InventoryHost(hosts=args.host, hostfile=args.hostfile,
all=args.all, db_basedir=args.inventory_dir,
action="del", configuration=configuration)
elif args.subcommand == "add-tag":
c = InventoryTag(hosts=args.host, tags=args.taglist,
hostfile=args.hostfile, tagfile=args.tagfile,
db_basedir=args.inventory_dir,
configuration=configuration)
elif args.subcommand == "del-tag":
c = InventoryTag(hosts=args.host, tags=args.taglist,
hostfile=args.hostfile, tagfile=args.tagfile,
all=args.all, db_basedir=args.inventory_dir,
action="del", configuration=configuration)
else:
raise cdist.Error("Unknown inventory command \'{}\'".format(
args.subcommand))
c.run()
class InventoryList(Inventory):
def __init__(self, hosts=None, istag=False, hostfile=None,
list_only_host=False, has_all_tags=False,
db_basedir=dist_inventory_db, configuration=None):
super().__init__(db_basedir, configuration)
self.hosts = hosts
self.istag = istag
self.hostfile = hostfile
self.list_only_host = list_only_host
self.has_all_tags = has_all_tags
def _print(self, host, tags):
if self.list_only_host:
print("{}".format(host))
else:
print("{} {}".format(host, ",".join(sorted(tags))))
def _do_list(self, it_tags, it_hosts, check_func):
if (it_tags is not None):
param_tags = set(it_tags)
self.log.trace("param_tags: %s", param_tags)
else:
param_tags = set()
for host in it_hosts:
self.log.trace("host: %s", host)
tags = self._get_host_tags(host)
if tags is None:
self.log.debug("Host \'%s\' not found, skipped", host)
continue
self.log.trace("tags: %s", tags)
if check_func(tags, param_tags):
yield host, tags
def entries(self):
if not self.hosts and not self.hostfile:
self.log.trace("Listing all hosts")
it_hosts = self._all_hosts()
it_tags = None
check_func = check_always_true
else:
it = itertools.chain(self._input_values(self.hosts),
self._input_values(self.hostfile))
if self.istag:
self.log.trace("Listing by tag(s)")
it_hosts = self._all_hosts()
it_tags = it
if self.has_all_tags:
check_func = contains_all
else:
check_func = contains_any
else:
self.log.trace("Listing by host(s)")
it_hosts = it
it_tags = None
check_func = check_always_true
for host, tags in self._do_list(it_tags, it_hosts, check_func):
yield host, tags
def host_entries(self):
for host, tags in self.entries():
yield host
def run(self):
for host, tags in self.entries():
self._print(host, tags)
class InventoryHost(Inventory):
def __init__(self, hosts=None, hostfile=None,
db_basedir=dist_inventory_db, all=False, action="add",
configuration=None):
super().__init__(db_basedir, configuration)
self.actions = ("add", "del")
if action not in self.actions:
raise cdist.Error("Invalid action \'{}\', valid actions are:"
" {}\n".format(action, self.actions.keys()))
self.action = action
self.hosts = hosts
self.hostfile = hostfile
self.all = all
if not self.hosts and not self.hostfile:
raise cdist.Error("Host(s) missing")
def _new_hostpath(self, hostpath):
# create empty file
with open(hostpath, "w"):
pass
def _action(self, host):
if self.action == "add":
self.log.debug("Adding host \'%s\'", host)
elif self.action == "del":
self.log.debug("Deleting host \'%s\'", host)
hostpath = self._host_path(host)
self.log.trace("hostpath: %s", hostpath)
if self.action == "add" and not os.path.exists(hostpath):
self._new_hostpath(hostpath)
else:
if not os.path.isfile(hostpath):
raise cdist.Error(("Host path \'{}\' is"
" not a valid file").format(hostpath))
if self.action == "del":
os.remove(hostpath)
def run(self):
if self.action == "del" and self.all:
self.log.trace("Doing for all hosts")
it = self._all_hosts()
else:
self.log.trace("Doing for specified hosts")
it = itertools.chain(self._input_values(self.hosts),
self._input_values(self.hostfile))
for host in it:
self._action(host)
class InventoryTag(Inventory):
def __init__(self, hosts=None, tags=None, hostfile=None, tagfile=None,
db_basedir=dist_inventory_db, all=False, action="add",
configuration=None):
super().__init__(db_basedir, configuration)
self.actions = ("add", "del")
if action not in self.actions:
raise cdist.Error("Invalid action \'{}\', valid actions are:"
" {}\n".format(action, self.actions.keys()))
self.action = action
self.hosts = hosts
self.tags = tags
self.hostfile = hostfile
self.tagfile = tagfile
self.all = all
if not self.hosts and not self.hostfile:
self.allhosts = True
else:
self.allhosts = False
if not self.tags and not self.tagfile:
raise cdist.Error("Tag(s) missing")
if self.hostfile == "-" and self.tagfile == "-":
raise cdist.Error("Cannot read both, hosts and tags, from stdin")
def _read_input_tags(self):
self.input_tags = set()
for tag in itertools.chain(self._input_values(self.tags),
self._input_values(self.tagfile)):
self.input_tags.add(tag)
def _action(self, host):
host_tags = self._get_host_tags(host)
if host_tags is None:
print("Host \'{}\' does not exist, skipping".format(host),
file=sys.stderr)
return
self.log.trace("existing host_tags: %s", host_tags)
if self.action == "del" and self.all:
host_tags = set()
else:
for tag in self.input_tags:
if self.action == "add":
self.log.debug("Adding tag \'%s\' for host \'%s\'",
tag, host)
host_tags.add(tag)
elif self.action == "del":
self.log.debug("Deleting tag \'%s\' for host \'%s\'",
tag, host)
if tag in host_tags:
host_tags.remove(tag)
self.log.trace("new host tags: %s", host_tags)
if not self._write_host_tags(host, host_tags):
self.log.trace("%s does not exist, skipped", host)
def run(self):
if self.allhosts:
self.log.trace("Doing for all hosts")
it = self._all_hosts()
else:
self.log.trace("Doing for specified hosts")
it = itertools.chain(self._input_values(self.hosts),
self._input_values(self.hostfile))
if not(self.action == "del" and self.all):
self._read_input_tags()
for host in it:
self._action(host)