forked from ungleich-public/cdist
		
	Further improve parallel execution.
This commit is contained in:
		
					parent
					
						
							
								d1a044cc23
							
						
					
				
			
			
				commit
				
					
						23fbabe303
					
				
			
		
					 2 changed files with 11 additions and 14 deletions
				
			
		| 
						 | 
				
			
			@ -29,7 +29,7 @@ import itertools
 | 
			
		|||
import tempfile
 | 
			
		||||
import socket
 | 
			
		||||
import multiprocessing
 | 
			
		||||
from cdist.mputil import mp_pool_run
 | 
			
		||||
from cdist.mputil import mp_pool_run, mp_sig_handler
 | 
			
		||||
import atexit
 | 
			
		||||
import shutil
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -138,12 +138,8 @@ class Config(object):
 | 
			
		|||
        if args.parallel:
 | 
			
		||||
            import signal
 | 
			
		||||
 | 
			
		||||
            def sigterm_handler(signum, frame):
 | 
			
		||||
                log.trace("signal %s, killing whole process group", signum)
 | 
			
		||||
                os.killpg(os.getpgrp(), signal.SIGKILL)
 | 
			
		||||
 | 
			
		||||
            signal.signal(signal.SIGTERM, sigterm_handler)
 | 
			
		||||
            signal.signal(signal.SIGHUP, sigterm_handler)
 | 
			
		||||
            signal.signal(signal.SIGTERM, mp_sig_handler)
 | 
			
		||||
            signal.signal(signal.SIGHUP, mp_sig_handler)
 | 
			
		||||
 | 
			
		||||
        # FIXME: Refactor relict - remove later
 | 
			
		||||
        log = logging.getLogger("cdist")
 | 
			
		||||
| 
						 | 
				
			
			@ -221,8 +217,7 @@ class Config(object):
 | 
			
		|||
                cls.onehost(*process_args[0])
 | 
			
		||||
            except cdist.Error as e:
 | 
			
		||||
                failed_hosts.append(host)
 | 
			
		||||
        # Catch errors in parallel mode when joining
 | 
			
		||||
        if args.parallel:
 | 
			
		||||
        elif args.parallel:
 | 
			
		||||
            log.trace("Multiprocessing start method is {}".format(
 | 
			
		||||
                multiprocessing.get_start_method()))
 | 
			
		||||
            log.trace(("Starting multiprocessing Pool for {} "
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,6 +1,6 @@
 | 
			
		|||
# -*- coding: utf-8 -*-
 | 
			
		||||
#
 | 
			
		||||
# 2016 Darko Poljak (darko.poljak at gmail.com)
 | 
			
		||||
# 2016-2017 Darko Poljak (darko.poljak at gmail.com)
 | 
			
		||||
#
 | 
			
		||||
# This file is part of cdist.
 | 
			
		||||
#
 | 
			
		||||
| 
						 | 
				
			
			@ -31,6 +31,11 @@ import logging
 | 
			
		|||
log = logging.getLogger("cdist-mputil")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def mp_sig_handler(signum, frame):
 | 
			
		||||
    log.trace("signal %s, SIGKILL whole process group", signum)
 | 
			
		||||
    os.killpg(os.getpgrp(), signal.SIGKILL)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def mp_pool_run(func, args=None, kwds=None, jobs=multiprocessing.cpu_count()):
 | 
			
		||||
    """Run func using concurrent.futures.ProcessPoolExecutor with jobs jobs
 | 
			
		||||
       and supplied iterables of args and kwds with one entry for each
 | 
			
		||||
| 
						 | 
				
			
			@ -56,8 +61,5 @@ def mp_pool_run(func, args=None, kwds=None, jobs=multiprocessing.cpu_count()):
 | 
			
		|||
                retval.append(f.result())
 | 
			
		||||
            return retval
 | 
			
		||||
        except KeyboardInterrupt:
 | 
			
		||||
            log.trace("KeyboardInterrupt, killing process group")
 | 
			
		||||
            # When Ctrl+C in terminal then kill whole process group.
 | 
			
		||||
            # Otherwise there remain processes in sleeping state.
 | 
			
		||||
            os.killpg(os.getpgrp(), signal.SIGKILL)
 | 
			
		||||
            mp_sig_handler(signal.SIGINT, None)
 | 
			
		||||
            raise
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue