diff --git a/hft.py b/hft.py index 0e14915..48c9019 100644 --- a/hft.py +++ b/hft.py @@ -3,19 +3,15 @@ import torch import torch.nn as nn from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer from datasets import Dataset -from PIL import Image import re -import pytesseract -import docx2txt -import PyPDF2 import json from collections import defaultdict from huggingface_hub import login -# Konfiguracja środowiska +# Konfiguracja os.environ['TORCH_USE_CUDA_DSA'] = '1' os.environ["TOKENIZERS_PARALLELISM"] = "false" -login(token="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") +login(token="TWÓJ_TOKEN_HF") # Zastąp swoim tokenem class SourceMapper: def __init__(self): @@ -38,7 +34,8 @@ def load_file_catalog(catalog_path): return json.load(file) def identify_legal_document(filename, file_catalog): - return file_catalog.get(filename, "Opracowanie własne") + base_name = os.path.splitext(filename)[0] + return file_catalog.get(base_name, "Opracowanie własne") def extract_text_from_file(file_path): _, ext = os.path.splitext(file_path) @@ -56,8 +53,6 @@ def extract_text_from_file(file_path): return text elif ext in ['.doc', '.docx']: return docx2txt.process(file_path) - elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']: - return pytesseract.image_to_string(Image.open(file_path)) else: return "" @@ -73,10 +68,11 @@ def prepare_dataset(directory, catalog_path, source_mapper): continue doc_type = identify_legal_document(file, file_catalog) + if doc_type != "Opracowanie własne": - articles = re.split(r'(Art\.\s+\d+[\.\s])', text) + articles = re.split(r'(#+\s*Art\.\s*\d+[\.\s]?)', text) for i in range(1, len(articles), 2): - article_number = articles[i].strip() + article_number = re.sub(r'#+\s*', '', articles[i].strip()) article_content = articles[i+1].strip() if i+1 < len(articles) else "" source = f"{doc_type}, {article_number}" source_mapper.add_source(source) @@ -130,7 +126,6 @@ def main(): data = prepare_dataset("files", catalog_path, source_mapper) dataset = Dataset.from_list(data) - # Tokenizacja def tokenize_function(examples): tokenized = tokenizer( examples["text"], @@ -154,44 +149,53 @@ def main(): # Trening training_args = TrainingArguments( output_dir="./results", - num_train_epochs=3, + num_train_epochs=5, per_device_train_batch_size=2, gradient_accumulation_steps=4, - learning_rate=2e-5, + learning_rate=3e-5, fp16=torch.cuda.is_available(), - logging_steps=1, + logging_steps=10, save_strategy="steps", save_steps=1000, - report_to="none" + report_to="none", + weight_decay=0.01 ) trainer = CustomTrainer( model=model, args=training_args, train_dataset=tokenized_dataset, + data_collator=lambda x: x ) print("Rozpoczęcie treningu...") trainer.train() # Testowanie def generate_answer(question): - inputs = tokenizer(question, return_tensors="pt").to(device) + inputs = tokenizer( + f"[PYTANIE PRAWNE] {question}", + return_tensors="pt", + truncation=True, + max_length=512 + ).to(device) - outputs = model.generate( - **inputs, - max_new_tokens=200, - temperature=0.7, - top_p=0.9, - do_sample=True, - repetition_penalty=1.2, - no_repeat_ngram_size=2, - pad_token_id=tokenizer.eos_token_id - ) + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=200, + temperature=0.7, + top_p=0.9, + do_sample=True, + repetition_penalty=1.5, + no_repeat_ngram_size=3, + pad_token_id=tokenizer.eos_token_id + ) - answer = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(question, "").strip() + answer = tokenizer.decode(outputs[0], skip_special_tokens=True) + answer = answer.split("[PYTANIE PRAWNE]")[-1].strip() sources = set() - for match in re.finditer(r'Art\.\s+\d+', answer): + for match in re.finditer(r'Art\.\s*\d+', answer): article_ref = match.group(0).strip() for idx, source in source_mapper.idx_to_source.items(): if article_ref in source: @@ -203,7 +207,7 @@ def main(): "sources": list(sources) if sources else ["Opracowanie własne"] } - # Przykładowe testy + # Testy test_questions = [ "Jakie są zasady udzielania urlopu wypoczynkowego?", "Co mówi art. 154 kodeksu pracy?",