-
Notifications
You must be signed in to change notification settings - Fork 2.7k
Expand file tree
/
Copy pathutils.py
More file actions
137 lines (100 loc) · 4.36 KB
/
utils.py
File metadata and controls
137 lines (100 loc) · 4.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import json
import torch
import torch.distributed as dist
from typing import List, Union, Optional, Tuple, Mapping, Dict
def save_json_to_file(objects: Union[List, dict], path: str, line_by_line: bool = False):
if line_by_line:
assert isinstance(objects, list), 'Only list can be saved in line by line format'
with open(path, 'w', encoding='utf-8') as writer:
if not line_by_line:
json.dump(objects, writer, ensure_ascii=False, indent=4, separators=(',', ':'))
else:
for obj in objects:
writer.write(json.dumps(obj, ensure_ascii=False, separators=(',', ':')))
writer.write('\n')
def move_to_cuda(sample):
if len(sample) == 0:
return {}
def _move_to_cuda(maybe_tensor):
if torch.is_tensor(maybe_tensor):
return maybe_tensor.cuda(non_blocking=True)
elif isinstance(maybe_tensor, dict):
return {key: _move_to_cuda(value) for key, value in maybe_tensor.items()}
elif isinstance(maybe_tensor, list):
return [_move_to_cuda(x) for x in maybe_tensor]
elif isinstance(maybe_tensor, tuple):
return tuple([_move_to_cuda(x) for x in maybe_tensor])
elif isinstance(maybe_tensor, Mapping):
return type(maybe_tensor)({k: _move_to_cuda(v) for k, v in maybe_tensor.items()})
else:
return maybe_tensor
return _move_to_cuda(sample)
def dist_gather_tensor(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
if t is None:
return None
t = t.contiguous()
all_tensors = [torch.empty_like(t) for _ in range(dist.get_world_size())]
dist.all_gather(all_tensors, t)
all_tensors[dist.get_rank()] = t
all_tensors = torch.cat(all_tensors, dim=0)
return all_tensors
@torch.no_grad()
def select_grouped_indices(scores: torch.Tensor,
group_size: int,
start: int = 0) -> torch.Tensor:
assert len(scores.shape) == 2
batch_size = scores.shape[0]
assert batch_size * group_size <= scores.shape[1]
indices = torch.arange(0, group_size, dtype=torch.long)
indices = indices.repeat(batch_size, 1)
indices += torch.arange(0, batch_size, dtype=torch.long).unsqueeze(-1) * group_size
indices += start
return indices.to(scores.device)
def full_contrastive_scores_and_labels(
query: torch.Tensor,
key: torch.Tensor,
use_all_pairs: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
assert key.shape[0] % query.shape[0] == 0, '{} % {} > 0'.format(key.shape[0], query.shape[0])
train_n_passages = key.shape[0] // query.shape[0]
labels = torch.arange(0, query.shape[0], dtype=torch.long, device=query.device)
labels = labels * train_n_passages
# batch_size x (batch_size x n_psg)
qk = torch.mm(query, key.t())
if not use_all_pairs:
return qk, labels
# batch_size x dim
sliced_key = key.index_select(dim=0, index=labels)
assert query.shape[0] == sliced_key.shape[0]
# batch_size x batch_size
kq = torch.mm(sliced_key, query.t())
kq.fill_diagonal_(float('-inf'))
qq = torch.mm(query, query.t())
qq.fill_diagonal_(float('-inf'))
kk = torch.mm(sliced_key, sliced_key.t())
kk.fill_diagonal_(float('-inf'))
scores = torch.cat([qk, kq, qq, kk], dim=-1)
return scores, labels
def slice_batch_dict(batch_dict: Dict[str, torch.Tensor], prefix: str) -> dict:
return {k[len(prefix):]: v for k, v in batch_dict.items() if k.startswith(prefix)}
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name: str, round_digits: int = 3):
self.name = name
self.round_digits = round_digits
self.reset()
def reset(self):
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
return '{}: {}'.format(self.name, round(self.avg, self.round_digits))
if __name__ == '__main__':
query = torch.randn(4, 16)
key = torch.randn(4 * 3, 16)
scores, labels = full_contrastive_scores_and_labels(query, key)
print(scores.shape)
print(labels)