ably.do/hft.py

341 lines
12 KiB
Python

import os
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
from datasets import Dataset
import re
import json
import PyPDF2
import docx2txt
import pytesseract
from PIL import Image
from collections import defaultdict
from huggingface_hub import login
from torch.utils.data import DataLoader
# Konfiguracja
os.environ['TORCH_USE_CUDA_DSA'] = '1'
os.environ["TOKENIZERS_PARALLELISM"] = "false"
login(token="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") # Zastąp swoim tokenem
class SourceMapper:
def __init__(self):
self.source_to_idx = defaultdict(lambda: len(self.source_to_idx))
self.idx_to_source = {}
def add_source(self, source):
if source and source not in self.source_to_idx:
idx = self.source_to_idx[source]
self.idx_to_source[idx] = source
def get_idx(self, source):
return self.source_to_idx[source] if source else -1
def get_source(self, idx):
return self.idx_to_source.get(idx, "Unknown")
def load_file_catalog(catalog_path):
try:
with open(catalog_path, 'r', encoding='utf-8') as file:
return json.load(file)
except Exception as e:
print(f"Błąd wczytywania katalogu plików: {str(e)}")
return {}
def identify_legal_document(filename, file_catalog):
base_name = os.path.splitext(filename)[0].lower()
return file_catalog.get(base_name, "Opracowanie własne")
def extract_text_from_file(file_path):
try:
_, ext = os.path.splitext(file_path)
ext = ext.lower()
if ext in ['.txt', '.md']:
with open(file_path, 'r', encoding='utf-8') as file:
return file.read()
elif ext == '.pdf':
text = ""
try:
with open(file_path, 'rb') as file:
reader = PyPDF2.PdfReader(file)
for page in reader.pages:
text += page.extract_text() or ""
except Exception as e:
print(f"Błąd PDF: {str(e)}")
return text
elif ext in ['.doc', '.docx']:
return docx2txt.process(file_path)
elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
return pytesseract.image_to_string(Image.open(file_path))
else:
print(f"Nieobsługiwany format pliku: {ext}")
return ""
except Exception as e:
print(f"Błąd ekstrakcji tekstu: {str(e)}")
return ""
def prepare_dataset(directory, catalog_path, source_mapper):
file_catalog = load_file_catalog(catalog_path)
data = []
print(f"\n{'='*50}\nDIAGNOSTYKA DANYCH\n{'='*50}")
for root, _, files in os.walk(directory):
for file in files:
file_path = os.path.join(root, file)
print(f"\nPrzetwarzanie pliku: {file_path}")
try:
text = extract_text_from_file(file_path)
if not text.strip():
print("Pominięto - brak tekstu")
continue
print(f"Długość tekstu: {len(text)} znaków")
doc_type = identify_legal_document(file, file_catalog)
print(f"Rozpoznany typ dokumentu: {doc_type}")
if doc_type != "Opracowanie własne":
# Ulepszone wyrażenie regularne dla różnych formatów
articles = re.split(r'(?i)(Art[^\S\n]*\.?[^\S\n]*\d+[^\S\n]*\.?)', text)
articles = [a.strip() for a in articles if a.strip()]
print(f"Znaleziono {len(articles)//2} artykułów")
for i in range(0, len(articles)-1, 2):
article_number = articles[i]
article_content = articles[i+1]
if len(article_content) < 50:
print(f"Pominięto zbyt krótki artykuł: {article_number}")
continue
source = f"{doc_type}, {article_number}"
print(f"Dodano artykuł: {source}")
source_mapper.add_source(source)
data.append({
"text": f"{article_number} {article_content}",
"source_idx": source_mapper.get_idx(source)
})
else:
clean_text = re.sub(r'\s+', ' ', text).strip()
chunks = [clean_text[i:i+512] for i in range(0, len(clean_text), 512)]
chunks = [c for c in chunks if c.strip()]
for chunk in chunks:
data.append({
"text": chunk,
"source_idx": -1
})
print(f"Dodano {len(chunks)} chunków")
except Exception as e:
print(f"Błąd podczas przetwarzania pliku: {str(e)}")
continue
print(f"\nPodsumowanie przygotowania danych:")
print(f"Łączna liczba przykładów: {len(data)}")
if data:
print("Przykładowy wpis:")
print(json.dumps(data[0], indent=2, ensure_ascii=False))
else:
print("BRAK DANYCH - sprawdź diagnostykę powyżej")
return data
class CustomModel(nn.Module):
def __init__(self, model_name, config):
super().__init__()
self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
self.source_embedding = nn.Embedding(1000, config.hidden_size, padding_idx=-1)
# Zamrożenie warstw bazowego modelu
for param in self.base_model.parameters():
param.requires_grad = False
for param in self.base_model.get_output_embeddings().parameters():
param.requires_grad = True
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs):
if source_idx is not None:
valid_indices = torch.clamp(source_idx, 0, self.source_embedding.num_embeddings-1)
source_embeds = torch.nn.functional.normalize(
self.source_embedding(valid_indices),
p=2,
dim=-1
).unsqueeze(1)
inputs_embeds = self.base_model.get_input_embeddings()(input_ids) + source_embeds
return self.base_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels,
**kwargs
)
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
**kwargs
)
def generate(self, *args, **kwargs):
return self.base_model.generate(*args, **kwargs)
class CustomTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
labels = inputs.pop("labels")
source_idx = inputs.pop("source_idx", None)
outputs = model(**inputs, labels=labels, source_idx=source_idx)
return (outputs.loss, outputs) if return_outputs else outputs.loss
def evaluate(self):
val_questions = {
"art1": "Jakie są prawa pracownika według art. 1?",
"art2": "Kto jest pracownikiem według art. 2?",
"art3": "Jakie są obowiązki pracodawcy według art. 3?"
}
model.eval()
results = {}
for key, question in val_questions.items():
result = self.generate_answer(question)
results[key] = result
print("\nWyniki walidacji:")
for key, val in results.items():
print(f"\n{val_questions[key]}")
print(f"Odpowiedź: {val['answer'][:200]}...")
print(f"Źródła: {val['sources']}")
return {"loss": 0.0}
def generate_answer(self, question):
tokenizer = self.tokenizer
model = self.model
device = model.base_model.device
prompt = f"[PYTANIE PRAWNE] {question} [KONTEKST]"
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=512
).to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=150,
temperature=0.3,
top_k=50,
top_p=0.95,
repetition_penalty=1.8,
num_beams=3,
no_repeat_ngram_size=4,
early_stopping=True,
pad_token_id=tokenizer.eos_token_id
)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
answer = answer.replace(prompt, "").strip()
sources = set()
for match in re.finditer(r'(?i)art\.?\s*\d+\.?', answer):
article_ref = match.group(0).strip().rstrip('.')
for source in self.model.source_mapper.idx_to_source.values():
if article_ref.lower() in source.lower():
sources.add(source)
return {
"answer": answer,
"sources": list(sources) if sources else ["Opracowanie własne"]
}
def main():
# Inicjalizacja
source_mapper = SourceMapper()
model_name = "crumb/nano-mistral"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
# Przygotowanie danych
catalog_path = "file_catalog.json"
data = prepare_dataset("files", catalog_path, source_mapper)
if not data:
print("\nBrak danych do treningu! Sprawdź pliki w katalogu 'files' i diagnostykę powyżej.")
return
dataset = Dataset.from_list(data)
def tokenize_function(examples):
tokenized = tokenizer(
examples["text"],
truncation=True,
padding="max_length",
max_length=512,
return_tensors="pt"
)
return {
"input_ids": tokenized["input_ids"][0],
"attention_mask": tokenized["attention_mask"][0],
"labels": tokenized["input_ids"][0].clone(),
"source_idx": examples["source_idx"]
}
tokenized_dataset = dataset.map(tokenize_function, batched=False)
def custom_collate_fn(features):
return {
"input_ids": torch.stack([torch.tensor(f["input_ids"]) for f in features]),
"attention_mask": torch.stack([torch.tensor(f["attention_mask"]) for f in features]),
"labels": torch.stack([torch.tensor(f["labels"]) for f in features]),
"source_idx": torch.tensor([f["source_idx"] for f in features], dtype=torch.long)
}
# Model
config = AutoModelForCausalLM.from_pretrained(model_name).config
model = CustomModel(model_name, config)
model.source_mapper = source_mapper # Dodanie mapowania źródeł do modelu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Trening
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=5,
per_device_train_batch_size=4,
gradient_accumulation_steps=2,
learning_rate=1e-5,
weight_decay=0.01,
warmup_ratio=0.1,
fp16=torch.cuda.is_available(),
logging_steps=10,
save_strategy="epoch",
evaluation_strategy="steps",
eval_steps=500,
report_to="none",
remove_unused_columns=False
)
trainer = CustomTrainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=custom_collate_fn,
tokenizer=tokenizer
)
print("\nRozpoczęcie treningu...")
trainer.train()
trainer.evaluate()
if __name__ == "__main__":
main()