Skip to main content

AutoGluon 学习笔记

· 9 min read

3 种主要数据类型和对应的预测任务:

  • Tabular
  • MultiModel
  • Time Series

Tabular (表格数据)

它是什么?

数据以行和列的形式组织,就像一张 Excel 表格和 SQL 数据库表。

典型任务

  • 分类(Classification): 预测一个类别。
    • 二分类: 预测客户是否会流失(是/否)
    • 多分类: 预测一件商品的类别(电子产品/服装/食品)
  • 回归(Regression): 预测一个数值
    • 例如: 预测房屋的售价、预测股票明日价格

AutoGluon 中的模块: TabularPredictor

特点: AutoGluon 会自动处理各种类型的数据(数字、类别、文本),进行特征工程,并集成多个机器学习模型(如 LightGBM, XGBoost, 神经网络, K 近邻等),最终给出一个强大的集成模型。

示例代码

查看文件: ./01-tabular/01-tabular.py

MultiModel (多模态数据)

当你的数据不仅仅包含结构化的表格信息,还包含图像、文本等非结构化数据时,就需要用到多模态。

它是什么?

“模态”指的是信息的类型或形式。多模态数据是指同时包含多种类型的数据。

常见组合

  • 表格数据 + 文本: 商品信息(价格、类别) + 商品描述文案。
  • 表格数据 + 图像: 商品信息 + 商品图片。
  • 纯文本: 可以看作是单模态,但 AutoGluon 也通过这个模块处理。
  • 纯图像: 同上。

典型任务

  • 多模态分类/回归: 利用所有可用信息进行预测。例如: 根据商品的描述文案和价格,预测其销量(回归);根据社交媒体帖子的图片和文字,判断其情感(分类)。
  • 自然语言处理 (NLP): 如文本分类、情感分析。
  • 计算机视觉 (CV): 如图像分类。

AutoGluon 中的模块: MultiModalPredictor

特点: 它的底层依赖于深度学习模型,特别是 Transformer 架构(如用于文本的 BERT,用于图像的 ViT)。它能自动为不同的模态选择合适的预训练模型,并将它们的信息融合起来进行预测。对于纯文本或纯图像任务,它也是一个非常强大的工具。

示例代码

from autogluon.multimodal import MultiModalPredictor
import pandas as pd

# 假设数据有两列: ‘image’ (图片路径) 和 ‘text’ (文本),以及目标列 ‘label’
train_data = pd.DataFrame({
'image': ['path/to/img1.jpg', 'path/to/img2.jpg', ...],
'text': ['这是一段文字A', '这是一段文字B', ...],
'label': [0, 1, ...]
})
# 创建多模态预测器
predictor = MultiModalPredictor(label='label')
predictor.fit(train_data)
# 预测
predictions = predictor.predict(test_data)

实际可应用场景

  • Text Classification
  • Image Classification
  • NER:从文本提取实体。
  • Matching:判断文本图像是否相关
  • Object Detection:在图像定位识别物体。

1. Text Classification (文本分类)

  • 它是什么? 这是最基础的文本任务。目标是给一整段文本分配一个或多个预定义的类别标签。你可以把它想象成给文本“贴标签”。

  • 典型情景:

    • 情感分析: 判断一条商品评论是“正面”、“负面”还是“中性”。
    • 垃圾邮件检测: 判断一封邮件是“垃圾邮件”还是“正常邮件”。
    • 新闻主题分类: 将一篇新闻文章归类到“体育”、“财经”、“科技”等版块。
    • 意图识别: 在聊天机器人中,判断用户的问题是想“查询天气”、“订餐”还是“投诉”。
  • 输入输出:

    • 输入: 一段文本(如句子、段落、文章)。
    • 输出: 一个或多个类别标签。

2. Image Classification (图像分类)

  • 它是什么? 这是最基础的计算机视觉任务。目标是识别一张图像整体属于哪个类别。它回答的是“这张图片是什么?”的问题。

  • 典型情景:

    • 物体识别: 识别图片中是“猫”、“狗”还是“汽车”。
    • 场景识别: 判断一张风景照是“海滩”、“森林”还是“城市”。
    • 医疗影像分析: 判断一张 X 光片是否“有肺炎迹象”。
    • 手写数字识别: 识别信封上的邮政编码数字。
  • 输入输出:

    • 输入: 一张图片。
    • 输出: 一个类别标签(代表整个图片的内容)。

3. NER (命名实体识别)

  • 它是什么? 这是一个更精细的文本信息抽取任务。目标是从非结构化的文本中找出并分类具有特定意义的实体(通常是名词)。它回答的是“文本中提到了哪些具体的人、地方、组织等?”的问题。

  • 典型情景:

    • 从新闻中提取信息: 在一句“苹果公司的 CEO蒂姆·库克昨日访问了中国北京。”中,识别出:
      • 苹果 -> 组织
      • 蒂姆·库克 -> 人物
      • 中国 -> 地点
      • 北京 -> 地点
    • 医疗记录处理: 从病历中提取“药物名称”、“疾病名称”、“症状”等实体。
    • 简历解析: 从简历中自动提取“候选人姓名”、“毕业院校”、“工作经历”等。
  • 输入输出:

    • 输入: 一段文本。
    • 输出: 文本中标记出的实体及其类型。

4. Matching (匹配)

  • 它是什么? 这是一个非常广泛的任务,核心是计算两个对象之间的相似度或相关性。它回答的是“这两个东西有多像/有多相关?”的问题。

  • 典型情景:

    • 搜索引擎: 计算用户搜索词“最新智能手机”与成千上万个网页内容之间的相关性,返回最匹配的结果。
    • 推荐系统: 计算用户 A 和用户 B 的相似度(协同过滤),或者计算电影《阿凡达》与用户喜好的匹配度。
    • 问答系统: 在一个问题库中,找到与用户提问“如何烤蛋糕?”最相似的问题及其答案。
    • 人脸验证: 判断两张人脸照片是否属于同一个人(1:1 匹配)。
    • 文本语义匹配: 判断“如何更换轮胎”和“轮胎拆卸安装步骤”这两个句子是否在表达同一个意思。
  • 输入输出:

    • 输入: 两个对象(文本、图片、用户等)。
    • 输出: 一个相似度分数,或一个二元判断(匹配/不匹配)。

5. Object Detection (目标检测)

  • 它是什么? 这是一个比图像分类更复杂的计算机视觉任务。它不仅要识别出图片中有什么物体,还要定位出它们的具体位置。它回答的是“图片里有什么,它们分别在哪儿?”的问题。

  • 典型情景:

    • 自动驾驶: 检测车辆前方的“行人”、“汽车”、“交通标志”的位置,以进行避障和决策。
    • 安防监控: 在监控视频中检测并框出“闯入者”、“车辆”。
    • 医疗影像分析: 在 CT 扫描片中定位和识别“肿瘤”的位置和大小。
    • 图片内容分析: 在一张街景照片中,同时框出多个“行人”、“交通灯”、“汽车”。
  • 输入输出:

    • 输入: 一张图片。
    • 输出: 一组边界框 以及每个框对应的物体类别置信度
      • 边界框: 一个矩形,用坐标表示 [x, y, 宽度, 高度]
      • 例如: [框1:汽车, 置信度0.95], [框2:行人, 置信度0.87]

总结与对比

任务领域核心问题输入示例输出示例
Text ClassificationNLP“这段文本是什么?”“这部电影太棒了!”正面情感
Image ClassificationCV“这张图片是什么?”一张猫的图片
NERNLP“文本里提到了谁/什么?”“马云在杭州创立了阿里巴巴。”(马云, 人物),(杭州, 地点),(阿里巴巴, 组织)
Matching通用“这两个东西像不像?”用户查询,商品描述相似度:0.82
Object DetectionCV“图里有什么,在哪儿?”一张街景图片[框出的汽车],[框出的行人]

一个生动的场景来区分它们:

假设你开发一个智能相册 App。

  • Image Classification 用来给整个相册分类:这张是“风景”,那张是“美食”。
  • Object Detection 用在某张“风景”照里:识别出图中有“山”、“湖”、“人”、“船”,并用框标出它们的位置。
  • 你对这张照片写了段描述:“我和李明西湖划船。”
  • NER 会从描述中提取出:(李明, 人物)(西湖, 地点)
  • Text Classification 会判断你这段描述的情感是“开心的”。
  • Matching 的功能是:当你搜索“划船”时,App 能通过语义匹配,找到这张照片和你的描述。

Time Series (时间序列数据)

这种数据的特点是,数据点按时间顺序排列,且相邻数据点之间通常存在相关性(即序列依赖)。

它是什么?

数据点是在连续时间或固定时间间隔上收集的。

核心特征:

  • 时间索引 (Time Index): 每个数据点都有一个时间戳(如 '2024-01-01', '2024-01-02')。
  • 序列依赖性 (Temporal Dependence): 今天的值通常与昨天、上周的值有关。
  • 趋势和季节性 (Trend & Seasonality): 长期上升/下降趋势,或周期性规律(如每日、每周、每年)。

典型任务

  • 单变量时间序列预测: 只预测一个变量。例如: 根据历史销售额预测未来销售额。
  • 多变量时间序列预测: 有多个相关的时间序列,它们可以互相提供信息。例如: 根据历史的气温、湿度、风速来预测未来的降水量。
  • 概率预测: 不仅预测未来的值,还给出预测值的不确定性区间。

AutoGluon 中的模块: TimeSeriesPredictor

特点: 专门为处理时间序列的特性而设计。它集成了从经典模型(如 ETS, ARIMA)到现代机器学习模型(如 LightGBM Temporal Fusion Transformer)的各种方法。它强调对时间索引的处理、滞后特征的创建以及处理多个相关序列。

示例代码

from autogluon.timeseries import TimeSeriesDataFrame, TimeSeriesPredictor

# 加载数据,数据必须包含 ‘item_id’ (序列ID) 和 ‘timestamp’ (时间戳) 列
train_data = TimeSeriesDataFrame.from_path("train.csv")
# 创建预测器,指定预测长度 prediction_length=28
predictor = TimeSeriesPredictor(prediction_length=28).fit(train_data)
# 预测
predictions = predictor.predict(train_data)

三类对比

类型核心数据形式典型任务AutoGluon 模块关键技术
Tabular结构化的行和列分类,回归TabularPredictor梯度提升树,特征工程,模型集成
Multimodal文本、图像、表格混合多模态分类/回归,NLP,CVMultiModalPredictor深度学习,Transformer (BERT, ViT)
Time Series带时间戳的序列单/多变量预测,概率预测TimeSeriesPredictor时序模型 (ETS, ARIMA),滞后特征,TFT

如何选择?

  1. 如果你的数据是标准的 Excel 表格,没有图片、长文本列,也没有严格的时间顺序要求 -> 选择 Tabular
  2. 如果你的数据包含大段的商品描述、评论、或者图片路径,并且这些信息对预测很重要 -> 选择 Multimodal
  3. 如果你的数据是按时间顺序记录的(如每日销售额、每小时气温),并且你的目标是预测未来的值 -> 选择 Time Series

Resources