Merge remote-tracking branch 'steven/master'

This commit is contained in:
Nico Schottelius 2011-10-18 15:14:44 +02:00
commit 4d4287c580
8 changed files with 83 additions and 20 deletions

View file

@ -92,9 +92,10 @@ class Manifest(object):
env = os.environ.copy() env = os.environ.copy()
env.update(self.env) env.update(self.env)
env.update({ env.update({
'__self': cdist_object.name,
'__object': cdist_object.absolute_path, '__object': cdist_object.absolute_path,
'__object_id': cdist_object.object_id, '__object_id': cdist_object.object_id,
'__object_fq': cdist_object.path, '__object_fq': cdist_object.name,
'__type': cdist_object.type.absolute_path, '__type': cdist_object.type.absolute_path,
'__cdist_manifest': script, '__cdist_manifest': script,
}) })

View file

@ -88,8 +88,11 @@ class Object(object):
return self.__class__(self.type.__class__(type_path, type_name), object_path, object_id=object_id) return self.__class__(self.type.__class__(type_path, type_name), object_path, object_id=object_id)
def __init__(self, cdist_type, base_path, object_id=None): def __init__(self, cdist_type, base_path, object_id=None):
if object_id and object_id.startswith('/'): if object_id:
raise IllegalObjectIdError(object_id, 'object_id may not start with /') if object_id.startswith('/'):
raise IllegalObjectIdError(object_id, 'object_id may not start with /')
if '.cdist' in object_id:
raise IllegalObjectIdError(object_id, 'object_id may not contain \'.cdist\'')
self.type = cdist_type # instance of Type self.type = cdist_type # instance of Type
self.base_path = base_path self.base_path = base_path
self.object_id = object_id self.object_id = object_id

View file

@ -158,9 +158,6 @@ class Emulator(object):
requirement_type_name = requirement_parts[0] requirement_type_name = requirement_parts[0]
requirement_object_id = requirement_parts[1] requirement_object_id = requirement_parts[1]
# Instantiate type which fails if type does not exist
requirement_type = core.Type(self.type_base_path, requirement_type_name)
# FIXME: Add support for omitted object id == singleton # FIXME: Add support for omitted object id == singleton
#if len(requirement_parts) == 1: #if len(requirement_parts) == 1:
#except IndexError: #except IndexError:
@ -170,6 +167,11 @@ class Emulator(object):
# Remove / if existent in object id # Remove / if existent in object id
requirement_object_id = requirement_object_id.lstrip('/') requirement_object_id = requirement_object_id.lstrip('/')
# Instantiate type which fails if type does not exist
requirement_type = core.Type(self.type_base_path, requirement_type_name)
# Instantiate object which fails if the object_id is illegal
requirement_object = core.Object(requirement_type, self.object_base_path, requirement_object_id)
# Construct cleaned up requirement with only one / :-) # Construct cleaned up requirement with only one / :-)
requirement = requirement_type_name + '/' + requirement_object_id requirement = requirement_type_name + '/' + requirement_object_id
self.cdist_object.requirements.append(requirement) self.cdist_object.requirements.append(requirement)

View file

@ -73,3 +73,10 @@ class EmulatorTestCase(unittest.TestCase):
os.environ['require'] = '__does-not-exist/some-id' os.environ['require'] = '__does-not-exist/some-id'
emu = emulator.Emulator(argv) emu = emulator.Emulator(argv)
self.assertRaises(core.NoSuchTypeError, emu.run) self.assertRaises(core.NoSuchTypeError, emu.run)
def test_illegal_object_id_requirement(self):
argv = ['__file', '/tmp/foobar']
os.environ.update(self.env)
os.environ['require'] = '__file/bad/id/with/.cdist/inside'
emu = emulator.Emulator(argv)
self.assertRaises(core.IllegalObjectIdError, emu.run)

View file

@ -27,6 +27,8 @@ import shutil
import string import string
import random import random
import logging import logging
import io
import sys
import cdist import cdist
from cdist.exec import local from cdist.exec import local
@ -48,6 +50,8 @@ class ManifestTestCase(unittest.TestCase):
return tempfile.mkstemp(prefix='tmp.cdist.test.', **kwargs) return tempfile.mkstemp(prefix='tmp.cdist.test.', **kwargs)
def setUp(self): def setUp(self):
self.orig_environ = os.environ
os.environ = os.environ.copy()
self.temp_dir = self.mkdtemp() self.temp_dir = self.mkdtemp()
self.target_host = 'localhost' self.target_host = 'localhost'
out_path = self.temp_dir out_path = self.temp_dir
@ -58,17 +62,52 @@ class ManifestTestCase(unittest.TestCase):
self.log = logging.getLogger(self.target_host) self.log = logging.getLogger(self.target_host)
def tearDown(self): def tearDown(self):
os.environ = self.orig_environ
shutil.rmtree(self.temp_dir) shutil.rmtree(self.temp_dir)
def test_initial_manifest_environment(self): def test_initial_manifest_environment(self):
initial_manifest = os.path.join(self.local.manifest_path, "dump_environment") initial_manifest = os.path.join(self.local.manifest_path, "dump_environment")
handle, output_file = self.mkstemp(dir=self.temp_dir)
os.environ['__cdist_test_out'] = output_file
self.manifest.run_initial_manifest(initial_manifest) self.manifest.run_initial_manifest(initial_manifest)
with open(output_file, 'r') as fd:
output_string = fd.read()
output_dict = {}
for line in output_string.split('\n'):
if line:
key,value = line.split(': ')
output_dict[key] = value
self.assertTrue(output_dict['PATH'].startswith(self.local.bin_path))
self.assertEqual(output_dict['__target_host'], self.local.target_host)
self.assertEqual(output_dict['__global'], self.local.out_path)
self.assertEqual(output_dict['__cdist_type_base_path'], self.local.type_path)
self.assertEqual(output_dict['__manifest'], self.local.manifest_path)
def test_type_manifest_environment(self): def test_type_manifest_environment(self):
cdist_type = core.Type(self.local.type_path, '__dump_environment') cdist_type = core.Type(self.local.type_path, '__dump_environment')
cdist_object = core.Object(cdist_type, self.local.object_path, 'whatever') cdist_object = core.Object(cdist_type, self.local.object_path, 'whatever')
handle, output_file = self.mkstemp(dir=self.temp_dir)
os.environ['__cdist_test_out'] = output_file
self.manifest.run_type_manifest(cdist_object) self.manifest.run_type_manifest(cdist_object)
with open(output_file, 'r') as fd:
output_string = fd.read()
output_dict = {}
for line in output_string.split('\n'):
if line:
key,value = line.split(': ')
output_dict[key] = value
self.assertTrue(output_dict['PATH'].startswith(self.local.bin_path))
self.assertEqual(output_dict['__target_host'], self.local.target_host)
self.assertEqual(output_dict['__global'], self.local.out_path)
self.assertEqual(output_dict['__cdist_type_base_path'], self.local.type_path)
self.assertEqual(output_dict['__type'], cdist_type.absolute_path)
self.assertEqual(output_dict['__object'], cdist_object.absolute_path)
self.assertEqual(output_dict['__self'], cdist_object.name)
self.assertEqual(output_dict['__object_id'], cdist_object.object_id)
self.assertEqual(output_dict['__object_fq'], cdist_object.path)
def test_debug_env_setup(self): def test_debug_env_setup(self):
self.log.setLevel(logging.DEBUG) self.log.setLevel(logging.DEBUG)
manifest = cdist.core.manifest.Manifest(self.target_host, self.local) manifest = cdist.core.manifest.Manifest(self.target_host, self.local)

View file

@ -1,7 +1,9 @@
#!/bin/sh #!/bin/sh
echo "PATH: $PATH" cat > $__cdist_test_out << DONE
echo "__target_host: $__target_host" PATH: $PATH
echo "__global: $__global" __target_host: $__target_host
echo "__cdist_type_base_path: $__cdist_type_base_path" __global: $__global
echo "__manifest: $__manifest" __cdist_type_base_path: $__cdist_type_base_path
__manifest: $__manifest
DONE

View file

@ -1,10 +1,13 @@
#!/bin/sh #!/bin/sh
echo "PATH: $PATH" cat > $__cdist_test_out << DONE
echo "__target_host: $__target_host" PATH: $PATH
echo "__global: $__global" __target_host: $__target_host
echo "__cdist_type_base_path: $__cdist_type_base_path" __global: $__global
echo "__type: $__type" __cdist_type_base_path: $__cdist_type_base_path
echo "__object: $__object" __type: $__type
echo "__object_id: $__object_id" __self: $__self
echo "__object_fq: $__object_fq" __object: $__object
__object_id: $__object_id
__object_fq: $__object_fq
DONE

View file

@ -53,12 +53,18 @@ class ObjectClassTestCase(unittest.TestCase):
class ObjectIdTestCase(unittest.TestCase): class ObjectIdTestCase(unittest.TestCase):
def test_illegal_object_id(self): def test_object_id_starts_with_slash(self):
cdist_type = core.Type(type_base_path, '__third') cdist_type = core.Type(type_base_path, '__third')
illegal_object_id = '/object_id/may/not/start/with/slash' illegal_object_id = '/object_id/may/not/start/with/slash'
with self.assertRaises(core.IllegalObjectIdError): with self.assertRaises(core.IllegalObjectIdError):
core.Object(cdist_type, object_base_path, illegal_object_id) core.Object(cdist_type, object_base_path, illegal_object_id)
def test_object_id_contains_dotcdist(self):
cdist_type = core.Type(type_base_path, '__third')
illegal_object_id = 'object_id/may/not/contain/.cdist/anywhere'
with self.assertRaises(core.IllegalObjectIdError):
core.Object(cdist_type, object_base_path, illegal_object_id)
class ObjectTestCase(unittest.TestCase): class ObjectTestCase(unittest.TestCase):