Source code for psychsim.modeling

from argparse import ArgumentParser
import csv
import itertools
import logging
import os

from psychsim.pwl import *
from psychsim.world import *

"""
Functions for automatic construction of PsychSim models
"""
[docs]class Domain: """ Structure for representing model-building information @ivar idFields: fields representing unique IDs for each record @type idFields: str[] @ivar filename: root filename for all model-related files @type filename: str @ivar fields: key of mappings from fields of data to model variables @type fields: strS{->}dict @ivar data: table of records with relevant variables @ivar variations: list of dependency variations to explore @ivar models: list of model name codes to explore @ivar targets: set of fields to predict """ def __init__(self,fname,logger=logging.getLogger()): self.logger = logger.getChild('Domain') self.filename = fname # Read variable/field definitions self.idFields = [] self.fields = {} self.targets = set() self.readKey() # Read input data self.data = {} self.readInputData() # What is the hypothesis space? self.readVariations() if self.variations: varLens = [range(link['range']) for link in self.variations] self.models = [self.links2model(model) for model in itertools.product(*varLens)] # What are the previously tested hypotheses? self.readPredictions() self.logger.info('|Unmatched| = %d' % (len(self.unmatched())))
[docs] def unmatched(self): return {ID: record for ID,record in self.data.items() \ if min(map(len,record['__matches__'].values())) == 0}
[docs] def targetHistogram(self,missing=None,data=None): if data is None: data = self.data result = {field: {} for field in self.targets} for ID,record in data.items(): for field in self.targets: value = record[field] if missing is not None and len(value.strip()) == 0: value = missing if not value in result[field]: result[field][value] = set() result[field][value].add(ID) return result
[docs] def recordID(self,record): return ''.join(['%s%s' % (field,record[field]) for field in self.idFields])
[docs] def readPredictions(self,fname=None): if fname is None: fname = '%s-predictions.csv' % (self.filename) if os.path.isfile('%s-predictions.csv' % (args['input'])): with open(fname,'r') as csvfile: reader = csv.DictReader(csvfile) for record in reader: ID = self.recordID(record) if not self.data[ID].has_key('__prediction__'): self.data[ID]['__prediction__'] = {} self.data[ID]['__probability__'] = {} for field in self.targets: self.data[ID]['__prediction__'][field] = {} self.data[ID]['__probability__'][field] = {} if not record.has_key('__target__'): assert len(self.targets) == 1 record['__target__'] = list(self.targets)[0] field = record['__target__'] links = [] for variation in self.variations: code = variation['code'] if record[code] == '': value = None else: value = map(int,list(record[code].split(':'))) links.append(variation['domain'].index(value)) model = self.links2model(links) self.data[ID]['__prediction__'][field][model] = record[field] if record['P(%s)' % (field)]: self.data[ID]['__probability__'][field][model] = float(record['P(%s)' % (field)]) for record in self.data.values(): # Missing predictions, so let's enter empty tables if not record.has_key('__prediction__'): record['__prediction__'] = {} record['__probability__'] = {} for field in self.targets: record['__prediction__'][field] = {} record['__probability__'][field] = {} record['__matches__'] = {} for field in self.targets: record['__matches__'][field] = {m for m in record['__prediction__'][field].keys() if record['__prediction__'][field][m] == record[field]}
[docs] def writePredictions(self,fname=None): if fname is None: fname = '%s-predictions.csv' % (self.filename) with open(filename,'w') as csvfile: fields = None for ID,record in sorted(self.data.items()): for field in sorted(self.targets): for model,prediction in sorted(record['__prediction__'][field].items()): newRecord = {field: record[field] for field in self.idFields} newRecord['__target__'] = field newRecord[field] = prediction try: newRecord['P(%s)' % (field)] = record['__probability__'][field][model] except KeyError: newRecord['P(%s)' % (field)] = '' links = self.model2links(model) for variation in self.variations: code = variation['code'] value = variation['domain'][links[code]] if value is None: newRecord[code] = '' elif len(value) == 1: newRecord[code] = '%d' % (value[0]) else: assert len(value) == 2 newRecord[code] = '%d:%d' % tuple(value) if fields is None: fields = sorted(newRecord.keys()) writer = csv.DictWriter(csvfile,fields,extrasaction='ignore') writer.writeheader() writer.writerow(newRecord)
[docs] def readDataFile(self,fname): data = {} with open(fname) as csvfile: reader = csv.DictReader(csvfile) for record in reader: ID = self.recordID(record) assert not ID in data,'Duplicate ID: %s' % (ID) data[ID] = record return data
[docs] def readInputData(self,fname=None): if fname is None: fname = '%s-input.csv' % (self.filename) if os.path.isfile(fname): self.data = self.readDataFile(fname) else: raw = self.readDataFile('%s-raw.csv' % (self.filename)) self.processData(raw) fields = sorted(next(iter(self.data.values())).keys()) for row in self.data.values(): assert set(row.keys()) == set(fields) with open(fname,'w') as csvfile: writer = csv.DictWriter(csvfile,fields,extrasaction='ignore') writer.writeheader() for ID,record in sorted(self.data.items()): writer.writerow(record)
[docs] def processData(self,raw): """ Takes in raw data and extracts the relevant fields """ if isinstance(raw,dict): raw = raw.values() logger = self.logger.getChild('processData') self.data.clear() for record in raw: ID = self.recordID(record) logger.debug('Processing record: %s' % (ID)) newRecord = {field: record[field] for field in self.idFields} for field,entry in self.fields.items(): if field and not entry['class'] == 'id': assert field in record,'Missing field %s from record %s' % (field,ID) assert entry['variable'],'Field %s has no variable' % (field) newRecord[field] = record[field] self.data[ID] = newRecord
[docs] def readKey(self,fname=None): if fname is None: fname = '%s-key.csv' % (self.filename) with open(fname) as csvfile: reader = csv.DictReader(row for row in csvfile if not row.startswith('#')) for field in reader: if field['class'] == 'id': self.idFields.append(field['field']) else: self.fields[field['field']] = field if len(field['variable']) == 0: field['variable'] = field['field'] if field['target'] == 'yes': self.targets.add(field['field'])
[docs] def readVariations(self,fname=None): if fname is None: fname = '%s-variations.csv' % (self.filename) self.variations = [] if os.path.isfile(fname): # Read in modeling variations from file with open(fname) as csvfile: reader = csv.DictReader(row for row in csvfile if not row.startswith('#')) index = 0 for link in reader: # Read variables involved and possible link values link['from'] = link['from'].split(';') link['domain'] = [None if val == 'None' else map(int,val.split(':')) for val in link['domain'].split(';')] link['range'] = int(link['range']) link['index'] = index self.variations.append(link) # Derive downstream effects link['effects'] = {} link['effects'][link['to']] = link['domain'][:] index += 1 return self.variations
[docs] def links2model(self,links): return ''.join(['%s%s' % (self.variations[i]['code'],links[i]) \ for i in range(len(self.variations))])
[docs]def noisyOrTree(tree,value): if isinstance(tree,dict): return {'if': tree['if'], True: noisyOrTree(tree[True],value), False: noisyOrTree(tree[False],value)} else: return tree*(1.-value)
[docs]def leaf2matrix(tree,key): if isinstance(tree,dict): return {'if': tree['if'], True: leaf2matrix(tree[True],key), False: leaf2matrix(tree[False],key)} else: prob = 1.-tree return {'distribution': [(setTrueMatrix(key),prob),(setFalseMatrix(key),1.-prob)]}
if __name__ == '__main__': logging.basicConfig(level=logging.ERROR) parser = ArgumentParser() parser.add_argument('-d','--debug',default='WARNING',help='Level of logging detail') # Positional argument for input file parser.add_argument('input',nargs='?',default='seattle', help='Root name of CSV files for input/output [default: %(default)s]') args = vars(parser.parse_args()) # Extract logging level from command-line argument level = getattr(logging, args['debug'].upper(), None) if not isinstance(level, int): raise ValueError('Invalid debug level: %s' % args['debug']) logging.getLogger().setLevel(level) domain = Domain(args['input'])