神经网络基础
学习链接:https://www.cnblogs.com/pinard/category/894694.html
FastText模型进行文本分类
#coding = utf-8
#Author:Shanv
#function:
import pandas as pd
import numpy as np
import datetime
import codecs
import jieba
import fastText
from prettytable import PrettyTable
#数据读取
def load_data(filePath):
#读取文件
f = codecs.open(filePath, 'r', encoding='utf-8')
labels = []
contents = []
for line in f.readlines():
try:
label, content = line.strip().split('\t')
if content:
labels.append(label)
contents.append(content)
except:
pass
df = pd.DataFrame({'title':labels,'content':contents})
return df
def cut_word(content):
"""结巴分词
:param content: 分词文本
:return: word:分好后的单词数据集(Series类型)
"""
word = content.map(lambda x: [w for w in list(jieba.cut(x)) if len(w) != 1])
return word
def make_fasttext_data(data_cuted, outFile, train=True, label=None):
"""把分好词后的数据转换成fasttext类型存储
:param label: 文本类别
:param outFile: 保存文件路径
:param data_cuted: 分好词的数据集
"""
if train:
fw = codecs.open(outFile, 'w', 'utf-8')
for i, text in enumerate(data_cuted):
outline = ' '.join(text)
outline = '__label__' + label[i] + ' ' + outline + '\n'
fw.write(outline)
fw.flush()
else:
fw = codecs.open(outFile, 'w', 'utf-8')
for i, text in enumerate(data_cuted):
outline = ' '.join(text) + '\n'
fw.write(outline)
fw.flush()
def class_predict(classifier, testFile):
"""文本类别预测
:param classifier: fasttext分类器
:param testFile: 预测文件目录
:return: 预测标签
"""
#预测
test_cut = []
with codecs.open(testFile, 'r', 'utf-8') as fr:
for line in fr:
line = line.rstrip()
test_cut.append(line.split('\t')[0])
pred_label = classifier.predict(test_cut)[0]
#去掉标签前缀'__label__'
pred_label = [label.strip('__label__') for label in pred_label]
return pred_label
def count_PR(pred_label, real_label):
"""计算模型分类结果的准确率和召回率
:param pred_label: 预测的分类结果
:param real_label: 真实的分类结果
:return: P:平均准确率,R:平均召回率,F1:平均F1值,
score_dict:各个类的准确率,召回率,F1值
"""
correct_pred_num = dict.fromkeys(list(set(real_label)), 0)#预测正确的各个类的数目
real_num = dict.fromkeys(list(set(real_label)), 0)#测试集真实标签各个类的数目
pred_num = dict.fromkeys(list(set(pred_label)), 0)#预测结果中各个类的数目
for i, item in enumerate(real_label):
real_num[item] += 1
pred_num[pred_label[i]] += 1
if item == pred_label[i]:
correct_pred_num[item] += 1
p_list = []
r_list = []
f1_list = []
score_dict = dict.fromkeys(list(set(real_label)))
for key in real_num:
p = float(correct_pred_num[key]) / float(pred_num[key])#计算准确率precision
r = float(correct_pred_num[key]) / float(real_num[key])#计算召回率recall rate
f1 = (2 * p * r) / (p + r)#计算F1值
p_list.append(p)
r_list.append(r)
f1_list.append(f1)
score_dict[key] = {'p': p, 'r': r, 'f1': f1}
P = sum(p_list) / len(real_num)
R = sum(r_list) / len(real_num)
F1 = sum(f1_list) / len(real_num)
return P, R, F1, score_dict
if __name__ == '__main__':
startTime = datetime.datetime.now()
print('start')
train = load_data('cnews.train.txt')
test = load_data('cnews.test.txt')
tags = train['title'].drop_duplicates().tolist()
tag_dict = dict(zip(tags, range(len(tags))))
train['label'] = train['title'].map(tag_dict)
test['label'] = test['title'].map(tag_dict)
# print(train.head(100))
# print(test.head(100))
all_content = pd.concat([train['content'], test['content']], ignore_index=True)
all_content = cut_word(all_content)
#生成fasttext数据文件
make_fasttext_data(all_content[:len(train)], 'new_fasttext_train.txt', label=train['title'])
make_fasttext_data(all_content[len(train):], 'new_fasttext_test.txt', train=False,)
#训练模型
classifier = fastText.FastText.train_supervised('new_fasttext_train.txt', label='__label__',
lr=0.01, dim=300, epoch=8,
minCount=1, wordNgrams=5)
#保存模型
classifier.save_model('my_fasttext.model')
#加载模型
classifier = fastText.FastText.load_model('my_fasttext.model')
#预测
print('预测')
pred_label = class_predict(classifier, 'new_fasttext_test.txt')
#计算准确率,召回率,F1
P, R, F1, score_dict = count_PR(pred_label, test['title'])
print('模型准确率:', P)
print('模型召回率:', R)
print('模型F1值:', F1)
print('------各个类的准确率,召回率,F1值为:------')
tb = PrettyTable()
tb.field_names = ['类别', '准确率', '召回率', 'F1值']
for key, value in score_dict.items():
tb.add_row([key, value['p'], value['r'], value['f1']])
print(tb)
endTime = datetime.datetime.now()
totalTime = (endTime - startTime).seconds
print(startTime, '--------', endTime)
print('共消耗%d秒' % totalTime)
输出:
哈哈哈 感觉分数有点低哈哈, 特别是家居和财经准确率只有0.2,0.4。。。。。。。。。。。。
不过调调模型的参数应该可以上点分。