mod
This commit is contained in:
parent
7c7391b608
commit
fcb4d25d8f
48
hft.py
48
hft.py
|
|
@ -3,19 +3,15 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from PIL import Image
|
|
||||||
import re
|
import re
|
||||||
import pytesseract
|
|
||||||
import docx2txt
|
|
||||||
import PyPDF2
|
|
||||||
import json
|
import json
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from huggingface_hub import login
|
from huggingface_hub import login
|
||||||
|
|
||||||
# Konfiguracja środowiska
|
# Konfiguracja
|
||||||
os.environ['TORCH_USE_CUDA_DSA'] = '1'
|
os.environ['TORCH_USE_CUDA_DSA'] = '1'
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
login(token="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")
|
login(token="TWÓJ_TOKEN_HF") # Zastąp swoim tokenem
|
||||||
|
|
||||||
class SourceMapper:
|
class SourceMapper:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
@ -38,7 +34,8 @@ def load_file_catalog(catalog_path):
|
||||||
return json.load(file)
|
return json.load(file)
|
||||||
|
|
||||||
def identify_legal_document(filename, file_catalog):
|
def identify_legal_document(filename, file_catalog):
|
||||||
return file_catalog.get(filename, "Opracowanie własne")
|
base_name = os.path.splitext(filename)[0]
|
||||||
|
return file_catalog.get(base_name, "Opracowanie własne")
|
||||||
|
|
||||||
def extract_text_from_file(file_path):
|
def extract_text_from_file(file_path):
|
||||||
_, ext = os.path.splitext(file_path)
|
_, ext = os.path.splitext(file_path)
|
||||||
|
|
@ -56,8 +53,6 @@ def extract_text_from_file(file_path):
|
||||||
return text
|
return text
|
||||||
elif ext in ['.doc', '.docx']:
|
elif ext in ['.doc', '.docx']:
|
||||||
return docx2txt.process(file_path)
|
return docx2txt.process(file_path)
|
||||||
elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
|
|
||||||
return pytesseract.image_to_string(Image.open(file_path))
|
|
||||||
else:
|
else:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
@ -73,10 +68,11 @@ def prepare_dataset(directory, catalog_path, source_mapper):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
doc_type = identify_legal_document(file, file_catalog)
|
doc_type = identify_legal_document(file, file_catalog)
|
||||||
|
|
||||||
if doc_type != "Opracowanie własne":
|
if doc_type != "Opracowanie własne":
|
||||||
articles = re.split(r'(Art\.\s+\d+[\.\s])', text)
|
articles = re.split(r'(#+\s*Art\.\s*\d+[\.\s]?)', text)
|
||||||
for i in range(1, len(articles), 2):
|
for i in range(1, len(articles), 2):
|
||||||
article_number = articles[i].strip()
|
article_number = re.sub(r'#+\s*', '', articles[i].strip())
|
||||||
article_content = articles[i+1].strip() if i+1 < len(articles) else ""
|
article_content = articles[i+1].strip() if i+1 < len(articles) else ""
|
||||||
source = f"{doc_type}, {article_number}"
|
source = f"{doc_type}, {article_number}"
|
||||||
source_mapper.add_source(source)
|
source_mapper.add_source(source)
|
||||||
|
|
@ -130,7 +126,6 @@ def main():
|
||||||
data = prepare_dataset("files", catalog_path, source_mapper)
|
data = prepare_dataset("files", catalog_path, source_mapper)
|
||||||
dataset = Dataset.from_list(data)
|
dataset = Dataset.from_list(data)
|
||||||
|
|
||||||
# Tokenizacja
|
|
||||||
def tokenize_function(examples):
|
def tokenize_function(examples):
|
||||||
tokenized = tokenizer(
|
tokenized = tokenizer(
|
||||||
examples["text"],
|
examples["text"],
|
||||||
|
|
@ -154,44 +149,53 @@ def main():
|
||||||
# Trening
|
# Trening
|
||||||
training_args = TrainingArguments(
|
training_args = TrainingArguments(
|
||||||
output_dir="./results",
|
output_dir="./results",
|
||||||
num_train_epochs=3,
|
num_train_epochs=5,
|
||||||
per_device_train_batch_size=2,
|
per_device_train_batch_size=2,
|
||||||
gradient_accumulation_steps=4,
|
gradient_accumulation_steps=4,
|
||||||
learning_rate=2e-5,
|
learning_rate=3e-5,
|
||||||
fp16=torch.cuda.is_available(),
|
fp16=torch.cuda.is_available(),
|
||||||
logging_steps=1,
|
logging_steps=10,
|
||||||
save_strategy="steps",
|
save_strategy="steps",
|
||||||
save_steps=1000,
|
save_steps=1000,
|
||||||
report_to="none"
|
report_to="none",
|
||||||
|
weight_decay=0.01
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = CustomTrainer(
|
trainer = CustomTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
train_dataset=tokenized_dataset,
|
train_dataset=tokenized_dataset,
|
||||||
|
data_collator=lambda x: x
|
||||||
)
|
)
|
||||||
print("Rozpoczęcie treningu...")
|
print("Rozpoczęcie treningu...")
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
# Testowanie
|
# Testowanie
|
||||||
def generate_answer(question):
|
def generate_answer(question):
|
||||||
inputs = tokenizer(question, return_tensors="pt").to(device)
|
inputs = tokenizer(
|
||||||
|
f"[PYTANIE PRAWNE] {question}",
|
||||||
|
return_tensors="pt",
|
||||||
|
truncation=True,
|
||||||
|
max_length=512
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
outputs = model.generate(
|
outputs = model.generate(
|
||||||
**inputs,
|
**inputs,
|
||||||
max_new_tokens=200,
|
max_new_tokens=200,
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
top_p=0.9,
|
top_p=0.9,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
repetition_penalty=1.2,
|
repetition_penalty=1.5,
|
||||||
no_repeat_ngram_size=2,
|
no_repeat_ngram_size=3,
|
||||||
pad_token_id=tokenizer.eos_token_id
|
pad_token_id=tokenizer.eos_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
answer = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(question, "").strip()
|
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||||
|
answer = answer.split("[PYTANIE PRAWNE]")[-1].strip()
|
||||||
|
|
||||||
sources = set()
|
sources = set()
|
||||||
for match in re.finditer(r'Art\.\s+\d+', answer):
|
for match in re.finditer(r'Art\.\s*\d+', answer):
|
||||||
article_ref = match.group(0).strip()
|
article_ref = match.group(0).strip()
|
||||||
for idx, source in source_mapper.idx_to_source.items():
|
for idx, source in source_mapper.idx_to_source.items():
|
||||||
if article_ref in source:
|
if article_ref in source:
|
||||||
|
|
@ -203,7 +207,7 @@ def main():
|
||||||
"sources": list(sources) if sources else ["Opracowanie własne"]
|
"sources": list(sources) if sources else ["Opracowanie własne"]
|
||||||
}
|
}
|
||||||
|
|
||||||
# Przykładowe testy
|
# Testy
|
||||||
test_questions = [
|
test_questions = [
|
||||||
"Jakie są zasady udzielania urlopu wypoczynkowego?",
|
"Jakie są zasady udzielania urlopu wypoczynkowego?",
|
||||||
"Co mówi art. 154 kodeksu pracy?",
|
"Co mówi art. 154 kodeksu pracy?",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue