diff --git a/wit/meaning_dataset.py b/wit/meaning_dataset.py index b3c89fc..b60be4c 100644 --- a/wit/meaning_dataset.py +++ b/wit/meaning_dataset.py @@ -447,6 +447,14 @@ class BatchGroupMeaningDataloader(Dataset): s += "--------------------------------------------------------\n" return s + def detection_collate(batch): + return batch[0] + + def dataloader(self, num_workers=1): + return DataLoader( + self, batch_size=1, num_workers=num_workers, collate_fn=BatchGroupMeaningDataloader.detection_collate + ) + if __name__ == "__main__":