查看生成的决策树:
In [2]: clf.tree
Out[2]:
{'tearRate': {'normal': {'astigmatic': {'no': {'age': {'pre': 'soft',
'presbyopic': {'prescript': {'hyper': 'soft', 'myope': 'no lenses'}},
'young': 'soft'}},
'yes': {'prescript': {'hyper': {'age': {'pre': 'no lenses',
'presbyopic': 'no lenses',
'young': 'hard'}},
'myope': 'hard'}}}},
'reduced': 'no lenses'}}
可视化决策树
直接通过嵌套字典表示决策树对人来说不好理解,我们需要借助可视化工具可视化树结构,这里我将使用Graphviz来可视化树结构。为此实现了讲字典表示的树生成Graphviz Dot文件内容的函数,大致思想就是递归获取整棵树的所有节点和连接节点的边然后将这些节点和边生成Dot格式的字符串写入文件中并绘图。
递归获取树的节点和边,其中使用了uuid给每个节点添加了id属性以便将相同属性的节点区分开.
def get_nodes_edges(self, tree=None, root_node=None):
''' 返回树中所有节点和边
'''
Node = namedtuple('Node', ['id', 'label'])
Edge = namedtuple('Edge', ['start', 'end', 'label'])
if tree is None:
tree = self.tree
if type(tree) is not dict:
return [], []
nodes, edges = [], []
if root_node is None:
label = list(tree.keys())[0]
root_node = Node._make([uuid.uuid4(), label])
nodes.append(root_node)
for edge_label, sub_tree in tree[root_node.label].items():
node_label = list(sub_tree.keys())[0] if type(sub_tree) is dict else sub_tree
sub_node = Node._make([uuid.uuid4(), node_label])
nodes.append(sub_node)
edge = Edge._make([root_node, sub_node, edge_label])
edges.append(edge)
sub_nodes, sub_edges = self.get_nodes_edges(sub_tree, root_node=sub_node)
nodes.extend(sub_nodes)
edges.extend(sub_edges)
return nodes, edges
生成dot文件内容
def dotify(self, tree=None):
''' 获取树的Graphviz Dot文件的内容
'''
if tree is None:
tree = self.tree
content = 'digraph decision_tree {\n'
nodes, edges = self.get_nodes_edges(tree)
for node in nodes:
content += ' "{}" [label="{}"];\n'.format(node.id, node.label)
for edge in edges:
start, label, end = edge.start, edge.label, edge.end
content += ' "{}" -> "{}" [label="{}"];\n'.format(start.id, end.id, label)
content += '}'
return content
隐形眼镜数据生成Dot文件内容如下:
digraph decision_tree {
"959b4c0c-1821-446d-94a1-c619c2decfcd" [label="call"];
"18665160-b058-437f-9b2e-05df2eb55661" [label="to"];
"2eb9860d-d241-45ca-85e6-cbd80fe2ebf7" [label="your"];
"bcbcc17c-9e2a-4bd4-a039-6e51fde5f8fd" [label="areyouunique"];
"ca091fc7-8a4e-4970-9ec3-485a4628ad29" [label="02073162414"];
"aac20872-1aac-499d-b2b5-caf0ef56eff3" [label="ham"];
"18aa8685-a6e8-4d76-bad5-ccea922bb14d" [label="spam"];
"3f7f30b1-4dbb-4459-9f25-358ad3c6d50b" [label="spam"];
"44d1f972-cd97-4636-b6e6-a389bf560656" [label="spam"];
"7f3c8562-69b5-47a9-8ee4-898bd4b6b506" [label="i"];
"a6f22325-8841-4a81-bc04-4e7485117aa1" [label="spam"];
"c181fe42-fd3c-48db-968a-502f8dd462a4" [label="ldn"];
"51b9477a-0326-4774-8622-24d1d869a283" [label="ham"];
"16f6aecd-c675-4291-867c-6c64d27eb3fc" [label="spam"];
"adb05303-813a-4fe0-bf98-c319eb70be48" [label="spam"];
"959b4c0c-1821-446d-94a1-c619c2decfcd" -> "18665160-b058-437f-9b2e-05df2eb55661" [label="0"];
"18665160-b058-437f-9b2e-05df2eb55661" -> "2eb9860d-d241-45ca-85e6-cbd80fe2ebf7" [label="0"];
"2eb9860d-d241-45ca-85e6-cbd80fe2ebf7" -> "bcbcc17c-9e2a-4bd4-a039-6e51fde5f8fd" [label="0"];
"bcbcc17c-9e2a-4bd4-a039-6e51fde5f8fd" -> "ca091fc7-8a4e-4970-9ec3-485a4628ad29" [label="0"];
"ca091fc7-8a4e-4970-9ec3-485a4628ad29" -> "aac20872-1aac-499d-b2b5-caf0ef56eff3" [label="0"];
"ca091fc7-8a4e-4970-9ec3-485a4628ad29" -> "18aa8685-a6e8-4d76-bad5-ccea922bb14d" [label="1"];
"bcbcc17c-9e2a-4bd4-a039-6e51fde5f8fd" -> "3f7f30b1-4dbb-4459-9f25-358ad3c6d50b" [label="1"];
"2eb9860d-d241-45ca-85e6-cbd80fe2ebf7" -> "44d1f972-cd97-4636-b6e6-a389bf560656" [label="1"];
"18665160-b058-437f-9b2e-05df2eb55661" -> "7f3c8562-69b5-47a9-8ee4-898bd4b6b506" [label="1"];
"7f3c8562-69b5-47a9-8ee4-898bd4b6b506" -> "a6f22325-8841-4a81-bc04-4e7485117aa1" [label="0"];
"7f3c8562-69b5-47a9-8ee4-898bd4b6b506" -> "c181fe42-fd3c-48db-968a-502f8dd462a4" [label="1"];
"c181fe42-fd3c-48db-968a-502f8dd462a4" -> "51b9477a-0326-4774-8622-24d1d869a283" [label="0"];
"c181fe42-fd3c-48db-968a-502f8dd462a4" -> "16f6aecd-c675-4291-867c-6c64d27eb3fc" [label="1"];
"959b4c0c-1821-446d-94a1-c619c2decfcd" -> "adb05303-813a-4fe0-bf98-c319eb70be48" [label="1"];
}
这样我们便可以使用Graphviz将决策树绘制出来
with open('lenses.dot', 'w') as f:
dot = clf.tree.dotify()
f.write(dot)
dot -Tgif lenses.dot -o lenses.gif
效果如下:
使用生成的决策树进行分类
对未知数据进行预测,主要是根据树中的节点递归的找到叶子节点即可。z这里可以通过为递归进行优化,代码实现如下:
def classify(self, data_vect, feat_names=None, tree=None):
''' 根据构建的决策树对数据进行分类
'''
if tree is None:
tree = self.tree
if feat_names is None:
feat_names = self.feat_names
# Recursive base case.
if type(tree) is not dict:
return tree
feature = list(tree.keys())[0]
value = data_vect[feat_names.index(feature)]
sub_tree = tree[feature][value]
return self.classify(feat_names, data_vect, sub_tree)
决策树的存储
通过字典表示决策树,这样我们可以通过内置的pickle或者json模块将其存储到硬盘上,同时也可以从硬盘中读取树结构,这样在数据集很大的时候可以节省构建决策树的时间.
def dump_tree(self, filename, tree=None):
''' 存储决策树
'''
if tree is None:
tree = self.tree
with open(filename, 'w') as f:
pickle.dump(tree, f)
def load_tree(self, filename):
''' 加载树结构
'''
with open(filename, 'r') as f:
tree = pickle.load(f)
self.tree = tree
return tree
总结
本文一步步实现了决策树的实现, 其中使用了ID3算法确定最佳划分属性,并通过Graphviz可视化了构建的决策树。
参考:
《Machine Learning in Action》
数据挖掘系列(6)决策树分类算法
评论
查看更多