conftest.py 350 B

1234567891011121314151617
  1. import pytest
  2. import torch
  3. @pytest.fixture
  4. def sample_embedding() -> torch.Tensor:
  5. return torch.tensor(
  6. [
  7. [1.0, 0.0, 0.0],
  8. [0.0, 1.0, 0.0],
  9. [0.0, 0.0, 1.0],
  10. [1.0, 0.0, 0.0],
  11. [1.0, 0.0, 0.0],
  12. [1.0, 0.0, 0.0],
  13. ],
  14. dtype=torch.float32,
  15. ) # (6, 3)