dataset.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import random
  2. from functools import partial
  3. from datasets import IterableDataset, interleave_datasets, load_dataset
  4. from datasets.distributed import split_dataset_by_node
  5. from torch.distributed import get_rank, get_world_size, is_initialized
  6. def encode(examples, tokenizer, max_length=512):
  7. # Random choice a 512 token window for each example
  8. texts = []
  9. for text in examples["text"]:
  10. if len(text) <= max_length:
  11. texts.append(text)
  12. else:
  13. start = random.randint(0, len(text) - max_length)
  14. texts.append(text[start : start + max_length])
  15. data = tokenizer(
  16. texts,
  17. truncation=True,
  18. padding="max_length",
  19. max_length=max_length,
  20. return_tensors="pt",
  21. )
  22. data["labels"] = data["input_ids"].clone()
  23. data["labels"][data["attention_mask"] == 0] = -100
  24. return data
  25. def build_dataset(tokenizer, max_length=512):
  26. en_dataset = load_dataset("uonlp/CulturaX", "en", split="train", streaming=True)
  27. ja_dataset = load_dataset("uonlp/CulturaX", "ja", split="train", streaming=True)
  28. zh_dataset = load_dataset("uonlp/CulturaX", "zh", split="train", streaming=True)
  29. multilingual_dataset: IterableDataset = interleave_datasets(
  30. [en_dataset, ja_dataset, zh_dataset], probabilities=[0.4, 0.3, 0.3], seed=42
  31. )
  32. # DDP
  33. if is_initialized():
  34. multilingual_dataset = split_dataset_by_node(
  35. multilingual_dataset,
  36. rank=get_rank(),
  37. world_size=get_world_size(),
  38. )
  39. multilingual_dataset = multilingual_dataset.shuffle(seed=42, buffer_size=10000)
  40. multilingual_dataset = multilingual_dataset.map(
  41. partial(encode, tokenizer=tokenizer, max_length=max_length),
  42. batched=True,
  43. remove_columns=multilingual_dataset.column_names,
  44. )
  45. return multilingual_dataset
  46. if __name__ == "__main__":
  47. dataset = build_dataset()
  48. print(list(dataset.take(16)))