PyTorch 1:基础、加载数据


访问作者github: https://github.com/NefelibataBIGR/PyTorch_Notes ,获取笔记代码

  • 学习内容为小土堆b站pytorch课程

一、课程基础

具体代码见Python>pytorch文件中的==basic_lesson.ipynb==文件

  • 两个有用的python函数:dir( )、help( )
    image.png
import torch
torch.cuda.is_available()

输出 True 即表明 cuda 可用

二、加载数据

具体代码见Python>pytorch文件中的 ==load_data.ipynb== 文件

  • Dataset:数据集(含数据+label)
  • Dataloader:加载器(设定每次获取数据的batch_size、shuffle等)

1. Dataset

  • from torch.utils.data import Dataset
  • import os
  • 自己设置一个子类继承Dataset类,要能读取数据
class MyData(Dataset):
    '''
    创建自己的数据集读取类
    '''
    def __init__(self, root_dir, label_dir): # 定义类里面的全局变量,观察到labels就是文件夹的名称
        self.root_dir = root_dir # 根目录,这里是训练集的文件夹路径
        self.label_dir = label_dir # 数据集所在文件夹名字,这里是标签名
        self.path = os.path.join(self.root_dir, self.label_dir) # 获取图片所在文件,这里文件名就是标签
        self.image_path_list = os.listdir(self.path) # 获取所有图片的名字组成列表

    def __getitem__(self, index): # 获取某一个图片,index为索引
        image_name = self.image_path_list[index]
        image_item_path = os.path.join(self.path, image_name)
        image = Image.open(image_item_path)
        label = self.label_dir
        return image, label

    def __len__(self): # 返回数据集的长度
        return len(self.image_path_list)

具体调试过程见代码文件

2. DataLoader

  • from torch.utils.data import DataLoader
  • 数据加载器:指定了怎样从Dataset中取数据
  • 参考官方文档:torch.utils.data — PyTorch 2.6 documentation
  • 参数:
    • dataset:要取数据的数据集
    • batch_size:每次取的数据数量
    • shuffle:顺序是否随机
    • num_workers:多进程,=0为主进程
    • drop_last:无法整除batch_size时是否舍去余项

文章作者: Nefelibata BIGR
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 Nefelibata BIGR !
  目录