import sys, math ### takes a filename as input and returns a list of lists, one for each row ### of the data. ### Note: both this function and the one below demonstrate how the use ### of list comprehensions and return values can make your code more ### compact and efficient. ### This is more efficient than creating a bunch of temporary lists, ### as we let the python interpreter manage the creation and ### modification of lists, rather than doing it ourselves. ### The resulting code is more dense, and a bit harder to read at ### first, but has the advantage of not having a bunch of temp1, temp2 ### variables scattered throughout. ### To get a feel for what it's doing try running pieces of it from ### the python interpreter. def getData(filename) : try : return [[i.strip() for i in line.strip().split(',')] for line in file(filename) if len(line) > 1 and not line.startswith('#')] except IOError : print 'Unable to open',filename sys.exit(0) ### takes a filename as input and returns a list of tuples containing ### each attribute name and a list of its possible values. ### assumption: ## denotes a comment in the data def getAttributes(filename) : try : return [(item[0].strip(), [i.strip() for i in item[1].split(',')]) for item in [line.strip().split(':') for line in file(filename) if len(line) > 1 and not line.startswith('#')]] except IOError : print 'Unable to open',filename sys.exit(0) ### unique - a helper function that takes as input a list and returns ### a new list with duplicates removed. def unique(inlist) : s = set(inlist) return list(s) ### takes as input a list of attribute values of the form [v1,v2,...,vn]. ### Returns a float indicating the entropy in this data. def entropy(data) : ### you complete this. ### remainder - this is the amount of entropy left in the data after ### splitting on an attribute. Assume the input is of ### the form: [(value1, class1), (value2, class2), ..., (valuen, ### classn)] def remainder(data) : ### you complete this ### selectAttribute - ### choose the index of the attribute in the current dataset that ### minimizes the remainder. ### data is in the form [[a1, a2, ..., c1], [b1,b2,...,c2], ... ] ### where the a's and b's are attributes and the c's are classifications. def selectAttribute(data, attributes) : ### you complete this. ### a TreeNode is an object that has either: ### 1. An attribute to be tested and a set of children; one for each possible ### value of the attribute. ### 2. A value (if it is a leaf in a tree) class TreeNode : def __init__(self, attribute, value) : self.attribute = attribute self.value = value self.children = {} def __repr__(self) : if self.attribute : return self.attribute else : return self.value ### a node with no children is a leaf def isLeaf(self) : return self.children == {} ### the input will be: ### data - an object to classify - [v1, v2, ..., vn] ### the attribute list produced by getAttributes: ### [(a1, [v1, v2]), (a2, [v3,v4]), ...] def classify(self, data, attributes) : ### you do this part. ### a tree is simply a data structure composed of nodes. The root of the tree ### is itself a node, so we don't need a separate 'Tree' class. We ### just need a function that takes in a dataset and a list of ### attributes, builds a tree, and returns the root node. ### makeTree is a recursive function. Our base case is that our ### dataset has entropy 0 - no further tests have to be made. There ### are two other degenerate base cases: when there is no more data to ### use, and when we have no data for a particular value. In this case ### we use either default value or majority value. ### The recursive step is to select the attribute that most increases ### the gain and split on that. ### assume: input looks like this: ### dataset: [[v1, v2, ..., vn, c1], [v1,v2, ..., c2] ... ] ### attributes: [(a1, [v1,v2, ..., vn]), (a2, [v1,v2, ...,vm]), ...] ### returns a TreeNode def makeTree(dataset, attributes, defaultValue) : ### you do this part. ### you might also find it helpful to have a print function. You should also ### write a main method that processes the command-line args and does one of ### two things: ### 1. train - call makeTree to construct a decision tree and use pickle ### (or cPickle) to save the tree to a file. ### 2. test. Read the tree in from a file and call classify on each member ### of the test set. ### 'Usage: dt.py {-train | -test} filename'