diff --git a/hft.py b/hft.py index 4e77248..a6f6ca1 100644 --- a/hft.py +++ b/hft.py @@ -11,11 +11,12 @@ import pytesseract from PIL import Image from collections import defaultdict from huggingface_hub import login +from torch.utils.data import DataLoader # Konfiguracja os.environ['TORCH_USE_CUDA_DSA'] = '1' os.environ["TOKENIZERS_PARALLELISM"] = "false" -login(token="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") +login(token="TWÓJ_TOKEN_HF") # Zastąp swoim tokenem class SourceMapper: def __init__(self): @@ -97,15 +98,18 @@ def prepare_dataset(directory, catalog_path, source_mapper): print(f"Rozpoznany typ dokumentu: {doc_type}") if doc_type != "Opracowanie własne": - # Nowe wyrażenie regularne dla formatu "Art. XX." - articles = re.split(r'(Art\. \d+\.?)', text) - print(f"Znaleziono {len(articles)} fragmentów") + # Ulepszone wyrażenie regularne dla różnych formatów + articles = re.split(r'(?i)(Art[^\S\n]*\.?[^\S\n]*\d+[^\S\n]*\.?)', text) + articles = [a.strip() for a in articles if a.strip()] - for i in range(1, len(articles), 2): - article_number = articles[i].strip() - article_content = articles[i+1].strip() if i+1 < len(articles) else "" + print(f"Znaleziono {len(articles)//2} artykułów") + + for i in range(0, len(articles)-1, 2): + article_number = articles[i] + article_content = articles[i+1] - if not article_content: + if len(article_content) < 50: + print(f"Pominięto zbyt krótki artykuł: {article_number}") continue source = f"{doc_type}, {article_number}" @@ -148,13 +152,37 @@ class CustomModel(nn.Module): self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config) self.source_embedding = nn.Embedding(1000, config.hidden_size, padding_idx=-1) + # Zamrożenie warstw bazowego modelu + for param in self.base_model.parameters(): + param.requires_grad = False + for param in self.base_model.get_output_embeddings().parameters(): + param.requires_grad = True + def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs): if source_idx is not None: valid_indices = torch.clamp(source_idx, 0, self.source_embedding.num_embeddings-1) - source_embeds = self.source_embedding(valid_indices).unsqueeze(1).expand(-1, input_ids.size(1), -1) + + source_embeds = torch.nn.functional.normalize( + self.source_embedding(valid_indices), + p=2, + dim=-1 + ).unsqueeze(1) + inputs_embeds = self.base_model.get_input_embeddings()(input_ids) + source_embeds - return self.base_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, **kwargs) - return self.base_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs) + + return self.base_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + labels=labels, + **kwargs + ) + + return self.base_model( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + **kwargs + ) def generate(self, *args, **kwargs): return self.base_model.generate(*args, **kwargs) @@ -165,6 +193,71 @@ class CustomTrainer(Trainer): source_idx = inputs.pop("source_idx", None) outputs = model(**inputs, labels=labels, source_idx=source_idx) return (outputs.loss, outputs) if return_outputs else outputs.loss + + def evaluate(self): + val_questions = { + "art1": "Jakie są prawa pracownika według art. 1?", + "art2": "Kto jest pracownikiem według art. 2?", + "art3": "Jakie są obowiązki pracodawcy według art. 3?" + } + + model.eval() + results = {} + + for key, question in val_questions.items(): + result = self.generate_answer(question) + results[key] = result + + print("\nWyniki walidacji:") + for key, val in results.items(): + print(f"\n{val_questions[key]}") + print(f"Odpowiedź: {val['answer'][:200]}...") + print(f"Źródła: {val['sources']}") + + return {"loss": 0.0} + + def generate_answer(self, question): + tokenizer = self.tokenizer + model = self.model + device = model.base_model.device + + prompt = f"[PYTANIE PRAWNE] {question} [KONTEKST]" + + inputs = tokenizer( + prompt, + return_tensors="pt", + truncation=True, + max_length=512 + ).to(device) + + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=150, + temperature=0.3, + top_k=50, + top_p=0.95, + repetition_penalty=1.8, + num_beams=3, + no_repeat_ngram_size=4, + early_stopping=True, + pad_token_id=tokenizer.eos_token_id + ) + + answer = tokenizer.decode(outputs[0], skip_special_tokens=True) + answer = answer.replace(prompt, "").strip() + + sources = set() + for match in re.finditer(r'(?i)art\.?\s*\d+\.?', answer): + article_ref = match.group(0).strip().rstrip('.') + for source in self.model.source_mapper.idx_to_source.values(): + if article_ref.lower() in source.lower(): + sources.add(source) + + return { + "answer": answer, + "sources": list(sources) if sources else ["Opracowanie własne"] + } def main(): # Inicjalizacja @@ -211,20 +304,24 @@ def main(): # Model config = AutoModelForCausalLM.from_pretrained(model_name).config model = CustomModel(model_name, config) + model.source_mapper = source_mapper # Dodanie mapowania źródeł do modelu device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Trening training_args = TrainingArguments( output_dir="./results", - num_train_epochs=3, - per_device_train_batch_size=2, - gradient_accumulation_steps=4, - learning_rate=2e-5, + num_train_epochs=5, + per_device_train_batch_size=4, + gradient_accumulation_steps=2, + learning_rate=1e-5, + weight_decay=0.01, + warmup_ratio=0.1, fp16=torch.cuda.is_available(), logging_steps=10, - save_strategy="steps", - save_steps=500, + save_strategy="epoch", + evaluation_strategy="steps", + eval_steps=500, report_to="none", remove_unused_columns=False ) @@ -233,65 +330,12 @@ def main(): model=model, args=training_args, train_dataset=tokenized_dataset, - data_collator=custom_collate_fn + data_collator=custom_collate_fn, + tokenizer=tokenizer ) print("\nRozpoczęcie treningu...") trainer.train() - - # Testowanie - def generate_answer(question): - model.eval() - prompt = f"[PYTANIE PRAWNE] {question}" - - inputs = tokenizer( - prompt, - return_tensors="pt", - truncation=True, - max_length=512 - ).to(device) - - with torch.no_grad(): - outputs = model.generate( - **inputs, - max_new_tokens=200, - temperature=0.7, - top_p=0.9, - do_sample=True, - repetition_penalty=1.5, - no_repeat_ngram_size=3, - pad_token_id=tokenizer.eos_token_id - ) - - answer = tokenizer.decode(outputs[0], skip_special_tokens=True) - answer = answer.replace(prompt, "").strip() - - sources = set() - for match in re.finditer(r'(?i)art\.?\s*\d+\.?', answer): - article_ref = match.group(0).strip().rstrip('.') - for source in source_mapper.idx_to_source.values(): - if article_ref.lower() in source.lower(): - sources.add(source) - - return { - "question": question, - "answer": answer, - "sources": list(sources) if sources else ["Opracowanie własne"] - } - - # Testy - test_questions = [ - "Jakie są prawa pracownika według art. 1?", - "Kto jest pracownikiem według art. 2?", - "Jakie są obowiązki pracodawcy według art. 3?" - ] - - print("\n" + "="*50 + "\nWYNIKI TESTOW\n" + "="*50) - for question in test_questions: - result = generate_answer(question) - print(f"\nPYTANIE: {result['question']}") - print(f"ODPOWIEDŹ: {result['answer'][:500]}") - print(f"ŹRÓDŁA: {', '.join(result['sources'])}") - print("-"*80) + trainer.evaluate() if __name__ == "__main__": main() \ No newline at end of file