diff --git a/hft.py b/hft.py index a947664..acf1678 100644 --- a/hft.py +++ b/hft.py @@ -189,22 +189,6 @@ trainer = CustomTrainer( ) trainer.train() -# Funkcja generująca odpowiedź -def generate_answer(question, max_length=200): - model.eval() - inputs = tokenizer(question, return_tensors="pt", truncation=True, max_length=512) - - with torch.no_grad(): - outputs = model.generate( - **inputs, - max_length=max_length, - num_return_sequences=1, - return_dict_in_generate=True - ) - - answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) - return answer - # Utwórz katalog do zapisu modelu save_directory = "./trained_model/ably.do/hse" os.makedirs(save_directory, exist_ok=True) @@ -227,11 +211,82 @@ with open(os.path.join(save_directory, "source_mapper.json"), 'w') as f: # 4. Zapisz konfigurację modelu (opcjonalnie, ale zalecane) model.base_model.config.save_pretrained(save_directory) -# Przeprowadź testy -test_questions = [ - "Ile dni urlopu przysługuje pracownikowi, który przepracował w pełnym wymiarze pracy 5 lat?" +# Funkcja generująca odpowiedź +def generate_answer_with_source(question, model, tokenizer, source_mapper, max_length=200): + device = next(model.parameters()).device + inputs = tokenizer( + question, + return_tensors="pt", + truncation=True, + max_length=512 + ).to(device) + + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_length=max_length, + num_return_sequences=1, + return_dict_in_generate=True, + temperature=0.7, + top_p=0.9, + ) + + answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) + + # Ekstrakcja informacji o źródłach + article_matches = re.finditer(r'Art\.\s+\d+', answer) + sources = set() + + for match in article_matches: + article_ref = match.group(0).strip() + for idx, source in source_mapper.idx_to_source.items(): + if article_ref in source: + sources.add(source) + break + + return { + "question": question, + "answer": answer, + "sources": list(sources) if sources else ["Opracowanie własne"], + "num_tokens": len(outputs.sequences[0]) + } + + + +# Przykładowe testy +test_cases = [ + "Jaki jest wymiar urlopu wypoczynkowego?", + "Jakie są zasady bezpieczeństwa na budowie?", + "Wyjaśnij procedurę zwolnienia grupowego", + "Co reguluje ustawa o ochronie danych osobowych?", + "Jakie dokumenty są potrzebne do zawarcia umowy o pracę?" ] -for q in test_questions: - print(f"Pytanie: {q}") - print(f"Odpowiedź: {generate_answer(q)}\n{'='*50}") \ No newline at end of file +print("\n\n🔴 🔴 🔴 ROZPOCZĘCIE TESTOWANIA MODELU 🔴 🔴 🔴") +for case in test_cases: + result = generate_answer_with_source(case, model, tokenizer, source_mapper) + print(f"\n🔷 Pytanie: {result['question']}") + print(f"🔷 Odpowiedź ({result['num_tokens']} tokenów):") + print(result['answer']) + print(f"🔷 Źródła: {', '.join(result['sources'])}") + print("-"*80) + +# Funkcja generująca odpowiedź +def generate_answer(question, max_length=200): + model.eval() + inputs = tokenizer(question, return_tensors="pt", truncation=True, max_length=512).to(device) + + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_length=max_length, + num_return_sequences=1, + return_dict_in_generate=True + ) + + answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) + return answer + +# Utwórz katalog do zapisu modelu +save_directory = "./trained_model/ably.do/hse" +os.makedirs(save_directory, exist_ok=True) \ No newline at end of file