add validator
This commit is contained in:
74
modules/validator.py
Normal file
74
modules/validator.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# 3 модуль - Валидация и восстановление сущностей
|
||||
|
||||
from NER import NER
|
||||
from paraGenerator import ParaphraseGenerator
|
||||
from pymorphy3 import MorphAnalyzer
|
||||
|
||||
ner = NER()
|
||||
pg = ParaphraseGenerator()
|
||||
morph = MorphAnalyzer()
|
||||
|
||||
def compare_entities(original, generated):
|
||||
"""
|
||||
Сравнивает два списка сущностей. Новые сущности допускаются.
|
||||
|
||||
Возвращает:
|
||||
- True - всё на месте
|
||||
- False - что то потерялось
|
||||
"""
|
||||
def normalize(text):
|
||||
"""
|
||||
Приводит слово к нормальной форме. Пример: Ивану --> Иван
|
||||
|
||||
Возвращает:
|
||||
- Совпадают или нет сущности (bool)
|
||||
- Какие сущности потерялись (set)
|
||||
"""
|
||||
return morph.parse(text)[0].normal_form
|
||||
|
||||
if original.issubset(generated): # если original является подмножеством множества generated
|
||||
return True, set()
|
||||
|
||||
# если нет, проверяем дополнительно, вдруг кейс по типу "Иван" - "Ивану"
|
||||
orig_norm = {normalize(e) for e in original} # нормализуем списки
|
||||
gen_norm = {normalize(e) for e in generated}
|
||||
|
||||
if orig_norm.issubset(gen_norm):
|
||||
return True, set()
|
||||
|
||||
# если по прежнему false, то ищем потерянные сущности
|
||||
lost = set()
|
||||
for o in original:
|
||||
if (normalize(o) not in gen_norm):
|
||||
lost.add(o)
|
||||
return False, lost
|
||||
|
||||
def validator(srcText, srcEntities, paraEntities):
|
||||
"""
|
||||
Использование: validator(<исходный текст>, <сущности исходного текста>, <сущности перефразированного текста>)
|
||||
|
||||
Возвращает:
|
||||
- Исходный текст если сущности сохранены
|
||||
- Изменённый текст, если сущности не сохранены и были восстановлены
|
||||
- None, если сущности не удалось восстановить с трёх раз
|
||||
"""
|
||||
ce = compare_entities(srcEntities, paraEntities)
|
||||
if ce[0]:
|
||||
return srcText # если всё нормально, возвращаем текст в неизменном виде
|
||||
|
||||
regen_prompt = (
|
||||
f'При перефразировании текста "{srcText}" из списка элементов "{', '.join(entity for entity in srcEntities)}"'
|
||||
f'были утеряны или изменены следующие важные элементы: "{', '.join(e for e in ce[1])}". '
|
||||
'Перефразируй исходный текст заново, обратив особое внимание на сохранение этих элементов. Выведи только текст.'
|
||||
)
|
||||
|
||||
for _ in range(3):
|
||||
newParaphrase = pg.generateByPrompt(regen_prompt)
|
||||
paraEntities = ner.extract_entities(newParaphrase)
|
||||
if (compare_entities(srcEntities, paraEntities)):
|
||||
return newParaphrase
|
||||
return None
|
||||
|
||||
# a = set(['a','b', 'c', 'd'])
|
||||
# b = set(['a','b'])
|
||||
# validator('123', '123', a, b)
|
||||
Reference in New Issue
Block a user