初识Pytorch(一) -- Transforms笔记

初识Pytorch(一) -- Transforms笔记
山河忽晚测试torch是否安装成功
1 | import torch |
蚂蚁蜜蜂分类数据集和下载连接:https://download.pytorch.org/tutorial/hymenoptera_data.zip
1 函数功能查看
1 | dir(torch.cuda) 打开,看见 |
python中__call__
的用法
- 在类中双下划线表示是类的内置函数
- 有
__call__
的情况下,调用对象时,可以直接传参到call中 - 没有
__call__
的情况下,调用对象需要加上“.
”来调用其中的方法
**python中__getitem__的用法**
- 在进行索引取值时自动调用,可以查看原对象中
__getitem__
的return值
不知道返回值时
print()
print(type())
- 打断点 debug
2 torch数据加载(读取)
pytorch官网文档地址:https://docs.pytorch.org
1 | from torch.utils.data import Dataset |
2.1 查看Dataset说明
说明文档查看:
1 | # help(Dataset) |
2.2 初始化类操作
1 | class MyData(Dataset): |
2.3 查看数据
1 | root_dir = './dataset/train' |
3 图片数据读取
将图片数据读取为其他格式:(torch.Tensor
,
numpy.ndarray
, string/blobname
)
3.1 PIL读取图片
PIL读取图片后,得到的数据类型为PIL.Image.Image
1 | from PIL import Image |
3.2 Numpy转换PIL格式
利用numpy.array()
,对PIL图片进行转换
转换后的图片类型为numpy.ndarray
1 | import numpy as np |
3.3 Opencv打开图片
利用Opencv读取图片,获取numpy型图片数据
安装Opencv:pip install opencv-python
1 | import cv2 |
4 TensorBoard 使用
TensorBoard的安装:pip install tensorboard
TensorBoard用来显示模型训练到xx步时,模型的output是什么样,如损失函数、训练结果
启动命令:tensorboard --logdir=logs --port=6007
logdir相对路径logs,port手动设置为6007
4.1 .add_scalar() 用法
1 | from torch.utils.tensorboard import SummaryWriter |
在终端输入:tensorboard --logdir=logs --port=6007
打开网页查看训练数据
4.2 .add_image() 用法
从PIL到numpy,需要在add_image()
中指定shape中每一个数字/维表示的含义
1 | from torch.utils.tensorboard import SummaryWriter |
5 Transforms 使用
transform基本上是对图片进行变化,使用方式是把一些特定格式的图片,通过Transform工具输出为我们想要的结果。
具体使用方式直接查看transforms.py文件中相关方法的使用说明。
- 使用某方法时,首先关注它的输入和输出类型
- 多看transforms.py文件内容
- 关注方法需要什么参数
5.1 Transform 的结构和用法
使用pycharm左侧工具栏的结构功能查看transforms.py的结构
可以把transforms.py看成一个工具箱,里面有各种工具如totensor(把一些数据类型转化为tensor类型)、resize、、、
1 | from torchvision import transforms |
首先要搞懂tensor数据类型,通过Transforms.ToTensor去解决两个问题
- transforms该如何使用(python)
- 为什么需要Tensor数据类型
- a.tensor数据类型是包装了神经网络理论基础的参数
1 | img_path = 'dataset/train/ants/6240338_93729615ec.jpg' |
tensor_img是一个tensor的数据类型,它的参数:
_backward_hooks
:神经网络中的反向传播,根据结果对参数进行调整_grad
:梯度_grad_fn
:梯度的方法data
:图片的具体数据
5.2 使用tensorboard记录
1 | from torch.utils.tensorboard import SummaryWriter |
5.3 常见的Transforms
需关注transforms.py中各函数的
- 输入——PIL ——
Image.open()
- 输出——tensor ——
ToTensor()
- 作用——narrays ——
cv.imread()
5.3.1 Compose类
把不同的transform结合在一起,比如有张图片要处理,经过compose类时,首先进行一个中心的裁剪,再转为tensor数据类型…
Compose()
中的参数是一个列表,列表内容是transforms类型
1 | Example: |
5.3.2 ToTensor类
输入必须时PIL Image 或 ndarray 三通道图片
输出为Tensor数据类型
1 | from PIL import Image |
5.3.3 Normalize类
归一化一个 tensor 类型的 image,根据它的均值和标准差
数据 = (输入 - 均值) / 标准差
1 | # Normalize |
5.3.4 Resize类
给定尺寸进行缩放,如果只给定了一个数,Resize就会根据最小的边去等比缩放
1 | # Resize |
5.3.5 Compose - Resize
应用Compose 将Resize步骤整合,并使用tensorboard记录日志
1 | # Compose - Resize |
5.3.6 RandomCrop类
随机裁剪,输入一个序列如(h, w)
或一个整数值size
,会按输入裁剪为一个(h,w)
或一个正方形(size, size)
1 | # RandomCrop |
5.4 结合 Datasets 使用
pytorch官方提供的数据集:https://docs.pytorch.org/vision/stable/datasets.html
如CIFAR-10 dataset,需要设置参数:
1 | torchvision.datasets.CIFAR10( |
Parameters:
- root (str or pathlib.Path) – Root directory of dataset where directory cifar-10-batches-py exists or will be saved to if download is set to True.
- train (bool, optional) – If True, creates dataset from training set, otherwise creates from test set.
- transform (callable, optional) – A function/transform that takes in a PIL image and returns a transformed version. E.g, transforms.RandomCrop
- target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
- download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
示例代码
1 | import torchvision |
6 Dataloader
dataloader是一个数据加载器,把数据加载到神经网络中,通过参数设置如何去dataset中取数据。
官网说明:https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
1 | torch.utils.data.DataLoader( |
- dataset (Dataset) :自定义的dataset,告诉我们数据集在什么位置,以及具体数据的索引等等。
- batch_size (int, optional) :每次读取数据的量
- shuffle (bool, optional) :是否打乱数据顺序,为True则进行打乱操作,默认为False
- num_workers (int, optional):加载数据时使用单个进程或多个进程,多个进程更快,默认情况是0。不为0时在Windows环境中可能会报错
- drop_last (bool, optional) :当数据量总数除不尽batch_size时,为True则最后的数据舍去,默认False
1 | import torchvision |
查看dataloader的数据
1 | for data in test_loader: |
使用tensorboard记录日志
1 | from torch.utils.tensorboard import SummaryWriter |