007

    xiaoxiao2022-07-02  110

    神经网络基础

    学习链接:https://www.cnblogs.com/pinard/category/894694.html

    FastText模型进行文本分类

    #coding = utf-8 #Author:Shanv #function: import pandas as pd import numpy as np import datetime import codecs import jieba import fastText from prettytable import PrettyTable #数据读取 def load_data(filePath): #读取文件 f = codecs.open(filePath, 'r', encoding='utf-8') labels = [] contents = [] for line in f.readlines(): try: label, content = line.strip().split('\t') if content: labels.append(label) contents.append(content) except: pass df = pd.DataFrame({'title':labels,'content':contents}) return df def cut_word(content): """结巴分词 :param content: 分词文本 :return: word:分好后的单词数据集(Series类型) """ word = content.map(lambda x: [w for w in list(jieba.cut(x)) if len(w) != 1]) return word def make_fasttext_data(data_cuted, outFile, train=True, label=None): """把分好词后的数据转换成fasttext类型存储 :param label: 文本类别 :param outFile: 保存文件路径 :param data_cuted: 分好词的数据集 """ if train: fw = codecs.open(outFile, 'w', 'utf-8') for i, text in enumerate(data_cuted): outline = ' '.join(text) outline = '__label__' + label[i] + ' ' + outline + '\n' fw.write(outline) fw.flush() else: fw = codecs.open(outFile, 'w', 'utf-8') for i, text in enumerate(data_cuted): outline = ' '.join(text) + '\n' fw.write(outline) fw.flush() def class_predict(classifier, testFile): """文本类别预测 :param classifier: fasttext分类器 :param testFile: 预测文件目录 :return: 预测标签 """ #预测 test_cut = [] with codecs.open(testFile, 'r', 'utf-8') as fr: for line in fr: line = line.rstrip() test_cut.append(line.split('\t')[0]) pred_label = classifier.predict(test_cut)[0] #去掉标签前缀'__label__' pred_label = [label.strip('__label__') for label in pred_label] return pred_label def count_PR(pred_label, real_label): """计算模型分类结果的准确率和召回率 :param pred_label: 预测的分类结果 :param real_label: 真实的分类结果 :return: P:平均准确率,R:平均召回率,F1:平均F1值, score_dict:各个类的准确率,召回率,F1值 """ correct_pred_num = dict.fromkeys(list(set(real_label)), 0)#预测正确的各个类的数目 real_num = dict.fromkeys(list(set(real_label)), 0)#测试集真实标签各个类的数目 pred_num = dict.fromkeys(list(set(pred_label)), 0)#预测结果中各个类的数目 for i, item in enumerate(real_label): real_num[item] += 1 pred_num[pred_label[i]] += 1 if item == pred_label[i]: correct_pred_num[item] += 1 p_list = [] r_list = [] f1_list = [] score_dict = dict.fromkeys(list(set(real_label))) for key in real_num: p = float(correct_pred_num[key]) / float(pred_num[key])#计算准确率precision r = float(correct_pred_num[key]) / float(real_num[key])#计算召回率recall rate f1 = (2 * p * r) / (p + r)#计算F1值 p_list.append(p) r_list.append(r) f1_list.append(f1) score_dict[key] = {'p': p, 'r': r, 'f1': f1} P = sum(p_list) / len(real_num) R = sum(r_list) / len(real_num) F1 = sum(f1_list) / len(real_num) return P, R, F1, score_dict if __name__ == '__main__': startTime = datetime.datetime.now() print('start') train = load_data('cnews.train.txt') test = load_data('cnews.test.txt') tags = train['title'].drop_duplicates().tolist() tag_dict = dict(zip(tags, range(len(tags)))) train['label'] = train['title'].map(tag_dict) test['label'] = test['title'].map(tag_dict) # print(train.head(100)) # print(test.head(100)) all_content = pd.concat([train['content'], test['content']], ignore_index=True) all_content = cut_word(all_content)   #生成fasttext数据文件 make_fasttext_data(all_content[:len(train)], 'new_fasttext_train.txt', label=train['title']) make_fasttext_data(all_content[len(train):], 'new_fasttext_test.txt', train=False,) #训练模型 classifier = fastText.FastText.train_supervised('new_fasttext_train.txt', label='__label__', lr=0.01, dim=300, epoch=8, minCount=1, wordNgrams=5) #保存模型 classifier.save_model('my_fasttext.model') #加载模型 classifier = fastText.FastText.load_model('my_fasttext.model') #预测 print('预测') pred_label = class_predict(classifier, 'new_fasttext_test.txt') #计算准确率,召回率,F1 P, R, F1, score_dict = count_PR(pred_label, test['title']) print('模型准确率:', P) print('模型召回率:', R) print('模型F1值:', F1) print('------各个类的准确率,召回率,F1值为:------') tb = PrettyTable() tb.field_names = ['类别', '准确率', '召回率', 'F1值'] for key, value in score_dict.items(): tb.add_row([key, value['p'], value['r'], value['f1']]) print(tb) endTime = datetime.datetime.now() totalTime = (endTime - startTime).seconds print(startTime, '--------', endTime) print('共消耗%d秒' % totalTime)

    输出:

    哈哈哈    感觉分数有点低哈哈,  特别是家居和财经准确率只有0.2,0.4。。。。。。。。。。。。

    不过调调模型的参数应该可以上点分。

    最新回复(0)