fix errors
This commit is contained in:
25
main.py
25
main.py
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from modules.NER import NER
|
from modules.NER import NER
|
||||||
from modules.paraGenerator import ParaphraseGenerator
|
from modules.paraGenerator import ParaphraseGenerator
|
||||||
#from modules.validator import validator
|
from modules.validator import validator
|
||||||
|
|
||||||
ner = NER()
|
ner = NER()
|
||||||
pg = ParaphraseGenerator()
|
pg = ParaphraseGenerator()
|
||||||
@@ -10,22 +10,11 @@ pg = ParaphraseGenerator()
|
|||||||
srcText = 'Добрый день, я, Сидоров Иван Иванович. Прошу перевести сто тысяч рублей Якову Петру Игнатьевичу в Москву.'
|
srcText = 'Добрый день, я, Сидоров Иван Иванович. Прошу перевести сто тысяч рублей Якову Петру Игнатьевичу в Москву.'
|
||||||
|
|
||||||
def main(srcText):
|
def main(srcText):
|
||||||
# поиск сущностей
|
srcEntities = ner.extract_entities(srcText) # поиск сущностей
|
||||||
srcEntities = ner.extract_entities(srcText)
|
paraphrase = pg.generate(srcText, srcEntities) # генерация парафраза
|
||||||
print(srcEntities)
|
paraEntities = ner.extract_entities(paraphrase) # поиск сущностей в парафразе
|
||||||
|
return validator(srcText, paraphrase, srcEntities, paraEntities) # валидация
|
||||||
# генерация парафраза
|
|
||||||
paraphrase = pg.generate(srcText, srcEntities)
|
|
||||||
print(paraphrase)
|
|
||||||
|
|
||||||
# поиск сущностей в парафразе
|
|
||||||
paraEntities = ner.extract_entities(paraphrase)
|
|
||||||
print(paraEntities)
|
|
||||||
|
|
||||||
# Валидация
|
|
||||||
# return validator(srcText, paraphrase, srcEntities, paraEntities)
|
|
||||||
|
|
||||||
result = main(srcText)
|
result = main(srcText)
|
||||||
print(result)
|
print(f'ИСХОДНЫЙ ТЕКСТ: {srcText}')
|
||||||
|
print(f'СГЕНЕРИРОВАННЫЙ ТЕКСТ: {result}')
|
||||||
|
|
||||||
|
|||||||
@@ -9,8 +9,10 @@ class ParaphraseGenerator:
|
|||||||
Использование: ParaphraseGenerator([температура], [максимальное количество токенов])
|
Использование: ParaphraseGenerator([температура], [максимальное количество токенов])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, temperature=0.8, max_tokens=200):
|
def __init__(self, temperature=0.8, max_tokens=512):
|
||||||
self.client = OpenAI(base_url='http://127.0.0.1:8080', api_key='')
|
self.client = OpenAI(base_url='http://127.0.0.1:8080', api_key='')
|
||||||
|
self.temperature = temperature
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
|
||||||
def generateByPrompt(self, prompt):
|
def generateByPrompt(self, prompt):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# 3 модуль - Валидация и восстановление сущностей
|
# 3 модуль - Валидация и восстановление сущностей
|
||||||
|
|
||||||
from NER import NER
|
from modules.NER import NER
|
||||||
from paraGenerator import ParaphraseGenerator
|
from modules.paraGenerator import ParaphraseGenerator
|
||||||
from pymorphy3 import MorphAnalyzer
|
from pymorphy3 import MorphAnalyzer
|
||||||
|
|
||||||
ner = NER()
|
ner = NER()
|
||||||
@@ -45,9 +45,9 @@ def compare_entities(original, generated):
|
|||||||
lost.add(o)
|
lost.add(o)
|
||||||
return False, lost
|
return False, lost
|
||||||
|
|
||||||
def validator(srcText, srcEntities, paraEntities):
|
def validator(srcText, paraphrase, srcEntities, paraEntities):
|
||||||
"""
|
"""
|
||||||
Использование: validator(<исходный текст>, <сущности исходного текста>, <сущности перефразированного текста>)
|
Использование: validator(<исходный текст>, <парафразированный текст> <сущности исходного текста>, <сущности перефразированного текста>)
|
||||||
|
|
||||||
Возвращает:
|
Возвращает:
|
||||||
- Исходный текст если сущности сохранены
|
- Исходный текст если сущности сохранены
|
||||||
@@ -56,10 +56,14 @@ def validator(srcText, srcEntities, paraEntities):
|
|||||||
"""
|
"""
|
||||||
ce = compare_entities(srcEntities, paraEntities)
|
ce = compare_entities(srcEntities, paraEntities)
|
||||||
if ce[0]:
|
if ce[0]:
|
||||||
return srcText # если всё нормально, возвращаем текст в неизменном виде
|
return paraphrase # если всё нормально, возвращаем текст в неизменном виде
|
||||||
|
|
||||||
# даём 3 попытки на восстановление
|
print(f'Произошла потеря сущностей!')
|
||||||
for _ in range(3):
|
print(f'Исходные сущности: {srcEntities}')
|
||||||
|
print(f'Сгенерируемые сущности: {paraEntities}')
|
||||||
|
print(f'Потеря сущностей: {ce[1]}')
|
||||||
|
# даём 5 попыток на восстановление
|
||||||
|
for i in range(5):
|
||||||
regen_prompt = (
|
regen_prompt = (
|
||||||
f'При перефразировании текста "{srcText}" из списка элементов "{', '.join(entity for entity in srcEntities)}"'
|
f'При перефразировании текста "{srcText}" из списка элементов "{', '.join(entity for entity in srcEntities)}"'
|
||||||
f'были утеряны или изменены следующие важные элементы: "{', '.join(e for e in ce[1])}". '
|
f'были утеряны или изменены следующие важные элементы: "{', '.join(e for e in ce[1])}". '
|
||||||
@@ -68,6 +72,10 @@ def validator(srcText, srcEntities, paraEntities):
|
|||||||
newParaphrase = pg.generateByPrompt(regen_prompt)
|
newParaphrase = pg.generateByPrompt(regen_prompt)
|
||||||
paraEntities = ner.extract_entities(newParaphrase)
|
paraEntities = ner.extract_entities(newParaphrase)
|
||||||
ce = compare_entities(srcEntities, paraEntities)
|
ce = compare_entities(srcEntities, paraEntities)
|
||||||
|
print(f'Попытка восстановления: {i+1}')
|
||||||
|
print(f'Исходные сущности: {srcEntities}')
|
||||||
|
print(f'Сгенерируемые сущности: {paraEntities}')
|
||||||
|
print(f'Потеря сущностей: {ce[1]}')
|
||||||
if (ce[0]): # если сравнение дало True, выходим из цикла
|
if (ce[0]): # если сравнение дало True, выходим из цикла
|
||||||
return newParaphrase
|
return newParaphrase
|
||||||
return None
|
return None
|
||||||
|
|||||||
Reference in New Issue
Block a user