22 lines
479 B
Python
22 lines
479 B
Python
from torch.utils.data import Dataset
|
|
import torch
|
|
|
|
|
|
class MyDataset(Dataset):
|
|
"""
|
|
|
|
"""
|
|
|
|
def __init__(self, input_list, max_len):
|
|
self.input_list = input_list
|
|
self.max_len = max_len
|
|
|
|
def __getitem__(self, index):
|
|
input_ids = self.input_list[index]
|
|
input_ids = input_ids[:self.max_len]
|
|
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
|
return input_ids
|
|
|
|
def __len__(self):
|
|
return len(self.input_list)
|