From 6e9e9ad557479354938257f04c8fab283b8251e8 Mon Sep 17 00:00:00 2001
From: Steven Armstrong <steven@icarus.ethz.ch>
Date: Tue, 12 Nov 2019 00:40:58 +0100
Subject: [PATCH] implement log server to capture nested logging output

Signed-off-by: Steven Armstrong <steven@icarus.ethz.ch>
---
 cdist/core/code.py |  4 +++
 cdist/install.py   | 23 +++++++++++++-
 cdist/log.py       | 77 ++++++++++++++++++++++++++++++++++++++--------
 3 files changed, 90 insertions(+), 14 deletions(-)

diff --git a/cdist/core/code.py b/cdist/core/code.py
index 1550880a..a7d9b7ca 100644
--- a/cdist/core/code.py
+++ b/cdist/core/code.py
@@ -116,6 +116,10 @@ class Code(object):
         if dry_run:
             self.env['__cdist_dry_run'] = '1'
 
+        if '__cdist_log_server_socket_to_export' in os.environ:
+            self.env['__cdist_log_server_socket'] = os.environ['__cdist_log_server_socket_to_export']
+
+
     def _run_gencode(self, cdist_object, which):
         cdist_type = cdist_object.cdist_type
         script = os.path.join(self.local.type_path,
diff --git a/cdist/install.py b/cdist/install.py
index b88ad016..3f94ca68 100644
--- a/cdist/install.py
+++ b/cdist/install.py
@@ -1,7 +1,7 @@
 #!/usr/bin/env python3
 # -*- coding: utf-8 -*-
 #
-# 2013 Steven Armstrong (steven-cdist at armstrong.cc)
+# 2013-2019 Steven Armstrong (steven-cdist at armstrong.cc)
 #
 # This file is part of cdist.
 #
@@ -20,11 +20,32 @@
 #
 #
 
+import os
+import logging
+import tempfile
+
 import cdist.config
 import cdist.core
 
 
 class Install(cdist.config.Config):
+
+    @classmethod
+    def onehost(cls, host, host_tags, host_base_path, host_dir_name, args,
+                parallel, configuration, remove_remote_files_dirs=False):
+        # Start a log server so nested `cdist config` runs have a place to
+        # send their logs to.
+        log_server_socket_dir = tempfile.mkdtemp()
+        log_server_socket = os.path.join(log_server_socket_dir, 'log-server')
+        cls._register_path_for_removal(log_server_socket_dir)
+        log = logging.getLogger(host)
+        log.debug('Starting logging server on: %s', log_server_socket)
+        os.environ['__cdist_log_server_socket_to_export'] = log_server_socket
+        cdist.log.setupLogServer(log_server_socket)
+
+        super().onehost(host, host_tags, host_base_path, host_dir_name, args,
+                parallel, configuration, remove_remote_files_dirs=False)
+
     def object_list(self):
         """Short name for object list retrieval.
         In install mode, we only care about install objects.
diff --git a/cdist/log.py b/cdist/log.py
index 790059df..94dd11e8 100644
--- a/cdist/log.py
+++ b/cdist/log.py
@@ -2,6 +2,7 @@
 # -*- coding: utf-8 -*-
 #
 # 2010-2013 Nico Schottelius (nico-cdist at schottelius.org)
+# 2019-2020 Steven Armstrong
 #
 # This file is part of cdist.
 #
@@ -20,9 +21,17 @@
 #
 #
 
-import logging
-import sys
+import asyncio
+import contextlib
 import datetime
+import logging
+import logging.handlers
+import os
+import pickle
+import struct
+import sys
+import threading
+import time
 
 
 # Define additional cdist logging levels.
@@ -89,20 +98,25 @@ class DefaultLog(logging.Logger):
         super().__init__(name)
         self.propagate = False
 
-        formatter = CdistFormatter(self.FORMAT)
+        if '__cdist_log_server_socket' in os.environ:
+            log_server_socket = os.environ['__cdist_log_server_socket']
+            socket_handler = logging.handlers.SocketHandler(log_server_socket, None)
+            self.addHandler(socket_handler)
+        else:
+            formatter = CdistFormatter(self.FORMAT)
 
-        stdout_handler = logging.StreamHandler(sys.stdout)
-        stdout_handler.addFilter(self.StdoutFilter())
-        stdout_handler.setLevel(logging.TRACE)
-        stdout_handler.setFormatter(formatter)
+            stdout_handler = logging.StreamHandler(sys.stdout)
+            stdout_handler.addFilter(self.StdoutFilter())
+            stdout_handler.setLevel(logging.TRACE)
+            stdout_handler.setFormatter(formatter)
 
-        stderr_handler = logging.StreamHandler(sys.stderr)
-        stderr_handler.addFilter(self.StderrFilter())
-        stderr_handler.setLevel(logging.ERROR)
-        stderr_handler.setFormatter(formatter)
+            stderr_handler = logging.StreamHandler(sys.stderr)
+            stderr_handler.addFilter(self.StderrFilter())
+            stderr_handler.setLevel(logging.ERROR)
+            stderr_handler.setFormatter(formatter)
 
-        self.addHandler(stdout_handler)
-        self.addHandler(stderr_handler)
+            self.addHandler(stdout_handler)
+            self.addHandler(stderr_handler)
 
     def verbose(self, msg, *args, **kwargs):
         self.log(logging.VERBOSE, msg, *args, **kwargs)
@@ -152,4 +166,41 @@ def setupParallelLogging():
     logging.setLoggerClass(ParallelLog)
 
 
+async def handle_log_client(reader, writer):
+    while True:
+        chunk = await reader.read(4)
+        if len(chunk) < 4:
+            return
+
+        data_size = struct.unpack('>L', chunk)[0]
+        data = bytearray(data_size)
+        view = memoryview(data)
+        data_pending = data_size
+        data = await reader.read(data_size)
+
+        obj = pickle.loads(data)
+        record = logging.makeLogRecord(obj)
+        logger = logging.getLogger(record.name)
+        logger.handle(record)
+
+
+def run_log_server(server_address):
+    # Get a new loop inside the current thread to run the log server.
+    loop = asyncio.new_event_loop()
+    loop.create_task(asyncio.start_unix_server(handle_log_client, server_address))
+    loop.run_forever()
+
+
+def setupLogServer(log_server_socket):
+    """Run a asyncio based unix socket log server in a background thread.
+    """
+    with contextlib.suppress(FileNotFoundError):
+        os.remove(log_server_socket)
+    t = threading.Thread(target=run_log_server, args=(log_server_socket,))
+    # Deamonizing the thread means we don't have to care about stoping it.
+    # It will die together with the main process.
+    t.daemon = True
+    t.start()
+
+
 setupDefaultLogging()