用TextCNN模型解决文本分类问题
创始人
2025-05-29 23:56:13

❤️觉得内容不错的话,欢迎点赞收藏加关注😊😊😊,后续会继续输入更多优质内容❤️

👉有问题欢迎大家加关注私戳或者评论(包括但不限于NLP算法相关,linux学习相关,读研读博相关......)👈

TextCNN

(封面图由文心一格生成)

用TextCNN模型解决文本分类问题

TextCNN模型是一种使用卷积神经网络(CNN)进行文本分类的模型,它可以有效地处理自然语言文本的特征提取和分类任务。在本文中,我们将详细介绍TextCNN模型的原理和实现,并结合一个具体的案例和代码,展示如何使用TextCNN模型来解决文本分类问题。

1. TextCNN模型介绍

TextCNN是一种使用卷积神经网络(CNN)进行文本分类的模型,其基本思路是将文本数据表示成矩阵形式,然后使用卷积层和池化层对矩阵进行处理,提取文本的局部特征,最后将处理后的特征输入到全连接层中进行分类。下面将详细描述TextCNN模型的原理。

(1)文本表示

在TextCNN模型中,首先需要将文本数据表示成矩阵形式。一种常见的方式是将文本表示成词向量的形式,然后将词向量拼接起来组成矩阵。具体地,我们先将文本中的每个单词映射成一个固定长度的词向量,例如使用word2vec或glove等预训练的词向量模型来生成词向量。然后,将所有词向量按照顺序拼接成一个矩阵,如下所示:

X=[x1,1x1,2⋯x1,dx2,1x2,2⋯x2,d⋮⋮⋱⋮xn,1xn,2⋯xn,d]X = \begin{bmatrix} x_{1,1} & x_{1,2} & \cdots & x_{1,d} \\ x_{2,1} & x_{2,2} & \cdots & x_{2,d} \\ \vdots & \vdots & \ddots & \vdots \\ x_{n,1} & x_{n,2} & \cdots & x_{n,d} \end{bmatrix}X=​x1,1​x2,1​⋮xn,1​​x1,2​x2,2​⋮xn,2​​⋯⋯⋱⋯​x1,d​x2,d​⋮xn,d​​

其中,nnn为文本的长度,ddd为词向量的维度,xi,jx_{i,j}xi,j​表示第iii个单词的第jjj个维度的值。这样,每个文本就可以表示成一个矩阵的形式,进一步输入到CNN中进行处理。

(2)卷积层

在TextCNN模型中,卷积层用于提取文本的局部特征。卷积操作可以看作是一个特征检测器,通过对输入矩阵和卷积核进行卷积操作,将输入矩阵中的某些特征提取出来。在TextCNN模型中,我们可以使用多个不同大小的卷积核来提取不同长度的文本特征。

具体地,我们将输入矩阵和一个大小为h×dh\times dh×d的卷积核进行卷积操作,得到一个特征映射CiC_iCi​,其中hhh为卷积核的高度,ddd为词向量的维度。然后,我们将卷积核向右移动一个步长,再对输入矩阵和卷积核进行卷积操作,得到另一个特征映射Ci+1C_{i+1}Ci+1​。如此重复,直到卷积核移动到输入矩阵的右端,最终得到一组特征映射C=[C1,C2,...,Ck]C=[C_1,C_2,...,C_k]C=[C1​,C2​,...,Ck​]。这里,kkk为卷积核的个数,表示使用kkk个不同大小的卷积核进行卷积操作。

卷积操作可以表示为:

Ci=f(w⋅xi+b)C_i = f(\mathbf{w}\cdot \mathbf{x}_i + b)Ci​=f(w⋅xi​+b)

其中,xi\mathbf{x}_ixi​表示输入矩阵的第iii行,w\mathbf{w}w为卷积核参数,bbb为偏置项,fff为激活函数,通常使用ReLU函数。

(3)池化层

在TextCNN模型中,池化层用于对卷积层输出的特征映射进行降维,提取更加重要的特征。一种常见的池化方式是最大池化,即对每个特征映射中的每个通道,取其最大值作为该通道的输出,最终得到一个降维后的特征向量。

具体地,我们将特征映射CCC中的每个特征图分别进行最大池化,得到一个池化后的特征向量p=[p1,p2,...,pk]p=[p_1,p_2,...,p_k]p=[p1​,p2​,...,pk​],其中pip_ipi​为特征映射CiC_iCi​的最大值。这里,kkk为卷积核的个数,与卷积层输出的特征映射个数相同。

(4)全连接层

在TextCNN模型中,全连接层用于将池化层输出的特征向量映射到分类结果的维度。具体地,我们将池化层输出的特征向量ppp输入到一个全连接层中,得到模型的输出结果。

全连接层可以表示为:

y=g(Wp+b)y = g(\mathbf{W}p + \mathbf{b})y=g(Wp+b)

其中,W\mathbf{W}W为全连接层的权重参数,b\mathbf{b}b为偏置项,ggg为激活函数,通常使用softmax函数来将输出结果转化为概率分布。

2. 模型优缺点分析

(1)优点

  • TextCNN模型具有良好的特征提取能力,能够很好地捕捉文本中的上下文信息。

  • 使用卷积神经网络,可以处理不同长度的文本,避免了传统文本分类模型需要对文本进行固定长度截断的问题。

  • 采用了dropout等正则化方法,可以有效避免过拟合问题。

(2) 缺点

  • 对于较长的文本,TextCNN模型的效果可能不如LSTM等循环神经网络模型。

  • TextCNN模型不能很好地处理词序信息,而往往是将整个句子看作一个向量进行处理。

3. 案例与代码

我们以AG数据集为例,使用TextCNN模型进行文本分类。AG数据集是一个新闻分类数据集,共有4个类别:World、Sports、Business、Science/Technology。我们将数据集按照8:1:1的比例分成训练集、验证集和测试集,使用PyTorch框架实现TextCNN模型。

(1)数据集准备

首先,我们下载并解压AG数据集,然后将数据集中的文本和标签分别保存到两个文件中。代码如下:

import os
import torchtext
from torchtext import data, datasets# 下载AG数据集
# 在这个网址下载数据集:https://github.com/mhjabreel/CharCnn_Keras/tree/master/data/ag_news_csv# 定义数据集的Field
TEXT = data.Field(sequential=True, tokenize='spacy', lower=True)
LABEL = data.LabelField(dtype=torch.float)# 加载数据集
train_data, valid_data, test_data = datasets.TabularDataset.splits(path="./ag_news_csv",train="train.csv",validation="valid.csv",test="test.csv",format="csv",fields=[("label", LABEL), ("text", TEXT)],skip_header=True)
# 构建词表
TEXT.build_vocab(train_data, max_size=25000, vectors="glove.6B.100d")
LABEL.build_vocab(train_data)# 定义迭代器
BATCH_SIZE = 64
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits((train_data, valid_data, test_data),batch_size=BATCH_SIZE,device=device,shuffle=True)

在代码中,我们使用了torchtext库来处理数据集,首先定义了数据集的Field,其中将文本数据分词,并使用glove预训练词向量初始化词表。然后,使用TabularDataset类将数据集加载到内存中,并根据8:1:1的比例分成训练集、验证集和测试集。最后,使用BucketIterator定义了三个迭代器,用于模型的训练、验证和测试。

(2) TextCNN模型实现

下面是TextCNN模型的实现代码:

import torch.nn as nn
import torch.nn.functional as Fclass TextCNN(nn.Module):def init(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, dropout):super().init()# 词嵌入层self.embedding = nn.Embedding(vocab_size, embedding_dim)# 卷积层self.convs = nn.ModuleList([nn.Conv2d(in_channels=1, out_channels=n_filters, kernel_size=(fs, embedding_dim)) for fs in filter_sizes])# 全连接层self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim)# dropout层self.dropout = nn.Dropout(dropout)def forward(self, x):# x: [sent len, batch size]# 词嵌入embedded = self.embedding(x)# embedded: [sent len, batch size, emb dim]# 将batch size和channel维度交换embedded = embedded.permute(1, 0, 2)# embedded: [batch size, sent len, emb dim]# 添加channel维度embedded = embedded.unsqueeze(1)# embedded: [batch size, 1, sent len, emb dim]# 卷积池化conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs]# conved: [batch size, n_filters, sent len - filter_sizes[n] + 1]# 最大池化pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]# pooled: [batch size, n_filters]# 拼接cat = self.dropout(torch.cat(pooled, dim=1))# cat: [batch size, n_filters * len(filter_sizes)]# 全连接层output = self.fc(cat)# output: [batch size, output dim]return output

在代码中,我们首先定义了TextCNN类,构建了包括词嵌入层、卷积层、全连接层和dropout层等模块。在forward函数中,我们对输入数据进行词嵌入、卷积和池化操作,将处理后的特征输入到全连接层中进行分类。

(3)训练模型

下面是TextCNN模型的训练代码:

import torch.optim as optim
import time# 模型参数
INPUT_DIM = len(TEXT.vocab)
EMBEDDING_DIM = 100
N_FILTERS = 100
FILTER_SIZES = [3, 4, 5]
OUTPUT_DIM = len(LABEL.vocab)
DROPOUT = 0.5# 模型实例化
model = TextCNN(INPUT_DIM, EMBEDDING_DIM, N_FILTERS, FILTER_SIZES, OUTPUT_DIM, DROPOUT)
model.to(device)# 优化器和损失函数
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()# 训练函数
def train(model, iterator, optimizer, criterion):epoch_loss = 0epoch_acc = 0model.train()for batch in iterator:optimizer.zero_grad()predictions = model(batch.text).squeeze(1)loss = criterion(predictions, batch.label.long())acc = binary_accuracy(predictions, batch.label.long())loss.backward()optimizer.step()epoch_loss += loss.item()epoch_acc += acc.item()return epoch_loss / len(iterator), epoch_acc / len(iterator)# 测试函数
def evaluate(model, iterator, criterion):epoch_loss = 0epoch_acc = 0model.eval()with torch.no_grad():for batch in iterator:predictions = model(batch.text).squeeze(1)loss = criterion(predictions, batch.label.long())acc = binary_accuracy(predictions, batch.label.long())epoch_loss += loss.item()epoch_acc += acc.item()return epoch_loss / len(iterator), epoch_acc / len(iterator)# 准确率函数
def binary_accuracy(preds, y):rounded_preds = torch.round(torch.sigmoid(preds))correct = (rounded_preds == y).float()acc = correct.sum() / len(correct)return acc# 训练模型
N_EPOCHS = 5
best_valid_loss = float('inf')for epoch in range(N_EPOCHS):start_time = time.time()train_loss, train_acc = train(model, train_iterator, optimizer, criterion)valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)end_time = time.time()epoch_mins, epoch_secs = divmod(end_time - start_time, 60)if valid_loss < best_valid_loss:best_valid_loss = valid_losstorch.save(model.state_dict(), 'textcnn-model.pt')print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

在代码中,我们定义了train函数和evaluate函数来分别训练模型和测试模型,同时定义了binary_accuracy函数来计算模型的准确率。在训练过程中,我们使用Adam优化器和交叉熵损失函数,训练模型5个epoch,并保存在验证集上表现最好的模型。

(4)测试模型

最后,我们使用测试集对训练好的TextCNN模型进行测试,并输出模型的准确率。

# 加载模型
model.load_state_dict(torch.load('textcnn-model.pt'))# 测试模型
test_loss, test_acc = evaluate(model, test_iterator, criterion)
print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')

❤️觉得内容不错的话,欢迎点赞收藏加关注😊😊😊,后续会继续输入更多优质内容❤️

👉有问题欢迎大家加关注私戳或者评论(包括但不限于NLP算法相关,linux学习相关,读研读博相关......)👈

相关内容

热门资讯

【实验报告】实验一 图像的... 实验目的熟悉Matlab图像运算的基础——矩阵运算;熟悉图像矩阵的显示方法࿰...
MATLAB | 全网最详细网... 一篇超超超长,超超超全面网络图绘制教程,本篇基本能讲清楚所有绘制要点&#...
大模型落地比趋势更重要,NLP... 全球很多人都开始相信,以ChatGPT为代表的大模型,将带来一场NLP领...
Linux学习之端口、网络协议... 端口:设备与外界通讯交流的出口 网络协议:   网络协议是指计算机通信网...
kuernetes 资源对象分... 文章目录1. pod 状态1.1 容器启动错误类型1.2 ImagePullBackOff 错误1....
STM32实战项目-数码管 程序实现功能: 1、上电后,数码管间隔50ms计数; 2、...
TM1638和TM1639差异... TM1638和TM1639差异说明 ✨本文不涉及具体的单片机代码驱动内容,值针对芯...
Qt+MySql开发笔记:Qt... 若该文为原创文章,转载请注明原文出处 本文章博客地址:https://h...
Java内存模型中的happe... 第29讲 | Java内存模型中的happen-before是什么? Java 语言...
《扬帆优配》算力概念股大爆发,... 3月22日,9股封单金额超亿元,工业富联、鸿博股份、鹏鼎控股分别为3.0...
CF1763D Valid B... CF1763D Valid Bitonic Permutations 题目大意 拱形排列࿰...
SQL语法 DDL、DML、D... 文章目录1 SQL通用语法2 SQL分类3 DDL 数据定义语言3.1 数据库操作3.2 表操作3....
文心一言 VS ChatGPT... 3月16号,百度正式发布了『文心一言』,这是国内公司第一次发布类Chat...
CentOS8提高篇5:磁盘分...        首先需要在虚拟机中模拟添加一块新的硬盘设备,然后进行分区、格式化、挂载等...
Linux防火墙——SNAT、... 目录 NAT 一、SNAT策略及作用 1、概述 SNAT应用环境 SNAT原理 SNAT转换前提条...
部署+使用集群的算力跑CPU密... 我先在开头做一个总结,表达我最终要做的事情和最终环境是如何的,然后我会一...
Uploadifive 批量文... Uploadifive 批量文件上传_uploadifive 多个上传按钮_asing1elife的...
C++入门语法基础 文章目录:1. 什么是C++2. 命名空间2.1 域的概念2.2 命名...
2023年全国DAMA-CDG... DAMA认证为数据管理专业人士提供职业目标晋升规划,彰显了职业发展里程碑及发展阶梯定义...
php实现助记词转TRX,ET... TRX助记词转地址网上都是Java,js或其他语言开发的示例,一个简单的...
【分割数据集操作集锦】毕设记录 1. 按要求将CSV文件转成json文件 有时候一些网络模型的源码会有data.json这样的文件里...
Postman接口测试之断言 如果你看文字部分还是不太理解的话,可以看看这个视频,详细介绍postma...
前端学习第三阶段-第4章 jQ... 4-1 jQuery介绍及常用API导读 01-jQuery入门导读 02-JavaScri...
4、linux初级——Linu... 目录 一、用CRT连接开发板 1、安装CRT调试工具 2、连接开发板 3、开机后ctrl+c...
Urban Radiance ... Urban Radiance Fields:城市辐射场 摘要:这项工作的目标是根据扫描...
天干地支(Java) 题目描述 古代中国使用天干地支来记录当前的年份。 天干一共有十个,分别为:...
SpringBoot雪花ID长... Long类型精度丢失 最近项目中使用雪花ID作为主键,雪花ID是19位Long类型数...
对JSP文件的理解 JSP是java程序。(JSP本质还是一个Servlet) JSP是&#...
【03173】2021年4月高... 一、单向填空题1、大量应用软件开发工具,开始于A、20世纪70年代B、20世纪 80年...
LeetCode5.最长回文子... 目录题目链接题目分析解题思路暴力中心向两边拓展搜索 题目链接 链接 题目分析 简单来说࿰...