糖尿病康复,内容丰富有趣,生活中的好帮手!
糖尿病康复 > 【组队学习】06.PyTorch的生态

【组队学习】06.PyTorch的生态

时间:2024-08-30 23:50:58

相关推荐

【组队学习】06.PyTorch的生态

PyTorch的主要组成模块

1.PyTorch 中的数据读取

import osimport numpy as npimport torchimport torch.nn as nnfrom torch.utils.data import Dataset, DataLoaderimport torch.optim as optimizer

PyTorch 中的数据读取训练开始的第一步,首先就是数据读取。PyTorch 为我们提供了一种十分方便的数据读取机制,即使用 Dataset 类与 DataLoader 类的组合,来得到数据迭代器。在训练或预测时,数据迭代器能够输出每一批次所需的数据,并且对数据进行相应的预处理与数据增强操作。下面我们分别来看下 Dataset 类与 DataLoader 类。Dataset 类PyTorch 中的 Dataset 类是一个抽象类,它可以用来表示数据集。我们通过继承 Dataset 类来自定义数据集的格式、大小和其它属性,后面就可以供 DataLoader 类直接使用。其实这就表示,无论使用自定义的数据集,还是官方为我们封装好的数据集,其本质都是继承了 Dataset 类。而在继承 Dataset 类时,至少需要重写以下几个方法:__init__():构造函数,可自定义数据读取方法以及进行数据预处理;__len__():返回数据集大小;__getitem__():索引数据集中的某一个数据。

class MyDataset(Dataset):# 构造函数def __init__(self, data_tensor, target_tensor):self.data_tensor = data_tensorself.target_tensor = target_tensor# 返回数据集大小def __len__(self):return self.data_tensor.size(0)# 返回索引的数据与标签def __getitem__(self, index):return self.data_tensor[index], self.target_tensor[index]

DataLoader 类在实际项目中,如果数据量很大,考虑到内存有限、I/O 速度等问题,在训练过程中不可能一次性的将所有数据全部加载到内存中,也不能只用一个进程去加载,所以就需要多进程、迭代加载,而 DataLoader 就是基于这些需要被设计出来的。DataLoader 是一个迭代器,最基本的使用方法就是传入一个 Dataset 对象,它会根据参数 batch_size 的值生成一个 batch 的数据,节省内存的同时,它还可以实现多进程、数据打乱等处理。

# 生成数据data_tensor = torch.randn(10, 3)target_tensor = torch.randint(2, (10,)) # 标签是0或1# 将数据封装成Datasetmy_dataset = MyDataset(data_tensor, target_tensor)# 查看数据集大小print('Dataset size:', len(my_dataset))

Dataset size: 10

tensor_dataloader = DataLoader(dataset=my_dataset, # 传入的数据集, 必须参数batch_size=2, # 输出的batch大小shuffle=True, # 数据是否打乱num_workers=0)# 进程数, 0表示只有主进程# 以循环形式输出for data, target in tensor_dataloader: print(data, target)

tensor([[-0.2421, 0.3404, 0.3580],[ 0.2852, -0.3163, 0.6388]]) tensor([1, 1])tensor([[ 1.0199, -1.2961, -1.2283],[ 0.2080, 0.7002, 0.0106]]) tensor([0, 1])tensor([[-0.4533, 0.3698, 1.1645],[ 1.1769, 0.9668, 0.7595]]) tensor([1, 0])tensor([[-0.2929, 1.4099, -1.1920],[-0.6805, -1.3196, -0.0264]]) tensor([0, 0])tensor([[ 0.0883, 0.7668, -1.2342],[-1.1929, -0.8921, 0.6323]]) tensor([1, 1])

# 输出一个batchprint('One batch tensor data: ', iter(tensor_dataloader).next())

One batch tensor data: [tensor([[ 0.0883, 0.7668, -1.2342],[ 0.2852, -0.3163, 0.6388]]), tensor([1, 1])]

结合代码,我们梳理一下 DataLoader 中的几个参数,它们分别表示:dataset:Dataset 类型,输入的数据集,必须参数;batch_size:int 类型,每个 batch 有多少个样本;shuffle:bool 类型,在每个 epoch 开始的时候,是否对数据进行重新打乱;num_workers:int 类型,加载数据的进程数,0 意味着所有的数据都会被加载进主进程,默认为 0。

2.Torchvision

PyTroch 官方为我们提供了一些常用的图片数据集,如果你需要读取这些数据集,那么无需自己实现,只需要利用 Torchvision 就可以搞定。Torchvision 是一个和 PyTorch 配合使用的 Python 包。它不只提供了一些常用数据集,还提供了几个已经搭建好的经典网络模型,以及集成了一些图像数据处理方面的工具,主要供数据预处理阶段使用。简单地说,Torchvision 库就是常用数据集 + 常见网络模型 + 常用图像处理方法。

# 以MNIST为例import torchvisionmnist_dataset = torchvision.datasets.MNIST(root='./data',train=True,transform=None,target_transform=None,download=True)

torchvision.datasets.MNIST是一个类,对它进行实例化,即可返回一个 MNIST 数据集对象。构造函数包括包含 5 个参数:root:是一个字符串,用于指定你想要保存 MNIST 数据集的位置。如果 download 是 Flase,则会从目标位置读取数据集;download:是布尔类型,表示是否下载数据集。如果为 True,则会自动从网上下载这个数据集,存储到 root 指定的位置。如果指定位置已经存在数据集文件,则不会重复下载;train:是布尔类型,表示是否加载训练集数据。如果为 True,则只加载训练数据。如果为 False,则只加载测试数据集。这里需要注意,并不是所有的数据集都做了训练集和测试集的划分,这个参数并不一定是有效参数,具体需要参考官方接口说明文档;transform:用于对图像进行预处理操作,例如数据增强、归一化、旋转或缩放等。这些操作我们会在下节课展开讲解;target_transform:用于对图像标签进行预处理操作。

mnist_dataset_list = list(mnist_dataset)print(mnist_dataset_list)

IOPub data rate exceeded.The notebook server will temporarily stop sending outputto the client in order to avoid crashing it.To change this limit, set the config variable`--NotebookApp.iopub_data_rate_limit`.Current values:NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)NotebookApp.rate_limit_window=3.0 (secs)

display(mnist_dataset_list[0][0])print("Image label is:", mnist_dataset_list[0][1])

Image label is: 9

3.图像处理

Torchvision 库中的torchvision.transforms包中提供了常用的图像操作,包括对 Tensor 及 PIL Image 对象的操作,例如随机切割、旋转、数据类型转换等等。按照torchvision.transforms 的功能,大致分为以下几类:数据类型转换、对 PIL.Image 和 Tensor 进行变化和变换的组合。下面我们依次来学习这些类别中的操作。

from PIL import Imagefrom torchvision import transforms

img = Image.open("E:\图片\天使.jpg") display(img)print(type(img)) # PIL.Image.Image是PIL.JpegImagePlugin.JpegImageFile的基类

<class 'PIL.JpegImagePlugin.JpegImageFile'>

# PIL.Image转换为Tensorimg1 = transforms.ToTensor()(img)print(type(img1))

<class 'torch.Tensor'>

# Tensor转换为PIL.Imageimg2 = transforms.ToPILImage()(img1) #PIL.Image.Imageprint(type(img2))

<class 'PIL.Image.Image'>

首先用读取图片,查看一下图片的类型为 PIL.JpegImagePlugin.JpegImageFile,这里需要注意,PIL.JpegImagePlugin.JpegImageFile 类是 PIL.Image.Image 类的子类。然后,用transforms.ToTensor() 将 PIL.Image 转换为 Tensor。最后,再将 Tensor 转换回 PIL.Image。

# 定义一个Resize操作resize_img_oper = transforms.Resize((200,200), interpolation=2)# 原图orig_img = Image.open("E:\图片\天使.jpg") display(orig_img)# Resize操作后的图img = resize_img_oper(orig_img)display(img)

C:\ProgramData\Anaconda3\lib\site-packages\torchvision\transforms\transforms.py:332: UserWarning: Argument interpolation should be of type InterpolationMode instead of int. Please, use InterpolationMode enum.warnings.warn(

torchvision.transforms提供了多种剪裁方法,例如中心剪裁、随机剪裁、四角和中心剪裁等。

# 定义剪裁操作center_crop_oper = transforms.CenterCrop((60,70))random_crop_oper = transforms.RandomCrop((80,80))five_crop_oper = transforms.FiveCrop((60,70))# 原图orig_img = Image.open("E:\图片\天使.jpg") display(orig_img)# 中心剪裁img1 = center_crop_oper(orig_img)display(img1)# 随机剪裁img2 = random_crop_oper(orig_img)display(img2)# 四角和中心剪裁imgs = five_crop_oper(orig_img)for img in imgs:display(img)

# 定义翻转操作h_flip_oper = transforms.RandomHorizontalFlip(p=1)v_flip_oper = transforms.RandomVerticalFlip(p=1)# 原图orig_img = Image.open("E:\图片\天使.jpg") display(orig_img)# 水平翻转img1 = h_flip_oper(orig_img)display(img1)# 垂直翻转img2 = v_flip_oper(orig_img)display(img2)

# 定义标准化操作norm_oper = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))# 原图orig_img = Image.open("E:\图片\天使.jpg") display(orig_img)# 图像转化为Tensorimg_tensor = transforms.ToTensor()(orig_img)# 标准化tensor_norm = norm_oper(img_tensor)# Tensor转化为图像img_norm = transforms.ToPILImage()(tensor_norm)display(img_norm)

from PIL import Imagefrom torchvision import transforms # 原图orig_img = Image.open("E:\图片\天使.jpg") display(orig_img)# 定义组合操作composed = pose([transforms.Resize((200, 200)),transforms.RandomCrop(80)])# 组合操作后的图img = composed(orig_img)display(img)

from torchvision import transformsfrom torchvision import datasets# 定义一个transformmy_transform = pose([transforms.ToTensor(),transforms.Normalize((0.5), (0.5))])# 读取MNIST数据集 同时做数据变换mnist_dataset = datasets.MNIST(root='./data',train=False,transform=my_transform,target_transform=None,download=True)# 查看变换后的数据类型item = mnist_dataset.__getitem__(0)print(type(item[0]))

<class 'torch.Tensor'>

如果觉得《【组队学习】06.PyTorch的生态》对你有帮助,请点赞、收藏,并留下你的观点哦!

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。