This commit is contained in:
l.gabrysiak 2025-02-25 23:17:07 +01:00
parent 5f06f859a5
commit b1512778d3
5 changed files with 240 additions and 208 deletions

448
hft.py
View File

@ -1,8 +1,6 @@
import os import os
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from datasets import Dataset
import re import re
import json import json
import PyPDF2 import PyPDF2
@ -10,252 +8,286 @@ import docx2txt
import pytesseract import pytesseract
from PIL import Image from PIL import Image
from collections import defaultdict from collections import defaultdict
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
DataCollatorForLanguageModeling
)
from datasets import Dataset
from huggingface_hub import login from huggingface_hub import login
# Konfiguracja # Konfiguracja
os.environ['TORCH_USE_CUDA_DSA'] = '1'
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
login(token="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") login(token="TWÓJ_TOKEN_HUGGINGFACE")
class SourceMapper: class LegalAITrainer:
def __init__(self): def __init__(self):
self.source_to_idx = defaultdict(lambda: len(self.source_to_idx)) self.source_mapper = defaultdict(lambda: len(self.source_mapper))
self.idx_to_source = {} self.idx_to_source = {}
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def add_source(self, source):
if source and source not in self.source_to_idx: class SourceMapper:
idx = self.source_to_idx[source] def __init__(self):
self.idx_to_source[idx] = source self.source_to_idx = defaultdict(lambda: len(self.source_to_idx))
self.idx_to_source = {}
def add_source(self, source):
if source and source not in self.source_to_idx:
idx = self.source_to_idx[source]
self.idx_to_source[idx] = source
def get_idx(self, source):
return self.source_to_idx[source] if source else -1
def get_source(self, idx):
return self.idx_to_source.get(idx, "Unknown")
class LegalModel(nn.Module):
def __init__(self, model_name, config):
super().__init__()
self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
self.source_embedding = nn.Embedding(100000, config.hidden_size, padding_idx=-1)
self.confidence_layer = nn.Linear(config.hidden_size, 1)
def get_idx(self, source): # Freeze base model
return self.source_to_idx[source] if source else -1 for param in self.base_model.parameters():
param.requires_grad = False
def get_source(self, idx):
return self.idx_to_source.get(idx, "Unknown") # Trainable components
for layer in [self.source_embedding, self.confidence_layer]:
for param in layer.parameters():
param.requires_grad = True
def load_file_catalog(catalog_path): def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None):
try: if source_idx is not None:
with open(catalog_path, 'r', encoding='utf-8') as file: source_idx = torch.clamp(source_idx, 0, self.source_embedding.num_embeddings-1)
return json.load(file) source_embeds = self.source_embedding(source_idx).unsqueeze(1)
except Exception as e: inputs_embeds = self.base_model.get_input_embeddings()(input_ids) + source_embeds
print(f"Błąd wczytywania katalogu plików: {str(e)}") outputs = self.base_model(
return {} inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels
)
else:
outputs = self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
confidence = torch.sigmoid(self.confidence_layer(outputs.hidden_states[-1].mean(dim=1)))
return {
"loss": outputs.loss,
"logits": outputs.logits,
"confidence": confidence,
"hidden_states": outputs.hidden_states
}
def identify_legal_document(filename, file_catalog): def load_file_catalog(self, catalog_path):
base_name = os.path.splitext(filename)[0].lower() try:
return file_catalog.get(base_name, "Opracowanie własne") with open(catalog_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
print(f"Błąd ładowania katalogu: {str(e)}")
return {}
def extract_text_from_file(file_path): def extract_text(self, file_path):
try: ext = os.path.splitext(file_path)[1].lower()
_, ext = os.path.splitext(file_path) try:
ext = ext.lower() if ext in ['.txt', '.md']:
with open(file_path, 'r', encoding='utf-8') as f:
if ext in ['.txt', '.md']: return f.read()
with open(file_path, 'r', encoding='utf-8') as file: elif ext == '.pdf':
return file.read() text = ""
elif ext == '.pdf': with open(file_path, 'rb') as f:
text = "" reader = PyPDF2.PdfReader(f)
try:
with open(file_path, 'rb') as file:
reader = PyPDF2.PdfReader(file)
for page in reader.pages: for page in reader.pages:
text += page.extract_text() or "" text += page.extract_text() or ""
except Exception as e: return text
print(f"Błąd PDF: {str(e)}") elif ext in ['.doc', '.docx']:
return text return docx2txt.process(file_path)
elif ext in ['.doc', '.docx']: elif ext in ['.jpg', '.jpeg', '.png']:
return docx2txt.process(file_path) return pytesseract.image_to_string(Image.open(file_path))
elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']: else:
return pytesseract.image_to_string(Image.open(file_path)) return ""
else: except Exception as e:
print(f"Nieobsługiwany format pliku: {ext}") print(f"Błąd przetwarzania {file_path}: {str(e)}")
return "" return ""
except Exception as e:
print(f"Błąd ekstrakcji tekstu: {str(e)}")
return ""
def prepare_dataset(directory, catalog_path, source_mapper): def prepare_data(self, data_dir, catalog_path):
file_catalog = load_file_catalog(catalog_path) catalog = self.load_file_catalog(catalog_path)
data = [] data = []
source_mapper = self.SourceMapper()
print(f"\n{'='*50}\nDIAGNOSTYKA DANYCH\n{'='*50}")
for root, _, files in os.walk(data_dir):
for root, _, files in os.walk(directory): for file in files:
for file in files: file_path = os.path.join(root, file)
file_path = os.path.join(root, file) text = self.extract_text(file_path)
print(f"\nPrzetwarzanie pliku: {file_path}")
try:
text = extract_text_from_file(file_path)
if not text.strip():
print("Pominięto - brak tekstu")
continue
print(f"Długość tekstu: {len(text)} znaków")
doc_type = identify_legal_document(file, file_catalog) if not text:
print(f"Rozpoznany typ dokumentu: {doc_type}") continue
doc_type = catalog.get(os.path.splitext(file)[0].lower(), "Opracowanie własne")
if doc_type != "Opracowanie własne": if doc_type != "Opracowanie własne":
articles = re.split(r'(?i)(Art[\.\s]+\d+[\.\s]?)', text) articles = re.split(r'(?i)(Art\.\s*\d+[a-z]*)', text)
articles = [a.strip() for a in articles if a.strip()] for i in range(1, len(articles), 2):
art_num = articles[i].strip()
print(f"Znaleziono {len(articles)} fragmentów") content = articles[i+1].strip()
for i in range(0, len(articles)-1, 2):
article_number = articles[i]
article_content = articles[i+1]
if len(article_content) < 50: if len(content) < 100:
continue continue
source = f"{doc_type}, {article_number}" source = f"{doc_type}, {art_num}"
source_mapper.add_source(source) source_mapper.add_source(source)
data.append({ data.append({
"text": f"{article_number} {article_content}", "text": f"[LEGAL] {art_num} {content}",
"source_idx": source_mapper.get_idx(source) "source_idx": source_mapper.get_idx(source),
"is_legal": 1
}) })
else: else:
clean_text = re.sub(r'\s+', ' ', text).strip() chunks = [f"[GENERAL] {text[i:i+512]}" for i in range(0, len(text), 512)]
chunks = [clean_text[i:i+512] for i in range(0, len(clean_text), 512)]
chunks = [c for c in chunks if c.strip()]
for chunk in chunks: for chunk in chunks:
data.append({ data.append({
"text": chunk, "text": chunk,
"source_idx": -1 "source_idx": -1,
"is_legal": 0
}) })
print(f"Dodano {len(chunks)} chunków")
except Exception as e:
print(f"Błąd podczas przetwarzania pliku: {str(e)}")
continue
print(f"\nPodsumowanie przygotowania danych:")
print(f"Łączna liczba przykładów: {len(data)}")
if data:
print("Przykładowy wpis:")
print(json.dumps(data[0], indent=2, ensure_ascii=False))
else:
print("BRAK DANYCH - sprawdź diagnostykę powyżej")
return data return Dataset.from_dict({k: [d[k] for d in data] for k in data[0]}), source_mapper
class CustomModel(nn.Module): def train(self, model_name="crumb/nano-mistral", data_dir="data", catalog_path="catalog.json"):
def __init__(self, model_name, config): # Przygotowanie danych
super().__init__() dataset, source_mapper = self.prepare_data(data_dir, catalog_path)
self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config) tokenizer = AutoTokenizer.from_pretrained(model_name)
self.source_embedding = nn.Embedding(10000, config.hidden_size, padding_idx=-1) tokenizer.pad_token = tokenizer.eos_token
for param in self.base_model.parameters(): # Tokenizacja
param.requires_grad = False def tokenize_fn(examples):
for param in self.base_model.get_output_embeddings().parameters(): tokenized = tokenizer(
param.requires_grad = True examples["text"],
padding="max_length",
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs): truncation=True,
if source_idx is not None: max_length=512,
valid_indices = torch.clamp(source_idx, 0, self.source_embedding.num_embeddings-1) return_tensors="pt"
source_embeds = self.source_embedding(valid_indices).unsqueeze(1)
inputs_embeds = self.base_model.get_input_embeddings()(input_ids) + source_embeds
return self.base_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels,
**kwargs
) )
return self.base_model( return {
input_ids=input_ids, "input_ids": tokenized["input_ids"].squeeze(),
attention_mask=attention_mask, "attention_mask": tokenized["attention_mask"].squeeze(),
labels=labels, "labels": tokenized["input_ids"].squeeze().clone(),
**kwargs "source_idx": examples["source_idx"]
}
tokenized_dataset = dataset.map(tokenize_fn, batched=True, batch_size=16)
# Inicjalizacja modelu
config = AutoModelForCausalLM.from_pretrained(model_name).config
model = self.LegalModel(model_name, config).to(self.device)
# Konfiguracja treningu
training_args = TrainingArguments(
output_dir="./legal_ai_model",
num_train_epochs=5,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
learning_rate=2e-5,
fp16=torch.cuda.is_available(),
logging_steps=50,
save_strategy="steps",
save_steps=500,
report_to="none",
remove_unused_columns=False
) )
def generate(self, *args, **kwargs):
return self.base_model.generate(*args, **kwargs)
class CustomDataCollator(DataCollatorForLanguageModeling): # Customowy Trainer
def torch_call(self, examples): class LegalTrainer(Trainer):
# Przetwórz podstawowe pola def compute_loss(self, model, inputs, return_outputs=False):
input_ids = torch.stack([torch.tensor(ex["input_ids"]) for ex in examples]) outputs = model(**inputs)
attention_mask = torch.stack([torch.tensor(ex["attention_mask"]) for ex in examples]) loss = outputs.loss
labels = torch.stack([torch.tensor(ex["labels"]) for ex in examples])
# Confidence loss
target_conf = (inputs["source_idx"] != -1).float()
conf_loss = nn.BCELoss()(outputs.confidence.squeeze(), target_conf)
total_loss = loss + 0.7*conf_loss
return (total_loss, outputs) if return_outputs else total_loss
# Trening
trainer = LegalTrainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
)
print("Rozpoczęcie treningu...")
trainer.train()
batch = { # Zapisz model
"input_ids": input_ids, model.save_pretrained("./trained_legal_ai")
"attention_mask": attention_mask, tokenizer.save_pretrained("./trained_legal_ai")
"labels": labels with open("./trained_legal_ai/source_mapper.json", "w") as f:
} json.dump(source_mapper.idx_to_source, f)
print("Trening zakończony i model zapisany!")
def generate_response(self, prompt, confidence_threshold=0.65):
# Ładowanie modelu
model = self.LegalModel.from_pretrained("./trained_legal_ai",
config=AutoModelForCausalLM.from_pretrained("crumb/nano-mistral").config)
tokenizer = AutoTokenizer.from_pretrained("./trained_legal_ai")
model.to(self.device)
# Dodaj source_idx jeśli istnieje # Ładowanie mapowania źródeł
if "source_idx" in examples[0]: with open("./trained_legal_ai/source_mapper.json", "r") as f:
source_idx = torch.stack([torch.tensor(ex["source_idx"]) for ex in examples]) source_mapper = json.load(f)
batch["source_idx"] = source_idx
return batch
def main(): # Przygotowanie wejścia
source_mapper = SourceMapper() inputs = tokenizer(
model_name = "crumb/nano-mistral" f"[PROMPT] {prompt} [RESPONSE]",
tokenizer = AutoTokenizer.from_pretrained(model_name) return_tensors="pt",
tokenizer.pad_token = tokenizer.eos_token
# Przygotowanie danych
catalog_path = "file_catalog.json"
data = prepare_dataset("files", catalog_path, source_mapper)
if not data:
print("\nBrak danych do treningu!")
return
#dataset = Dataset.from_list(data)
dataset = Dataset.from_dict({k: [d[k] for d in data] for k in data[0]})
def tokenize_function(examples):
tokenized = tokenizer(
examples["text"],
truncation=True,
padding="max_length",
max_length=512, max_length=512,
return_tensors="pt" truncation=True
) ).to(self.device)
return {
"input_ids": tokenized["input_ids"].squeeze(),
"attention_mask": tokenized["attention_mask"].squeeze(),
"labels": tokenized["input_ids"].squeeze().clone(),
"source_idx": examples["source_idx"] # Dodano bez konwersji do tensora
}
tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=16) # Generacja
with torch.no_grad():
outputs = model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
max_length=512,
do_sample=True,
temperature=0.7,
top_k=50,
pad_token_id=tokenizer.eos_token_id,
output_scores=True,
return_dict_in_generate=True
)
model = CustomModel(model_name, AutoModelForCausalLM.from_pretrained(model_name).config) # Analiza wyników
model.source_mapper = source_mapper full_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") confidence = torch.sigmoid(outputs.scores[-1][:, tokenizer.eos_token_id]).item()
model.to(device)
# Ekstrakcja i weryfikacja źródeł
citations = list(set(re.findall(r"Art\.\s*\d+[a-z]*", full_text)))
verified = [c for c in citations if any(c in s for s in source_mapper.values())]
training_args = TrainingArguments( if confidence < confidence_threshold or not verified:
output_dir="./results", return "Nie mogę udzielić jednoznacznej odpowiedzi na podstawie dostępnych danych."
num_train_epochs=3, else:
per_device_train_batch_size=2, return f"{full_text}\n\nPotwierdzone źródła: {', '.join(verified)}"
gradient_accumulation_steps=4,
learning_rate=2e-5,
fp16=torch.cuda.is_available(),
logging_steps=10,
save_strategy="steps",
save_steps=1000,
report_to="none",
remove_unused_columns=False
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=CustomDataCollator(tokenizer=tokenizer, mlm=False)
)
print("\nRozpoczęcie treningu...")
trainer.train()
if __name__ == "__main__": if __name__ == "__main__":
main() legal_ai = LegalAITrainer()
# Etap 1: Trening
legal_ai.train(
model_name="crumb/nano-mistral",
data_dir="./legal_docs",
catalog_path="./catalog.json"
)
# Etap 2: Testowanie
test_prompt = "Ile dni urlopu przysługuje po 5 latach pracy w pełnym wymiarze?"
print(legal_ai.generate_response(test_prompt))