forked from ungleich-public/cdist
		
	Further improve parallel execution.
This commit is contained in:
		
					parent
					
						
							
								d1a044cc23
							
						
					
				
			
			
				commit
				
					
						23fbabe303
					
				
			
		
					 2 changed files with 11 additions and 14 deletions
				
			
		| 
						 | 
					@ -29,7 +29,7 @@ import itertools
 | 
				
			||||||
import tempfile
 | 
					import tempfile
 | 
				
			||||||
import socket
 | 
					import socket
 | 
				
			||||||
import multiprocessing
 | 
					import multiprocessing
 | 
				
			||||||
from cdist.mputil import mp_pool_run
 | 
					from cdist.mputil import mp_pool_run, mp_sig_handler
 | 
				
			||||||
import atexit
 | 
					import atexit
 | 
				
			||||||
import shutil
 | 
					import shutil
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -138,12 +138,8 @@ class Config(object):
 | 
				
			||||||
        if args.parallel:
 | 
					        if args.parallel:
 | 
				
			||||||
            import signal
 | 
					            import signal
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            def sigterm_handler(signum, frame):
 | 
					            signal.signal(signal.SIGTERM, mp_sig_handler)
 | 
				
			||||||
                log.trace("signal %s, killing whole process group", signum)
 | 
					            signal.signal(signal.SIGHUP, mp_sig_handler)
 | 
				
			||||||
                os.killpg(os.getpgrp(), signal.SIGKILL)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            signal.signal(signal.SIGTERM, sigterm_handler)
 | 
					 | 
				
			||||||
            signal.signal(signal.SIGHUP, sigterm_handler)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # FIXME: Refactor relict - remove later
 | 
					        # FIXME: Refactor relict - remove later
 | 
				
			||||||
        log = logging.getLogger("cdist")
 | 
					        log = logging.getLogger("cdist")
 | 
				
			||||||
| 
						 | 
					@ -221,8 +217,7 @@ class Config(object):
 | 
				
			||||||
                cls.onehost(*process_args[0])
 | 
					                cls.onehost(*process_args[0])
 | 
				
			||||||
            except cdist.Error as e:
 | 
					            except cdist.Error as e:
 | 
				
			||||||
                failed_hosts.append(host)
 | 
					                failed_hosts.append(host)
 | 
				
			||||||
        # Catch errors in parallel mode when joining
 | 
					        elif args.parallel:
 | 
				
			||||||
        if args.parallel:
 | 
					 | 
				
			||||||
            log.trace("Multiprocessing start method is {}".format(
 | 
					            log.trace("Multiprocessing start method is {}".format(
 | 
				
			||||||
                multiprocessing.get_start_method()))
 | 
					                multiprocessing.get_start_method()))
 | 
				
			||||||
            log.trace(("Starting multiprocessing Pool for {} "
 | 
					            log.trace(("Starting multiprocessing Pool for {} "
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,6 +1,6 @@
 | 
				
			||||||
# -*- coding: utf-8 -*-
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# 2016 Darko Poljak (darko.poljak at gmail.com)
 | 
					# 2016-2017 Darko Poljak (darko.poljak at gmail.com)
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# This file is part of cdist.
 | 
					# This file is part of cdist.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
| 
						 | 
					@ -31,6 +31,11 @@ import logging
 | 
				
			||||||
log = logging.getLogger("cdist-mputil")
 | 
					log = logging.getLogger("cdist-mputil")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def mp_sig_handler(signum, frame):
 | 
				
			||||||
 | 
					    log.trace("signal %s, SIGKILL whole process group", signum)
 | 
				
			||||||
 | 
					    os.killpg(os.getpgrp(), signal.SIGKILL)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def mp_pool_run(func, args=None, kwds=None, jobs=multiprocessing.cpu_count()):
 | 
					def mp_pool_run(func, args=None, kwds=None, jobs=multiprocessing.cpu_count()):
 | 
				
			||||||
    """Run func using concurrent.futures.ProcessPoolExecutor with jobs jobs
 | 
					    """Run func using concurrent.futures.ProcessPoolExecutor with jobs jobs
 | 
				
			||||||
       and supplied iterables of args and kwds with one entry for each
 | 
					       and supplied iterables of args and kwds with one entry for each
 | 
				
			||||||
| 
						 | 
					@ -56,8 +61,5 @@ def mp_pool_run(func, args=None, kwds=None, jobs=multiprocessing.cpu_count()):
 | 
				
			||||||
                retval.append(f.result())
 | 
					                retval.append(f.result())
 | 
				
			||||||
            return retval
 | 
					            return retval
 | 
				
			||||||
        except KeyboardInterrupt:
 | 
					        except KeyboardInterrupt:
 | 
				
			||||||
            log.trace("KeyboardInterrupt, killing process group")
 | 
					            mp_sig_handler(signal.SIGINT, None)
 | 
				
			||||||
            # When Ctrl+C in terminal then kill whole process group.
 | 
					 | 
				
			||||||
            # Otherwise there remain processes in sleeping state.
 | 
					 | 
				
			||||||
            os.killpg(os.getpgrp(), signal.SIGKILL)
 | 
					 | 
				
			||||||
            raise
 | 
					            raise
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue