python实现基于信息增益的决策树归纳

所属分类: python / 脚本专栏 阅读数: 335
收藏 0 赞 0 分享

本文实例为大家分享了基于信息增益的决策树归纳的Python实现代码,供大家参考,具体内容如下

# -*- coding: utf-8 -*-import numpy as npimport matplotlib.mlab as mlabimport matplotlib.pyplot as pltfrom copy import copy #加载训练数据#文件格式:属性标号,是否连续【yes|no】,属性说明attribute_file_dest = 'F:\\bayes_categorize\\attribute.dat'attribute_file = open(attribute_file_dest) #文件格式:rec_id,attr1_value,attr2_value,...,attrn_value,class_idtrainning_data_file_dest = 'F:\\bayes_categorize\\trainning_data.dat'trainning_data_file = open(trainning_data_file_dest) #文件格式:class_id,class_descclass_desc_file_dest = 'F:\\bayes_categorize\\class_desc.dat'class_desc_file = open(class_desc_file_dest)  root_attr_dict = {
}
for line in attribute_file :  line = line.strip()  fld_list = line.split(',')  root_attr_dict[int(fld_list[0])] = tuple(fld_list[1:]) class_dict = {
}
for line in class_desc_file :  line = line.strip()  fld_list = line.split(',')  class_dict[int(fld_list[0])] = fld_list[1]  trainning_data_dict = {
}
class_member_set_dict = {
}
for line in trainning_data_file :  line = line.strip()  fld_list = line.split(',')  rec_id = int(fld_list[0])  a1 = int(fld_list[1])  a2 = int(fld_list[2])  a3 = float(fld_list[3])  c_id = int(fld_list[4])    if c_id not in class_member_set_dict :    class_member_set_dict[c_id] = set()  class_member_set_dict[c_id].add(rec_id)  trainning_data_dict[rec_id] = (a1 , a2 , a3 , c_id)  attribute_file.close()class_desc_file.close()trainning_data_file.close() class_possibility_dict = {
}
for c_id in class_member_set_dict :  class_possibility_dict[c_id] = (len(class_member_set_dict[c_id]) + 0.0)/len(trainning_data_dict)   #等待分类的数据data_to_classify_file_dest = 'F:\\bayes_categorize\\trainning_data_new.dat'data_to_classify_file = open(data_to_classify_file_dest)data_to_classify_dict = {
}
for line in data_to_classify_file :  line = line.strip()  fld_list = line.split(',')  rec_id = int(fld_list[0])  a1 = int(fld_list[1])  a2 = int(fld_list[2])  a3 = float(fld_list[3])  c_id = int(fld_list[4])  data_to_classify_dict[rec_id] = (a1 , a2 , a3 , c_id)data_to_classify_file.close()    '''决策树的表达结点的需求:1、指示出是哪一种分区 一共3种 一是离散穷举 二是连续有分裂点 三是离散有判别集合 零是叶子结点2、保存分类所需信息3、子结点列表每个结点用Tuple类型表示元素一是整形,取值123 分别对应两种分裂类型元素二是集合类型 对于1保存所有的离散值 对于2保存分裂点 对于3保存判别集合 对于0保存分类结果类标号元素三是dict key对于1来说是某个的离散值 对于23来说只有12两种 对于2来说1代表小于等于分裂点对于3来说1代表属于判别集合'''   #对于一个成员列表,计算其熵#公式为 Info_D = - sum(pi * log2 (pi)) pi为一个元素属于Ci的概率,用|Ci|/|D|计算 ,对所有分类求和def get_entropy( member_list ) :  #成员总数  mem_cnt = len(member_list)  #首先找出member中所包含的分类  class_dict = {
}
  for mem_id in member_list :    c_id = trainning_data_dict[mem_id][3]    if c_id not in class_dict :      class_dict[c_id] = set()    class_dict[c_id].add(mem_id)    tmp_sum = 0.0  for c_id in class_dict :    pi = ( len(class_dict[c_id]) + 0.0 ) / mem_cnt    tmp_sum += pi * mlab.log2(pi)  tmp_sum = -tmp_sum  return tmp_sum     def attribute_selection_method( member_list , attribute_dict ) :  #先计算原始的熵  info_D = get_entropy(member_list)    max_info_Gain = 0.0  attr_get = 0  split_point = 0.0  for attr_id in attribute_dict :    #对于每一个属性计算划分后的熵    #信息增益等于原始的熵减去划分后的熵    info_D_new = 0    #如果是连续属性    if attribute_dict[attr_id][0] == 'yes' :      #先得到memberlist中此属性的取值序列,把序列中每一对相邻项的中值作为划分点计算熵      #找出其中最小的,作为此连续属性的划分点      value_list = []      for mem_id in member_list :        value_list.append(trainning_data_dict[mem_id][attr_id - 1])            #获取相邻元素的中值序列      mid_value_list = []      value_list.sort()      #print value_list      last_value = None      for value in value_list :        if value == last_value :          continue        if last_value is not None :          mid_value_list.append((last_value+value)/2)        last_value = value      #print mid_value_list      #对于中值序列做循环      #计算以此值做为划分点的熵      #总的熵等于两个划分的熵乘以两个划分的比重      min_info = 1000000000.0      total_mens = len(member_list) + 0.0      for mid_value in mid_value_list :        #小于mid_value的mem        less_list = []        #大于        more_list = []        for tmp_mem_id in member_list :          if trainning_data_dict[tmp_mem_id][attr_id - 1] <= mid_value :            less_list.append(tmp_mem_id)          else :            more_list.append(tmp_mem_id)        sum_info = len(less_list)/total_mens * get_entropy(less_list) \        + len(more_list)/total_mens * get_entropy(more_list)                if sum_info < min_info :          min_info = sum_info          split_point = mid_value                info_D_new = min_info    #如果是离散属性    else :      #计算划分后的熵      #采用循环累加的方式      attr_value_member_dict = {
}
 #键为attribute value , 值为memberlist      for tmp_mem_id in member_list :        attr_value = trainning_data_dict[tmp_mem_id][attr_id - 1]        if attr_value not in attr_value_member_dict :          attr_value_member_dict[attr_value] = []        attr_value_member_dict[attr_value].append(tmp_mem_id)      #将每个离散值的熵乘以比重加到这上面      total_mens = len(member_list) + 0.0      sum_info = 0.0      for a_value in attr_value_member_dict :        sum_info += len(attr_value_member_dict[a_value])/total_mens \        * get_entropy(attr_value_member_dict[a_value])            info_D_new = sum_info        info_Gain = info_D - info_D_new    if info_Gain > max_info_Gain :      max_info_Gain = info_Gain      attr_get = attr_id    #如果是离散的  #print 'attr_get ' + str(attr_get)  if attribute_dict[attr_get][0] == 'no' :    return (1 , attr_get , split_point)  else :      return (2 , attr_get , split_point)  #第三类先不考虑 def get_decision_tree(father_node , key , member_list , attr_dict ) :  #最终的结果是新建一个结点,并且添加到father_node的sub_node_dict,对key为键  #检查memberlist 如果都是同类的,则生成一个叶子结点,set里面保存类标号  class_set = set()  for mem_id in member_list :    class_set.add(trainning_data_dict[mem_id][3])  if len(class_set) == 1 :    father_node[2][key] = (0 , (1 , class_set) , {
}
 )    return    #检查attribute_list,如果为空,产生叶子结点,类标号为memberlist中多数元素的类标号  #如果几个类的成员等量,则打印提示,并且全部添加到set里面  if not attr_dict :    class_cnt_dict = {
}
    for mem_id in member_list :      c_id = trainning_data_dict[mem_id][3]      if c_id not in class_cnt_dict :        class_cnt_dict[c_id] = 1      else :        class_cnt_dict[c_id] += 1            class_set = set()    max_cnt = 0    for c_id in class_cnt_dict :      if class_cnt_dict[c_id] > max_cnt :        max_cnt = class_cnt_dict[c_id]        class_set.clear()        class_set.add(c_id)      elif class_cnt_dict[c_id] == max_cnt :        class_set.add(c_id)        if len(class_set) > 1 :      print 'more than one class !'        father_node[2][key] = (0 , (1 , class_set ) , {
}
 )    return    #找出最好的分区方案 , 暂不考虑第三种划分方法  #比较所有离散属性和所有连续属性的所有中值点划分的信息增益  split_criterion = attribute_selection_method(member_list , attr_dict)  #print split_criterion  selected_plan_id = split_criterion[0]  selected_attr_id = split_criterion[1]    #如果采用的是离散属性做为分区方案,删除这个属性  new_attr_dict = copy(attr_dict)  if attr_dict[selected_attr_id][0] == 'no' :    del new_attr_dict[selected_attr_id]    #建立一个结点new_node,father_node[2][key] = new_node  #然后对new node的每一个key , sub_member_list,  #调用 get_decision_tree(new_node , new_key , sub_member_list , new_attribute_dict)  #实现递归  ele2 = ( selected_attr_id , set() )  #如果是1 , ele2保存所有离散值  if selected_plan_id == 1 :    for mem_id in member_list :      ele2[1].add(trainning_data_dict[mem_id][selected_attr_id - 1])  #如果是2,ele2保存分裂点  elif selected_plan_id == 2 :    ele2[1].add(split_criterion[2])  #如果是3则保存判别集合,先不管  else :    print 'not completed'    pass      new_node = ( selected_plan_id , ele2 , {
}
 )  father_node[2][key] = new_node    #生成KEY,并递归调用  if selected_plan_id == 1 :    #每个attr_value是一个key    attr_value_member_dict = {
}
    for mem_id in member_list :      attr_value = trainning_data_dict[mem_id][selected_attr_id - 1 ]      if attr_value not in attr_value_member_dict :        attr_value_member_dict[attr_value] = []      attr_value_member_dict[attr_value].append(mem_id)    for attr_value in attr_value_member_dict :      get_decision_tree(new_node , attr_value , attr_value_member_dict[attr_value] , new_attr_dict)    pass  elif selected_plan_id == 2 :    #key 只有12 , 小于等于分裂点的是1 , 大于的是2    less_list = []    more_list = []    for mem_id in member_list :      attr_value = trainning_data_dict[mem_id][selected_attr_id - 1 ]      if attr_value <= split_criterion[2] :        less_list.append(mem_id)      else :        more_list.append(mem_id)    #if len(less_list) != 0 :    get_decision_tree(new_node , 1 , less_list , new_attr_dict)    #if len(more_list) != 0 :    get_decision_tree(new_node , 2 , more_list , new_attr_dict)    pass  #如果是3则保存判别集合,先不管  else :    print 'not completed'    pass  def get_class_sub(node , tp ) :  #  attr_id = node[1][0]  plan_id = node[0]  key = 0  if plan_id == 0 :    return node[1][1]  elif plan_id == 1 :    key = tp[attr_id - 1]  elif plan_id == 2 :    split_point = tuple(node[1][1])[0]    attr_value = tp[attr_id - 1]    if attr_value <= split_point :      key = 1    else :      key = 2  else :    print 'error'    return set()      return get_class_sub(node[2][key] , tp ) def get_class(r_node , tp) :  #tp为一组属性值  if r_node[0] != -1 :    print 'error'    return set()    if 1 in r_node[2] :    return get_class_sub(r_node[2][1] , tp)  else :    print 'error'    return set()    if __name__ == '__main__' :  root_node = ( -1 , set() , {
}
 )  mem_list = trainning_data_dict.keys()  get_decision_tree(root_node , 1 , mem_list , root_attr_dict )   #测试分类器的准确率  diff_cnt = 0  for mem_id in data_to_classify_dict :    c_id = get_class(root_node , data_to_classify_dict[mem_id][0:3])    if tuple(c_id)[0] != data_to_classify_dict[mem_id][3] :      print tuple(c_id)[0]      print data_to_classify_dict[mem_id][3]      print 'different'      diff_cnt += 1  print diff_cnt

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。

更多精彩内容其他人还在看

Python处理json字符串转化为字典的简单实现

今天一个朋友给个需求: 来来 {'isOK': 1, 'isRunning': None, 'isError': None}怎么转换成字典好,一看就是json转化很简单,开始:import jsona = "{'isOK': 1, 'isRunning': None, 'isEr... 查看详情
收藏 0 赞 0 分享

简单掌握Python的Collections模块中counter结构的用法

counter 是一种特殊的字典,主要方便用来计数,key 是要计数的 item,value 保存的是个数。from collections import Counter>>> c = Counter('hello,world')Counter({'l': 3,... 查看详情
收藏 0 赞 0 分享

详解Python的collections模块中的deque双端队列结构

deque 是 double-ended queue的缩写,类似于 list,不过提供了在两端插入和删除的操作。 appendleft 在列表左侧插入 popleft 弹出列表左侧的值 extendleft 在左侧扩展例如:queue = deque()# append v... 查看详情
收藏 0 赞 0 分享

Python的collections模块中namedtuple结构使用示例

namedtuple 就是命名的 tuple,比较像 C 语言中 struct。一般情况下的 tuple 是 (item1, item2, item3,...),所有的 item 都只能按照 index 访问,没有明确的称呼,而 namedtuple 就是事先把这些 item 命... 查看详情
收藏 0 赞 0 分享

Python的collections模块中的OrderedDict有序字典

如同这个数据结构的名称所说的那样,它记录了每个键值对添加的顺序。d = OrderedDict()d['a'] = 1d['b'] = 10d['c'] = 8for letter in d: print letter输出:    abc如果初始化... 查看详情
收藏 0 赞 0 分享

简介Python的collections模块中defaultdict类型的用法

defaultdict 主要用来需要对 value 做初始化的情形。对于字典来说,key 必须是 hashable,immutable,unique 的数据,而 value 可以是任意的数据类型。如果 value 是 list,dict 等数据类型,在使用之前必须初始化为空,有些... 查看详情
收藏 0 赞 0 分享

Python中的os.path路径模块中的操作方法总结

解析路径路径解析依赖与os中定义的一些变量: os.sep-路径各部分之间的分隔符。 os.extsep-文件名与文件扩展名之间的分隔符。 os.pardir-路径中表示目录树上一级的部分。 os.curdir-路径中当前目录的部分。split()函数将路径分解为两个单独... 查看详情
收藏 0 赞 0 分享

Python中使用platform模块获取系统信息的用法教程

操作系统相关 system() : 操作系统类型(见例) version(): 操作系统版本 release(): 操作系统发布号, 例如win 7返回7, 还有如NT, 2.2.0之类. platform(aliased=0, terse=0): 操作系统信息字符串,扥... 查看详情
收藏 0 赞 0 分享

Python中的FTP通信模块ftplib的用法整理

Python中默认安装的ftplib模块定义了FTP类,其中函数有限,可用来实现简单的ftp客户端,用于上传或下载文件.FTP的工作流程及基本操作可参考协议RFC959.ftp登陆连接from ftplib import FTP #加载ftp模块ftp=FTP() #设置变量ft... 查看详情
收藏 0 赞 0 分享

Python设计足球联赛赛程表程序的思路与简单实现示例

每年意甲德甲英超西甲各大联赛的赛程表都是球迷们的必看之物,想起之前写过的一段生成赛程表的代码,用Python来写这类东西太舒服了。这个算法叫做蛇环算法。即,把所有球队排成一个环形(2列),左边对阵右边,第一支队伍不动,其他队伍顺时针循环,这样就肯定不重复了。为了方便说明,假设有8... 查看详情
收藏 0 赞 0 分享
查看更多