Back

torch中Dataset的构造与解读

torch中Dataset的构造与解读

Dataset的构造

要自定义自己的数据集,首先需要继承Dataset(torch.utils.data.Dataset)类.

继承Dataset类之后,必须重写三个方法:__init__(), __getitem__(), __len__()

class ModelNet40(Dataset):
    def __init__(self, xxx):
        ...
    
    def __getitem__(self, item):
        ...

    def __len()__(self):
        ...

解读

单看上面的构造结构与三个需要重写的方法可能会一头雾水。我们详细分析其作用:

  1. __init__的作用 __init__的作用与所有构造函数都一样,初始化一个类的实例。定义类的实际属性,如点云数据集中的unseen, guassian_noise等,是True还是False, 取出所有数据存储为成员变量等等。

  2. __getitem__的作用 __getitem__的作用是,根据item的值取出数据。 item实际上就是索引值,会由Dataloader自动从0一直递增到__len__中取出的值。

  3. __len__的作用 __len__的作用是,相当于返回整体数据data的shape[0], 即给item的递增指定一个范围。

例子

class ModelNet40(Dataset):
    def __init__(self, num_points, partition='train', gaussian_noise=False, unseen=False, factor=4):
        self.data, self.label = load_data(partition)
        self.num_points = num_points
        self.partition = partition
        self.gaussian_noise = gaussian_noise
        self.unseen = unseen
        self.label = self.label.squeeze()
        self.factor = factor
        if self.unseen:
            ######## simulate testing on first 20 categories while training on last 20 categories
            if self.partition == 'test':
                self.data = self.data[self.label>=20]
                self.label = self.label[self.label>=20]
            elif self.partition == 'train':
                self.data = self.data[self.label<20]
                self.label = self.label[self.label<20]

    def __getitem__(self, item):
        pointcloud = self.data[item][:self.num_points]          # 核心代码,就是用item取出的数据
        if self.gaussian_noise:
            pointcloud = jitter_pointcloud(pointcloud)
        if self.partition != 'train':
            np.random.seed(item)
        anglex = np.random.uniform() * np.pi / self.factor
        angley = np.random.uniform() * np.pi / self.factor
        anglez = np.random.uniform() * np.pi / self.factor

        cosx = np.cos(anglex)
        cosy = np.cos(angley)
        cosz = np.cos(anglez)
        sinx = np.sin(anglex)
        siny = np.sin(angley)
        sinz = np.sin(anglez)
        Rx = np.array([[1, 0, 0],
                        [0, cosx, -sinx],
                        [0, sinx, cosx]])
        Ry = np.array([[cosy, 0, siny],
                        [0, 1, 0],
                        [-siny, 0, cosy]])
        Rz = np.array([[cosz, -sinz, 0],
                        [sinz, cosz, 0],
                        [0, 0, 1]])
        R_ab = Rx.dot(Ry).dot(Rz)
        R_ba = R_ab.T
        translation_ab = np.array([np.random.uniform(-0.5, 0.5), np.random.uniform(-0.5, 0.5),
                                   np.random.uniform(-0.5, 0.5)])
        translation_ba = -R_ba.dot(translation_ab)

        pointcloud1 = pointcloud.T

        rotation_ab = Rotation.from_euler('zyx', [anglez, angley, anglex])
        pointcloud2 = rotation_ab.apply(pointcloud1.T).T + np.expand_dims(translation_ab, axis=1)

        euler_ab = np.asarray([anglez, angley, anglex])
        euler_ba = -euler_ab[::-1]

        pointcloud1 = np.random.permutation(pointcloud1.T).T
        pointcloud2 = np.random.permutation(pointcloud2.T).T
        print(item)
        print(pointcloud1.shape)
        return pointcloud1.astype('float32'), pointcloud2.astype('float32'), R_ab.astype('float32'), \
               translation_ab.astype('float32'), R_ba.astype('float32'), translation_ba.astype('float32'), \
               euler_ab.astype('float32'), euler_ba.astype('float32')

    def __len__(self):
        return self.data.shape[0]       # 给item一个范围

进一步理解其执行逻辑

if __name__ == '__main__':
    dataset1 = ModelNet40(num_points=1024, partition='train', gaussian_noise=True)
    dataloader = DataLoader(dataset1, batch_size=64, shuffle=False)
    count = 0
    for src_pointcloud, tgt_pointcloud, Rotation, translation, _, _, _, _ in dataloader:
        print(src_pointcloud.shape)
        count += 1
    print(count)

首先需要说明的是,在ModelNet40中,getitem中会打印item的当前值。

如果执行这段代码,在shuffle=False的情况下,其结果为:

item从0一直增加到__len__返回的那个值-1, 也就是data的第一维(姑且称为batch维)。

在getitem中取出的pointcloud的shape为(3, 1024),只有两个axis.

而最后输出的count,也就是main函数中整个for循环执行的次数,会是__len__() / batch_size.

比如len是9480,即self.data的shape为(9480, 2048, 3),那么item就会从0一直增加到9479. 在batch_size为64的情况下,for循环一共执行(即count为) 9480/64 = 148.125, 那么最终会执行149次。 也就是说,每次for循环实质上调用了getitem方法64次,最后在第一维上stack,使之shape变为(64, 3, 1024).

Licensed under CC BY-NC-SA 4.0
comments powered by Disqus
Built with Hugo
Theme Stack designed by Jimmy