|
@@ -17,14 +17,21 @@ def score_to_attention(score, symbol=1):
|
|
score_attn = torch.nn.functional.softmax(score_norm, dim=1)
|
|
score_attn = torch.nn.functional.softmax(score_norm, dim=1)
|
|
return score_attn, score_norm, score_pred
|
|
return score_attn, score_norm, score_pred
|
|
|
|
|
|
|
|
+def compare_tensor(tensor1, tensor2):
|
|
|
|
+ if tensor1.shape != tensor2.shape:
|
|
|
|
+ print(f"[EmbeddingManager]shape error: {tensor1.shape} vs {tensor2.shape}")
|
|
|
|
+ return
|
|
|
|
+ if not torch.allclose(tensor1, tensor2):
|
|
|
|
+ print("[EmbeddingManager]value error: tensor1 not close to tensor2")
|
|
|
|
|
|
class NLPFunction(object):
|
|
class NLPFunction(object):
|
|
"""
|
|
"""
|
|
NLP Task
|
|
NLP Task
|
|
"""
|
|
"""
|
|
|
|
|
|
- def __init__(self, model):
|
|
|
|
|
|
+ def __init__(self, model, embedding_manager):
|
|
self.model = model
|
|
self.model = model
|
|
|
|
+ self.embedding_manager = embedding_manager
|
|
|
|
|
|
def base_string_similarity(self, text_dict):
|
|
def base_string_similarity(self, text_dict):
|
|
"""
|
|
"""
|
|
@@ -36,6 +43,13 @@ class NLPFunction(object):
|
|
text_dict['text_a'],
|
|
text_dict['text_a'],
|
|
text_dict['text_b']
|
|
text_dict['text_b']
|
|
)
|
|
)
|
|
|
|
+ # test embedding manager functions
|
|
|
|
+ text_emb1 = self.embedding_manager.get_embedding(text_dict['text_a'])
|
|
|
|
+ text_emb2 = self.embedding_manager.get_embedding(text_dict['text_b'])
|
|
|
|
+ score_function = self.model.score_functions['cos_sim']
|
|
|
|
+ score_tensor_new = score_function(text_emb1, text_emb2)
|
|
|
|
+ compare_tensor(score_tensor, score_tensor_new)
|
|
|
|
+
|
|
response = {
|
|
response = {
|
|
"score": score_tensor.squeeze().tolist()
|
|
"score": score_tensor.squeeze().tolist()
|
|
}
|
|
}
|