Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

how to save torchtext Dataset?

I'm working with text and use torchtext.data.Dataset. Creating the dataset takes a considerable amount of time. For just running the program this is still acceptable. But I would like to debug the torch code for the neural network. And if python is started in debug mode, the dataset creation takes roughly 20 minutes (!!). That's just to get a working environment where I can debug-step through the neural network code.

I would like to save the Dataset, for example with pickle. This sample code is taken from here, but I removed everything that is not necessary for this example:

from torchtext import data
from fastai.nlp import *

PATH = 'data/aclImdb/'

TRN_PATH = 'train/all/'
VAL_PATH = 'test/all/'
TRN = f'{PATH}{TRN_PATH}'
VAL = f'{PATH}{VAL_PATH}'

TEXT = data.Field(lower=True, tokenize="spacy")

bs = 64;
bptt = 70

FILES = dict(train=TRN_PATH, validation=VAL_PATH, test=VAL_PATH)
md = LanguageModelData.from_text_files(PATH, TEXT, **FILES, bs=bs, bptt=bptt, min_freq=10)

with open("md.pkl", "wb") as file:
    pickle.dump(md, file)

To run the code, you need the aclImdb dataset, it can be downloaded from here. Extract it into a data/ folder next to this code snippet. The code produces an error in the last line, where pickle is used:

Traceback (most recent call last):
  File "/home/lhk/programming/fastai_sandbox/lesson4-imdb2.py", line 27, in <module>
    pickle.dump(md, file)
TypeError: 'generator' object is not callable

The samples from fastai often use dill instead of pickle. But that doesn't work for me either.

like image 773
lhk Avatar asked Oct 16 '22 10:10

lhk


2 Answers

I came up with the following functions for myself:

import dill
from pathlib import Path

import torch
from torchtext.data import Dataset

def save_dataset(dataset, path):
    if not isinstance(path, Path):
        path = Path(path)
    path.mkdir(parents=True, exist_ok=True)
    torch.save(dataset.examples, path/"examples.pkl", pickle_module=dill)
    torch.save(dataset.fields, path/"fields.pkl", pickle_module=dill)

def load_dataset(path):
    if not isinstance(path, Path):
        path = Path(path)
    examples = torch.load(path/"examples.pkl", pickle_module=dill)
    fields = torch.load(path/"fields.pkl", pickle_module=dill)
    return Dataset(examples, fields)

Not that actual objects could be a bit different, for example, if you save TabularDataset, then load_dataset returns an instance of class Dataset. This unlikely affect the data pipeline but may require extra diligence for tests. In the case of a custom tokenizer, it should be serializable as well (e.g. no lambda functions, etc).

like image 100
Nikita Avatar answered Nov 01 '22 14:11

Nikita


You can use dill instead of pickle. It works for me. You can save a torchtext Field like

TEXT = data.Field(sequential=True, tokenize=tokenizer, lower=True,fix_length=200,batch_first=True)
with open("model/TEXT.Field","wb")as f:
     dill.dump(TEXT,f)

And load a Field like

with open("model/TEXT.Field","rb")as f:
     TEXT=dill.load(f)

The offical code suppport is under development,you can follow https://github.com/pytorch/text/issues/451 and https://github.com/pytorch/text/issues/73 .

like image 39
chj Avatar answered Nov 01 '22 15:11

chj