ably.do/hft.py

239 lines
8.4 KiB
Python

import os
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
from datasets import Dataset
from PIL import Image
import re
import pytesseract
import docx2txt
import PyPDF2
import json
from collections import defaultdict
from huggingface_hub import login
os.environ['TORCH_USE_CUDA_DSA'] = '1'
os.environ["TOKENIZERS_PARALLELISM"] = "false"
login(token="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")
class SourceMapper:
def __init__(self):
self.source_to_idx = defaultdict(lambda: len(self.source_to_idx))
self.idx_to_source = {}
def add_source(self, source):
if source and source not in self.source_to_idx:
idx = self.source_to_idx[source]
self.idx_to_source[idx] = source
def get_idx(self, source):
return self.source_to_idx[source] if source else -1
def get_source(self, idx):
return self.idx_to_source.get(idx, "Unknown")
def load_file_catalog(catalog_path):
with open(catalog_path, 'r', encoding='utf-8') as file:
return json.load(file)
def identify_legal_document(filename, file_catalog):
return file_catalog.get(filename, "Opracowanie własne")
def extract_text_from_file(file_path):
_, ext = os.path.splitext(file_path)
ext = ext.lower()
if ext in ['.txt', '.md']:
with open(file_path, 'r', encoding='utf-8') as file:
return file.read()
elif ext == '.pdf':
text = ""
with open(file_path, 'rb') as file:
reader = PyPDF2.PdfReader(file)
for page in reader.pages:
text += page.extract_text()
return text
elif ext in ['.doc', '.docx']:
return docx2txt.process(file_path)
elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
return pytesseract.image_to_string(Image.open(file_path))
else:
return ""
def prepare_dataset(directory, catalog_path, source_mapper):
file_catalog = load_file_catalog(catalog_path)
data = []
for root, _, files in os.walk(directory):
for file in files:
file_path = os.path.join(root, file)
text = extract_text_from_file(file_path)
if not text:
continue
doc_type = identify_legal_document(file, file_catalog)
if doc_type != "Opracowanie własne":
articles = re.split(r'(Art\.\s+\d+[\.\s])', text)
for i in range(1, len(articles), 2):
article_number = articles[i].strip()
article_content = articles[i+1].strip() if i+1 < len(articles) else ""
source = f"{doc_type}, {article_number}"
source_mapper.add_source(source)
data.append({
"text": f"{article_number} {article_content}",
"source_idx": source_mapper.get_idx(source)
})
else:
chunks = [text[i:i+512] for i in range(0, len(text), 512)]
for chunk in chunks:
data.append({
"text": chunk,
"source_idx": -1 # Brak źródła
})
return data
def tokenize_function(examples):
tokenized = tokenizer(
examples["text"],
truncation=True,
padding="max_length",
max_length=512,
return_tensors="pt"
)
tokenized["labels"] = tokenized["input_ids"].clone()
tokenized["source_idx"] = examples["source_idx"]
return tokenized
def custom_collate_fn(batch):
input_ids = torch.stack([torch.tensor(b["input_ids"]) for b in batch])
attention_mask = torch.stack([torch.tensor(b["attention_mask"]) for b in batch])
labels = torch.stack([torch.tensor(b["labels"]) for b in batch])
source_idx = torch.tensor([b.get("source_idx", -1) for b in batch], dtype=torch.long)
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "source_idx": source_idx}
class CustomModel(nn.Module):
def __init__(self, model_name, config):
super().__init__()
self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
self.source_embedding = nn.Embedding(
num_embeddings=1000,
embedding_dim=config.hidden_size,
padding_idx=-1
)
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs):
if source_idx is not None:
source_idx = torch.clamp(source_idx, 0, self.source_embedding.num_embeddings - 1)
source_embeds = self.source_embedding(source_idx).unsqueeze(1).expand(-1, input_ids.size(1), -1)
inputs_embeds = self.base_model.get_input_embeddings()(input_ids) + source_embeds
outputs = self.base_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, **kwargs)
else:
outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
return outputs
def generate(self, *args, **kwargs):
return self.base_model.generate(*args, **kwargs)
class CustomTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
labels = inputs.pop("labels")
source_idx = inputs.pop("source_idx", None)
outputs = model(input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
labels=labels,
source_idx=source_idx)
return (outputs.loss, outputs) if return_outputs else outputs.loss
# Inicjalizacja komponentów
source_mapper = SourceMapper()
model_name = "crumb/nano-mistral"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
# Przygotowanie danych
catalog_path = "file_catalog.json"
data = prepare_dataset("files", catalog_path, source_mapper)
dataset = Dataset.from_list(data)
tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=8)
# Inicjalizacja modelu
config = AutoModelForCausalLM.from_pretrained(model_name).config
model = CustomModel(model_name, config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Konfiguracja treningu
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
learning_rate=2e-5,
fp16=torch.cuda.is_available(),
logging_steps=1,
save_strategy="steps",
save_steps=1000,
logging_strategy="no",
report_to="none"
)
# Trening
trainer = CustomTrainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=custom_collate_fn,
)
trainer.train()
# Funkcja testująca
def generate_answer_with_source(question, model, tokenizer, source_mapper, max_length=200):
device = next(model.parameters()).device
inputs = tokenizer(question, return_tensors="pt", truncation=True, max_length=512).to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=max_length,
num_return_sequences=1,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id
)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Wyszukiwanie źródeł
sources = set()
for idx in source_mapper.idx_to_source:
if source_mapper.idx_to_source[idx] in answer:
sources.add(source_mapper.idx_to_source[idx])
return {
"question": question,
"answer": answer,
"sources": list(sources) if sources else ["Opracowanie własne"]
}
# Testowanie
test_questions = [
"Jaki jest wymiar urlopu wypoczynkowego?",
"Jakie są zasady bezpieczeństwa na budowie?",
"Wyjaśnij procedurę zwolnienia grupowego"
]
print("\n=== TEST MODELU ===")
for question in test_questions:
result = generate_answer_with_source(question, model, tokenizer, source_mapper)
print(f"\nPytanie: {result['question']}")
print(f"Odpowiedź: {result['answer']}")
print(f"Źródła: {', '.join(result['sources'])}")
print("="*80)
# Zapis modelu
save_directory = "./trained_model"
os.makedirs(save_directory, exist_ok=True)
torch.save(model.state_dict(), os.path.join(save_directory, "model.bin"))
tokenizer.save_pretrained(save_directory)