torch中Dataset的构造与解读
Dataset的构造
要自定义自己的数据集,首先需要继承Dataset(torch.utils.data.Dataset)
类.
继承Dataset
类之后,必须重写三个方法:__init__(), __getitem__(), __len__()
class ModelNet40(Dataset):
def __init__(self, xxx):
...
def __getitem__(self, item):
...
def __len()__(self):
...
解读
单看上面的构造结构与三个需要重写的方法可能会一头雾水。我们详细分析其作用:
-
__init__的作用 __init__的作用与所有构造函数都一样,初始化一个类的实例。定义类的实际属性,如点云数据集中的
unseen, guassian_noise
等,是True
还是False
, 取出所有数据存储为成员变量等等。 -
__getitem__的作用 __getitem__的作用是,根据item的值取出数据。 item实际上就是索引值,会由Dataloader自动从0一直递增到__len__中取出的值。
-
__len__的作用 __len__的作用是,相当于返回整体数据data的shape[0], 即给item的递增指定一个范围。
例子
class ModelNet40(Dataset):
def __init__(self, num_points, partition='train', gaussian_noise=False, unseen=False, factor=4):
self.data, self.label = load_data(partition)
self.num_points = num_points
self.partition = partition
self.gaussian_noise = gaussian_noise
self.unseen = unseen
self.label = self.label.squeeze()
self.factor = factor
if self.unseen:
######## simulate testing on first 20 categories while training on last 20 categories
if self.partition == 'test':
self.data = self.data[self.label>=20]
self.label = self.label[self.label>=20]
elif self.partition == 'train':
self.data = self.data[self.label<20]
self.label = self.label[self.label<20]
def __getitem__(self, item):
pointcloud = self.data[item][:self.num_points] # 核心代码,就是用item取出的数据
if self.gaussian_noise:
pointcloud = jitter_pointcloud(pointcloud)
if self.partition != 'train':
np.random.seed(item)
anglex = np.random.uniform() * np.pi / self.factor
angley = np.random.uniform() * np.pi / self.factor
anglez = np.random.uniform() * np.pi / self.factor
cosx = np.cos(anglex)
cosy = np.cos(angley)
cosz = np.cos(anglez)
sinx = np.sin(anglex)
siny = np.sin(angley)
sinz = np.sin(anglez)
Rx = np.array([[1, 0, 0],
[0, cosx, -sinx],
[0, sinx, cosx]])
Ry = np.array([[cosy, 0, siny],
[0, 1, 0],
[-siny, 0, cosy]])
Rz = np.array([[cosz, -sinz, 0],
[sinz, cosz, 0],
[0, 0, 1]])
R_ab = Rx.dot(Ry).dot(Rz)
R_ba = R_ab.T
translation_ab = np.array([np.random.uniform(-0.5, 0.5), np.random.uniform(-0.5, 0.5),
np.random.uniform(-0.5, 0.5)])
translation_ba = -R_ba.dot(translation_ab)
pointcloud1 = pointcloud.T
rotation_ab = Rotation.from_euler('zyx', [anglez, angley, anglex])
pointcloud2 = rotation_ab.apply(pointcloud1.T).T + np.expand_dims(translation_ab, axis=1)
euler_ab = np.asarray([anglez, angley, anglex])
euler_ba = -euler_ab[::-1]
pointcloud1 = np.random.permutation(pointcloud1.T).T
pointcloud2 = np.random.permutation(pointcloud2.T).T
print(item)
print(pointcloud1.shape)
return pointcloud1.astype('float32'), pointcloud2.astype('float32'), R_ab.astype('float32'), \
translation_ab.astype('float32'), R_ba.astype('float32'), translation_ba.astype('float32'), \
euler_ab.astype('float32'), euler_ba.astype('float32')
def __len__(self):
return self.data.shape[0] # 给item一个范围
进一步理解其执行逻辑
if __name__ == '__main__':
dataset1 = ModelNet40(num_points=1024, partition='train', gaussian_noise=True)
dataloader = DataLoader(dataset1, batch_size=64, shuffle=False)
count = 0
for src_pointcloud, tgt_pointcloud, Rotation, translation, _, _, _, _ in dataloader:
print(src_pointcloud.shape)
count += 1
print(count)
首先需要说明的是,在ModelNet40中,getitem中会打印item的当前值。
如果执行这段代码,在shuffle=False
的情况下,其结果为:
item从0一直增加到__len__
返回的那个值-1, 也就是data的第一维(姑且称为batch维)。
在getitem中取出的pointcloud的shape为(3, 1024),只有两个axis.
而最后输出的count,也就是main函数中整个for循环执行的次数,会是__len__() / batch_size
.
比如len是9480,即self.data的shape为(9480, 2048, 3),那么item就会从0一直增加到9479. 在batch_size为64的情况下,for循环一共执行(即count为) 9480/64 = 148.125, 那么最终会执行149次。 也就是说,每次for循环实质上调用了getitem方法64次,最后在第一维上stack,使之shape变为(64, 3, 1024).