ably.do/hft.py

230 lines
8.0 KiB
Python

import os
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
from datasets import Dataset
import re
import json
from collections import defaultdict
from huggingface_hub import login
import PyPDF2 # Dodane
import docx2txt # Dodane
import pytesseract # Dodane
from PIL import Image # Dodane
# Konfiguracja
os.environ['TORCH_USE_CUDA_DSA'] = '1'
os.environ["TOKENIZERS_PARALLELISM"] = "false"
login(token="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") # Zastąp swoim tokenem
class SourceMapper:
def __init__(self):
self.source_to_idx = defaultdict(lambda: len(self.source_to_idx))
self.idx_to_source = {}
def add_source(self, source):
if source and source not in self.source_to_idx:
idx = self.source_to_idx[source]
self.idx_to_source[idx] = source
def get_idx(self, source):
return self.source_to_idx[source] if source else -1
def get_source(self, idx):
return self.idx_to_source.get(idx, "Unknown")
def load_file_catalog(catalog_path):
with open(catalog_path, 'r', encoding='utf-8') as file:
return json.load(file)
def identify_legal_document(filename, file_catalog):
base_name = os.path.splitext(filename)[0]
return file_catalog.get(base_name, "Opracowanie własne")
def extract_text_from_file(file_path):
_, ext = os.path.splitext(file_path)
ext = ext.lower()
if ext in ['.txt', '.md']:
with open(file_path, 'r', encoding='utf-8') as file:
return file.read()
elif ext == '.pdf':
text = ""
with open(file_path, 'rb') as file:
reader = PyPDF2.PdfReader(file)
for page in reader.pages:
text += page.extract_text()
return text
elif ext in ['.doc', '.docx']:
return docx2txt.process(file_path)
else:
return ""
def prepare_dataset(directory, catalog_path, source_mapper):
file_catalog = load_file_catalog(catalog_path)
data = []
for root, _, files in os.walk(directory):
for file in files:
file_path = os.path.join(root, file)
text = extract_text_from_file(file_path)
if not text:
continue
doc_type = identify_legal_document(file, file_catalog)
if doc_type != "Opracowanie własne":
articles = re.split(r'(#+\s*Art\.\s*\d+[\.\s]?)', text)
for i in range(1, len(articles), 2):
article_number = re.sub(r'#+\s*', '', articles[i].strip())
article_content = articles[i+1].strip() if i+1 < len(articles) else ""
source = f"{doc_type}, {article_number}"
source_mapper.add_source(source)
data.append({
"text": f"{article_number} {article_content}",
"source_idx": source_mapper.get_idx(source)
})
else:
chunks = [text[i:i+512] for i in range(0, len(text), 512)]
for chunk in chunks:
data.append({
"text": chunk,
"source_idx": -1
})
return data
class CustomModel(nn.Module):
def __init__(self, model_name, config):
super().__init__()
self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
self.source_embedding = nn.Embedding(1000, config.hidden_size, padding_idx=-1)
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs):
if source_idx is not None:
source_idx = torch.clamp(source_idx, 0, self.source_embedding.num_embeddings-1)
source_embeds = self.source_embedding(source_idx).unsqueeze(1).expand(-1, input_ids.size(1), -1)
inputs_embeds = self.base_model.get_input_embeddings()(input_ids) + source_embeds
return self.base_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, **kwargs)
return self.base_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
def generate(self, *args, **kwargs):
return self.base_model.generate(*args, **kwargs)
class CustomTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
labels = inputs.pop("labels")
source_idx = inputs.pop("source_idx", None)
outputs = model(**inputs, labels=labels, source_idx=source_idx)
return (outputs.loss, outputs) if return_outputs else outputs.loss
def main():
# Inicjalizacja
source_mapper = SourceMapper()
model_name = "crumb/nano-mistral"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
# Przygotowanie danych
catalog_path = "file_catalog.json"
data = prepare_dataset("files", catalog_path, source_mapper)
dataset = Dataset.from_list(data)
def tokenize_function(examples):
tokenized = tokenizer(
examples["text"],
truncation=True,
padding="max_length",
max_length=512,
return_tensors="pt"
)
tokenized["labels"] = tokenized["input_ids"].clone()
tokenized["source_idx"] = examples["source_idx"]
return tokenized
tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=8)
# Model
config = AutoModelForCausalLM.from_pretrained(model_name).config
model = CustomModel(model_name, config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Trening
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=5,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
learning_rate=3e-5,
fp16=torch.cuda.is_available(),
logging_steps=10,
save_strategy="steps",
save_steps=1000,
report_to="none",
weight_decay=0.01
)
trainer = CustomTrainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=lambda x: x
)
print("Rozpoczęcie treningu...")
trainer.train()
# Testowanie
def generate_answer(question):
inputs = tokenizer(
f"[PYTANIE PRAWNE] {question}",
return_tensors="pt",
truncation=True,
max_length=512
).to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=200,
temperature=0.7,
top_p=0.9,
do_sample=True,
repetition_penalty=1.5,
no_repeat_ngram_size=3,
pad_token_id=tokenizer.eos_token_id
)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
answer = answer.split("[PYTANIE PRAWNE]")[-1].strip()
sources = set()
for match in re.finditer(r'Art\.\s*\d+', answer):
article_ref = match.group(0).strip()
for idx, source in source_mapper.idx_to_source.items():
if article_ref in source:
sources.add(source)
return {
"question": question,
"answer": answer,
"sources": list(sources) if sources else ["Opracowanie własne"]
}
# Testy
test_questions = [
"Jakie są zasady udzielania urlopu wypoczynkowego?",
"Co mówi art. 154 kodeksu pracy?",
"Jakie są obowiązki pracodawcy w zakresie BHP?"
]
print("\n" + "="*50 + "\nWYNIKI TESTOW\n" + "="*50)
for question in test_questions:
result = generate_answer(question)
print(f"\nPYTANIE: {result['question']}")
print(f"ODPOWIEDŹ: {result['answer'][:500]}")
print(f"ŹRÓDŁA: {', '.join(result['sources'])}")
print("-"*80)
if __name__ == "__main__":
main()