mod
This commit is contained in:
parent
7c7391b608
commit
fcb4d25d8f
64
hft.py
64
hft.py
|
|
@ -3,19 +3,15 @@ import torch
|
|||
import torch.nn as nn
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
|
||||
from datasets import Dataset
|
||||
from PIL import Image
|
||||
import re
|
||||
import pytesseract
|
||||
import docx2txt
|
||||
import PyPDF2
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from huggingface_hub import login
|
||||
|
||||
# Konfiguracja środowiska
|
||||
# Konfiguracja
|
||||
os.environ['TORCH_USE_CUDA_DSA'] = '1'
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
login(token="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")
|
||||
login(token="TWÓJ_TOKEN_HF") # Zastąp swoim tokenem
|
||||
|
||||
class SourceMapper:
|
||||
def __init__(self):
|
||||
|
|
@ -38,7 +34,8 @@ def load_file_catalog(catalog_path):
|
|||
return json.load(file)
|
||||
|
||||
def identify_legal_document(filename, file_catalog):
|
||||
return file_catalog.get(filename, "Opracowanie własne")
|
||||
base_name = os.path.splitext(filename)[0]
|
||||
return file_catalog.get(base_name, "Opracowanie własne")
|
||||
|
||||
def extract_text_from_file(file_path):
|
||||
_, ext = os.path.splitext(file_path)
|
||||
|
|
@ -56,8 +53,6 @@ def extract_text_from_file(file_path):
|
|||
return text
|
||||
elif ext in ['.doc', '.docx']:
|
||||
return docx2txt.process(file_path)
|
||||
elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
|
||||
return pytesseract.image_to_string(Image.open(file_path))
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
|
@ -73,10 +68,11 @@ def prepare_dataset(directory, catalog_path, source_mapper):
|
|||
continue
|
||||
|
||||
doc_type = identify_legal_document(file, file_catalog)
|
||||
|
||||
if doc_type != "Opracowanie własne":
|
||||
articles = re.split(r'(Art\.\s+\d+[\.\s])', text)
|
||||
articles = re.split(r'(#+\s*Art\.\s*\d+[\.\s]?)', text)
|
||||
for i in range(1, len(articles), 2):
|
||||
article_number = articles[i].strip()
|
||||
article_number = re.sub(r'#+\s*', '', articles[i].strip())
|
||||
article_content = articles[i+1].strip() if i+1 < len(articles) else ""
|
||||
source = f"{doc_type}, {article_number}"
|
||||
source_mapper.add_source(source)
|
||||
|
|
@ -130,7 +126,6 @@ def main():
|
|||
data = prepare_dataset("files", catalog_path, source_mapper)
|
||||
dataset = Dataset.from_list(data)
|
||||
|
||||
# Tokenizacja
|
||||
def tokenize_function(examples):
|
||||
tokenized = tokenizer(
|
||||
examples["text"],
|
||||
|
|
@ -154,44 +149,53 @@ def main():
|
|||
# Trening
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
num_train_epochs=3,
|
||||
num_train_epochs=5,
|
||||
per_device_train_batch_size=2,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=2e-5,
|
||||
learning_rate=3e-5,
|
||||
fp16=torch.cuda.is_available(),
|
||||
logging_steps=1,
|
||||
logging_steps=10,
|
||||
save_strategy="steps",
|
||||
save_steps=1000,
|
||||
report_to="none"
|
||||
report_to="none",
|
||||
weight_decay=0.01
|
||||
)
|
||||
|
||||
trainer = CustomTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized_dataset,
|
||||
data_collator=lambda x: x
|
||||
)
|
||||
print("Rozpoczęcie treningu...")
|
||||
trainer.train()
|
||||
|
||||
# Testowanie
|
||||
def generate_answer(question):
|
||||
inputs = tokenizer(question, return_tensors="pt").to(device)
|
||||
inputs = tokenizer(
|
||||
f"[PYTANIE PRAWNE] {question}",
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
max_length=512
|
||||
).to(device)
|
||||
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=200,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
do_sample=True,
|
||||
repetition_penalty=1.2,
|
||||
no_repeat_ngram_size=2,
|
||||
pad_token_id=tokenizer.eos_token_id
|
||||
)
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=200,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
do_sample=True,
|
||||
repetition_penalty=1.5,
|
||||
no_repeat_ngram_size=3,
|
||||
pad_token_id=tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
answer = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(question, "").strip()
|
||||
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
answer = answer.split("[PYTANIE PRAWNE]")[-1].strip()
|
||||
|
||||
sources = set()
|
||||
for match in re.finditer(r'Art\.\s+\d+', answer):
|
||||
for match in re.finditer(r'Art\.\s*\d+', answer):
|
||||
article_ref = match.group(0).strip()
|
||||
for idx, source in source_mapper.idx_to_source.items():
|
||||
if article_ref in source:
|
||||
|
|
@ -203,7 +207,7 @@ def main():
|
|||
"sources": list(sources) if sources else ["Opracowanie własne"]
|
||||
}
|
||||
|
||||
# Przykładowe testy
|
||||
# Testy
|
||||
test_questions = [
|
||||
"Jakie są zasady udzielania urlopu wypoczynkowego?",
|
||||
"Co mówi art. 154 kodeksu pracy?",
|
||||
|
|
|
|||
Loading…
Reference in New Issue