文章目录
- 一、简介
- 二、Datasets基本使用
- 2.1 加载在线数据集(load_dataset)
- 2.2 加载数据集某一项任务(load_dataset)
- 2.3 按照数据集划分进行加载(load_dataset)
- 2.4 查看数据集(index and slice)
- 2.4.1 加载数据
- 2.4.2 查看数据集中的第i条
- 2.4.3 仅查看某字段的的前2条
- 2.4.4 看数据里包含哪几个字段(表头)
- 2.4.5 查看标头有具体的特征
- 2.5 数据集划分(train_test_split)
- 2.5.1 train_test_split函数自动划分
- 2.5.2 按照标签均衡划分
- 2.6 数据选取与过滤(select and filter)
- 2.7 数据映射(map)
- 2.7.1 普通映射
- 2.7.2 多batch映射和多线程映射
- 2.7.3 去除原始字段
- 2.8 保存与加载(save_to_disk/load_from_disk)
- 三、Datasets 加载本地数据集
- 3.1 直接加载文件作为数据集
- 3.2 加载文件夹内全部文件作为数据集
- 3.3 通过预先加载的其他格式 转换加载数据集
- 3.4 通过自定义加载脚本加载数据集
- 3.4.1 自己写加载脚本
- 四、DataCollator
- 4.1 负杂的写法,通过自己写collec funtion或者map
- 4.2 简单的方法,直接运用Datacollator
本文为 https://space.bilibili.com/21060026/channel/collectiondetail?sid=1357748的视频学习笔记
项目地址为:https://github.com/zyds/transformers-code
一、简介
- dataset库是一个非常简单易用的数据集加载库,可以方便快捷的从本地或者HuggingFace Hub加载数据集
- 公开数据集地址: https://huggingface.co/datasets
- 文档地址:https://huggingface.co/docs/datasets/index
二、Datasets基本使用
2.1 加载在线数据集(load_dataset)
尝试一个中文标题数据集 https://huggingface.co/datasets/madao33/new-title-chinese
from datasets import load_dataset
datasets = load_dataset("../../data/madao33/new-title-chinese")
datasets
输出是一个DatasetDict,包含了train和validation,每个数据集包含features(指有哪些字段)
2.2 加载数据集某一项任务(load_dataset)
如glue,不是一个任务,而是一个任务的集合。 数据集地址https://huggingface.co/datasets/nyu-mll/glue
让我们下载到本地:
git clone https://huggingface.co/datasets/nyu-mll/glue
# 这里第二个参数就是具体的数据集
glue_dataset = load_dataset("../../data/glue", "cola")
glue_dataset
输出如下:
2.3 按照数据集划分进行加载(load_dataset)
如2.2,已经将数据集分成train、validation和test,我们也可以指定train
通过split参数
# 仅加载train
dataset = load_dataset("../../data/madao33/new-title-chinese", split="train")
# 加载 10-100条
dataset = load_dataset("../../data/madao33/new-title-chinese", split="train[10:100]")
# 加载前50%,注意只能写百分数,不能写小数
dataset = load_dataset("../../data/madao33/new-title-chinese", split="train[:50%]")
# 加载训练集的50%,验证集的前10%
# 加载训练集的50%,验证集的前10%
dataset = load_dataset("../../data/madao33/new-title-chinese", split=["train[:50%]", "validation[50%:]"])
2.4 查看数据集(index and slice)
2.4.1 加载数据
datasets = load_dataset("../../data/madao33/new-title-chinese")
datasets
2.4.2 查看数据集中的第i条
# 查看数据集中的第0条,以dict返回
datasets["train"][0]
# 查看数据集中的前2条,还是会以Dict的形式返回,每个字段里是一个list
datasets["train"][:2]
输出如下:
2.4.3 仅查看某字段的的前2条
# 仅title字段的的前2条
datasets["train"]["title"][:5]
输出如下:
2.4.4 看数据里包含哪几个字段(表头)
# 看数据里包含哪几个字段
datasets["train"].column_names
2.4.5 查看标头有具体的特征
# 查看标有的具体特征
datasets["train"].features
输出如下:值的类型都是string
2.5 数据集划分(train_test_split)
2.5.1 train_test_split函数自动划分
dataset = datasets["train"]
# 指定划分的比例
dataset.train_test_split(test_size=0.1)
2.5.2 按照标签均衡划分
比如分类问题,希望标签是均衡的
dataset.train_test_split(test_size=0.1, stratify_by_column="label") # 分类数据集可以按照比例划分
2.6 数据选取与过滤(select and filter)
取的数据,还作为一个Dataset。 和查看数据不一样,查看数据返回的是一个Dict
# 选取
datasets["train"].select([0, 1])
# 过滤,过滤后的结果需要一个值来接收,操作不是一个inplace,而是拷贝出一个副本
filter_dataset = datasets["train"].filter(lambda example: "中国" in example["title"])
2.7 数据映射(map)
数据处理一般会结合数据映射的方法,我们可以定义一个函数,对每个数据进行处理。
比如这里,对每条数据增加一个prefix,然后进行返回
def add_prefix(example):example["title"] = 'Prefix: ' + example["title"]return example
# 然后将函数传进来
prefix_dataset = datasets.map(add_prefix)
prefix_dataset["train"][:10]["title"]
一般数据处理会有tokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("../../models/hfl/rbt3")
def preprocess_function(example, tokenizer=tokenizer):model_inputs = tokenizer(example["content"], max_length=512, truncation=True)labels = tokenizer(example["title"], max_length=32, truncation=True)# label就是title编码的结果model_inputs["labels"] = labels["input_ids"]return model_inputs
2.7.1 普通映射
processed_datasets = datasets.map(preprocess_function)
2.7.2 多batch映射和多线程映射
# 多batch映射
processed_datasets = datasets.map(preprocess_function, batched=True)
# 多线程映射
processed_datasets = datasets.map(preprocess_function, num_proc=4)
2.7.3 去除原始字段
# 去除原始字段
processed_datasets = datasets.map(preprocess_function, batched=True, remove_columns=datasets["train"].column_names)
2.8 保存与加载(save_to_disk/load_from_disk)
# 序列化保存到本地
processed_datasets.save_to_disk("./processed_data")
# 反序列化加载
processed_datasets = load_from_disk("./processed_data")
三、Datasets 加载本地数据集
3.1 直接加载文件作为数据集
如果本地数据集比较规范,比如csv文件
# 没有split,返回DataSetDict. 因为本地只有一个文件
dataset = load_dataset("csv", data_files="./ChnSentiCorp_htl_all.csv")
# 没有split,返回DataSetDict. 则返回一个纯粹的dataset
dataset = load_dataset("csv", data_files="./ChnSentiCorp_htl_all.csv", split="train")
# 用from_csv来加载,结果和load_dataset是一样的
dataset = Dataset.from_csv("./ChnSentiCorp_htl_all.csv")
3.2 加载文件夹内全部文件作为数据集
# 加载文件夹中所有的文件
dataset = load_dataset("csv", data_dir=["./all_data/", "./all_data/ChnSentiCorp_htl_all copy.csv"], split='train')
# 加载 指定的多个文件
dataset = load_dataset("csv", data_files=["./all_data/ChnSentiCorp_htl_all.csv", "./all_data/ChnSentiCorp_htl_all copy.csv"], split='train')
3.3 通过预先加载的其他格式 转换加载数据集
如dict、pandas、list
import pandas as pd
data = pd.read_csv("./ChnSentiCorp_htl_all.csv")
data.head()
dataset = Dataset.from_pandas(data)
注意,list不能直接加载,需要明确一下数据字段
# List格式的数据需要内嵌{},明确数据字段
data = [{"text": "abc"}, {"text": "def"}]
# data = ["abc", "def"]
Dataset.from_list(data)
3.4 通过自定义加载脚本加载数据集
# filed 指限定某个字段的数据
load_dataset("json", data_files="./cmrc2018_trial.json", field="data")
3.4.1 自己写加载脚本
# 也可以自己写加载脚本
dataset = load_dataset("./load_script.py", split="train")
load_script的实现
首先需要定义一个类,类继承 datasets.GeneratorBasedBuilder
class CMRC2018TRIAL(datasets.GeneratorBasedBuilder):
,然后实现三个抽象方法
-
_info(self) -> DatasetInfo: info方法, 定义数据集的信息,这里要对数据的字段进行定义:
-
_split_generators(self, dl_manager: DownloadManager): 返回一个list,List里放着
datasets.SplitGenerator
,涉及两个参数:name
和gen_kwargs
name
: 指定数据集的划分
gen_kwargs
: 指定要读取的文件的路径, 与_generate_examples的入参数一致 -
_generate_examples(self, filepath): 生成具体的样本, 使用yield
需要额外指定key, id从0开始自增就可以. 注意,返回的必须是一个tuple。
load_script.py 完整实现如下:
import json
import datasets
from datasets import DownloadManager, DatasetInfoclass CMRC2018TRIAL(datasets.GeneratorBasedBuilder):def _info(self) -> DatasetInfo:"""info方法, 定义数据集的信息,这里要对数据的字段进行定义:return:"""return datasets.DatasetInfo(description="CMRC2018 trial",features=datasets.Features({"id": datasets.Value("string"),"context": datasets.Value("string"),"question": datasets.Value("string"),# answers 包括text和answer_start两个字段,还需要datasets.features.Sequence"answers": datasets.features.Sequence({"text": datasets.Value("string"),"answer_start": datasets.Value("int32"),})}))def _split_generators(self, dl_manager: DownloadManager):"""返回datasets.SplitGenerator涉及两个参数: name和gen_kwargsname: 指定数据集的划分gen_kwargs: 指定要读取的文件的路径, 与_generate_examples的入参数一致:param dl_manager::return: [ datasets.SplitGenerator ]"""return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": "./cmrc2018_trial.json"})]def _generate_examples(self, filepath):"""生成具体的样本, 使用yield需要额外指定key, id从0开始自增就可以:param filepath::return:"""# Yields (key, example) tuples from the datasetwith open(filepath, encoding="utf-8") as f:data = json.load(f)for example in data["data"]:for paragraph in example["paragraphs"]:context = paragraph["context"].strip()for qa in paragraph["qas"]:question = qa["question"].strip()id_ = qa["id"]answer_starts = [answer["answer_start"] for answer in qa["answers"]]answers = [answer["text"].strip() for answer in qa["answers"]]yield id_, {"context": context,"question": question,"id": id_,"answers": {"answer_start": answer_starts,"text": answers,},}
四、DataCollator
4.1 负杂的写法,通过自己写collec funtion或者map
dataset = load_dataset("csv", data_files="./ChnSentiCorp_htl_all.csv", split='train')
dataset = dataset.filter(lambda x: x["review"] is not None) # 去除空数据
def process_function(examples):tokenized_examples = tokenizer(examples["review"], max_length=128, truncation=True)tokenized_examples["labels"] = examples["label"]return tokenized_examples
tokenized_dataset = dataset.map(process_function, batched=True, remove_columns=dataset.column_names)
4.2 简单的方法,直接运用Datacollator
transformers 提供了一些预制的Datacollator
collator = DataCollatorWithPadding(tokenizer=tokenizer)
dl = DataLoader(tokenized_dataset, batch_size=4, collate_fn=collator, shuffle=True)
num = 0
for batch in dl:print(batch["input_ids"].size())num += 1if num > 10:break
输出如下,可以看到一个批次里数据都没有这么长。 这样每个batch都会去最长的字段,可以节省一些计算资源。
但有一点需要注意,如果我们自己的数据里。除了input_ids、token_type_ids、attention_mask、labels外,我们还有其他的自定义的自读含,就不要用官方的 datacollector了(没法对其他字段做padding)。