| 1234567891011121314151617 |
- import pytest
- import torch
- @pytest.fixture
- def sample_embedding() -> torch.Tensor:
- return torch.tensor(
- [
- [1.0, 0.0, 0.0],
- [0.0, 1.0, 0.0],
- [0.0, 0.0, 1.0],
- [1.0, 0.0, 0.0],
- [1.0, 0.0, 0.0],
- [1.0, 0.0, 0.0],
- ],
- dtype=torch.float32,
- ) # (6, 3)
|