基于Bert和通用句子编码的Spark-NLP文本分类( 二 )


我们开始写代码吧!
声明加载必要的包并启动一个Spark会话 。
import sparknlpspark = sparknlp.start() # sparknlp.start(gpu=True) >> 在GPU上训练from sparknlp.base import *from sparknlp.annotator import *from pyspark.ml import Pipelineimport pandas as pdprint("Spark NLP version", sparknlp.version())print("Apache Spark version:", spark.version)>> Spark NLP version 2.4.5>> Apache Spark version: 2.4.4然后我们可以从Github repo下载AGNews数据集(https://github.com/JohnSnowLabs/spark-nlp-workshop/tree/master/tutorials/Certification_Trainings/Public) 。
! wget https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Public/data/news_category_train.csv! wget https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Public/data/news_category_test.csvtrainDataset = spark.read.option("header", True).csv("news_category_train.csv")trainDataset.show(10, truncate=50)>> +--------+--------------------------------------------------+|category|description|+--------+--------------------------------------------------+|Business| Short sellers, Wall Street's dwindling band of...||Business| Private investment firm Carlyle Group, which h...||Business| Soaring crude prices plus worries about the ec...||Business| Authorities have halted oil export flows from ...||Business| Tearaway world oil prices, toppling records an...||Business| Stocks ended slightly higher on Friday but sta...||Business| Assets of the nation's retail money market mut...||Business| Retail sales bounced back a bit in July, and n...||Business|" After earning a PH.D. in Sociology, Danny Baz...||Business| Short sellers, Wall Street's dwindlingband o...|+--------+--------------------------------------------------+only showing top 10 rowsAGNews数据集有4个类:World、Sci/Tech、Sports、Business
from pyspark.sql.functions import coltrainDataset.groupBy("category").count().orderBy(col("count").desc()).show()>>+--------+-----+|category|count|+--------+-----+|World|30000||Sci/Tech|30000||Sports|30000||Business|30000|+--------+-----+testDataset = spark.read.option("header", True).csv("news_category_test.csv")testDataset.groupBy("category").count().orderBy(col("count").desc()).show()>>+--------+-----+|category|count|+--------+-----+|Sci/Tech| 1900||Sports| 1900||World| 1900||Business| 1900|+--------+-----+现在,我们可以将这个数据提供给Spark NLP DocumentAssembler,它是任何Spark datagram的Spark NLP的入口点 。
# 实际内容在description列document = DocumentAssembler().setInputCol("description").setOutputCol("document")#我们可以下载预先训练好的嵌入use = UniversalSentenceEncoder.pretrained() .setInputCols(["document"]) .setOutputCol("sentence_embeddings")# classes/labels/categories 在category列classsifierdl = ClassifierDLApproach().setInputCols(["sentence_embeddings"]).setOutputCol("class").setLabelColumn("category").setMaxEpochs(5).setEnableOutputLogs(True)use_clf_pipeline = Pipeline(stages = [document,use,classsifierdl])以上,我们获取数据集,输入,然后从使用中获取句子嵌入,然后在ClassifierDL中进行训练
现在我们开始训练 。我们将使用ClassiferDL中的.setMaxEpochs()训练5个epoch 。在Colab环境下,这大约需要10分钟才能完成 。
use_pipelineModel = use_clf_pipeline.fit(trainDataset)运行此命令时,Spark NLP会将训练日志写入主目录中的annotator_logs文件夹 。下面是得到的日志 。

基于Bert和通用句子编码的Spark-NLP文本分类

文章插图
 
如你所见,我们在不到10分钟的时间内就实现了90%以上的验证精度,而无需进行文本预处理,这通常是任何NLP建模中最耗时、最费力的一步 。
现在让我们在最早的时候得到预测 。我们将使用上面下载的测试集 。
基于Bert和通用句子编码的Spark-NLP文本分类

文章插图
 
下面是通过sklearn库中的classification_report获得测试结果 。
基于Bert和通用句子编码的Spark-NLP文本分类

文章插图
 
我们达到了89.3%的测试集精度!看起来不错!
基于Bert和globe嵌入的Spark-NLP文本预处理分类与任何文本分类问题一样,有很多有用的文本预处理技术,包括词干、词干分析、拼写检查和停用词删除,而且除了拼写检查之外,Python中几乎所有的NLP库都有应用这些技术的工具 。目前,Spark NLP库是唯一一个具备拼写检查功能的可用NLP库 。
让我们在Spark NLP管道中应用这些步骤,然后使用glove嵌入来训练文本分类器 。我们将首先应用几个文本预处理步骤(仅通过保留字母顺序进行标准化,删除停用词字和词干化),然后获取每个标记的单词嵌入(标记的词干),然后平均每个句子中的单词嵌入以获得每行的句子嵌入 。


推荐阅读