This commit is contained in:
l.gabrysiak 2025-02-25 21:17:17 +01:00
parent 7c7391b608
commit fcb4d25d8f
1 changed files with 34 additions and 30 deletions

64
hft.py
View File

@ -3,19 +3,15 @@ import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
from datasets import Dataset
from PIL import Image
import re
import pytesseract
import docx2txt
import PyPDF2
import json
from collections import defaultdict
from huggingface_hub import login
# Konfiguracja środowiska
# Konfiguracja
os.environ['TORCH_USE_CUDA_DSA'] = '1'
os.environ["TOKENIZERS_PARALLELISM"] = "false"
login(token="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")
login(token="TWÓJ_TOKEN_HF") # Zastąp swoim tokenem
class SourceMapper:
def __init__(self):
@ -38,7 +34,8 @@ def load_file_catalog(catalog_path):
return json.load(file)
def identify_legal_document(filename, file_catalog):
return file_catalog.get(filename, "Opracowanie własne")
base_name = os.path.splitext(filename)[0]
return file_catalog.get(base_name, "Opracowanie własne")
def extract_text_from_file(file_path):
_, ext = os.path.splitext(file_path)
@ -56,8 +53,6 @@ def extract_text_from_file(file_path):
return text
elif ext in ['.doc', '.docx']:
return docx2txt.process(file_path)
elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
return pytesseract.image_to_string(Image.open(file_path))
else:
return ""
@ -73,10 +68,11 @@ def prepare_dataset(directory, catalog_path, source_mapper):
continue
doc_type = identify_legal_document(file, file_catalog)
if doc_type != "Opracowanie własne":
articles = re.split(r'(Art\.\s+\d+[\.\s])', text)
articles = re.split(r'(#+\s*Art\.\s*\d+[\.\s]?)', text)
for i in range(1, len(articles), 2):
article_number = articles[i].strip()
article_number = re.sub(r'#+\s*', '', articles[i].strip())
article_content = articles[i+1].strip() if i+1 < len(articles) else ""
source = f"{doc_type}, {article_number}"
source_mapper.add_source(source)
@ -130,7 +126,6 @@ def main():
data = prepare_dataset("files", catalog_path, source_mapper)
dataset = Dataset.from_list(data)
# Tokenizacja
def tokenize_function(examples):
tokenized = tokenizer(
examples["text"],
@ -154,44 +149,53 @@ def main():
# Trening
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
num_train_epochs=5,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
learning_rate=2e-5,
learning_rate=3e-5,
fp16=torch.cuda.is_available(),
logging_steps=1,
logging_steps=10,
save_strategy="steps",
save_steps=1000,
report_to="none"
report_to="none",
weight_decay=0.01
)
trainer = CustomTrainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=lambda x: x
)
print("Rozpoczęcie treningu...")
trainer.train()
# Testowanie
def generate_answer(question):
inputs = tokenizer(question, return_tensors="pt").to(device)
inputs = tokenizer(
f"[PYTANIE PRAWNE] {question}",
return_tensors="pt",
truncation=True,
max_length=512
).to(device)
outputs = model.generate(
**inputs,
max_new_tokens=200,
temperature=0.7,
top_p=0.9,
do_sample=True,
repetition_penalty=1.2,
no_repeat_ngram_size=2,
pad_token_id=tokenizer.eos_token_id
)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=200,
temperature=0.7,
top_p=0.9,
do_sample=True,
repetition_penalty=1.5,
no_repeat_ngram_size=3,
pad_token_id=tokenizer.eos_token_id
)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(question, "").strip()
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
answer = answer.split("[PYTANIE PRAWNE]")[-1].strip()
sources = set()
for match in re.finditer(r'Art\.\s+\d+', answer):
for match in re.finditer(r'Art\.\s*\d+', answer):
article_ref = match.group(0).strip()
for idx, source in source_mapper.idx_to_source.items():
if article_ref in source:
@ -203,7 +207,7 @@ def main():
"sources": list(sources) if sources else ["Opracowanie własne"]
}
# Przykładowe testy
# Testy
test_questions = [
"Jakie są zasady udzielania urlopu wypoczynkowego?",
"Co mówi art. 154 kodeksu pracy?",