This commit is contained in:
l.gabrysiak 2025-02-25 23:17:07 +01:00
parent 5f06f859a5
commit b1512778d3
5 changed files with 240 additions and 208 deletions

382
hft.py
View File

@ -1,8 +1,6 @@
import os import os
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from datasets import Dataset
import re import re
import json import json
import PyPDF2 import PyPDF2
@ -10,14 +8,26 @@ import docx2txt
import pytesseract import pytesseract
from PIL import Image from PIL import Image
from collections import defaultdict from collections import defaultdict
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
DataCollatorForLanguageModeling
)
from datasets import Dataset
from huggingface_hub import login from huggingface_hub import login
# Konfiguracja # Konfiguracja
os.environ['TORCH_USE_CUDA_DSA'] = '1'
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
login(token="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") login(token="TWÓJ_TOKEN_HUGGINGFACE")
class SourceMapper: class LegalAITrainer:
def __init__(self):
self.source_mapper = defaultdict(lambda: len(self.source_mapper))
self.idx_to_source = {}
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class SourceMapper:
def __init__(self): def __init__(self):
self.source_to_idx = defaultdict(lambda: len(self.source_to_idx)) self.source_to_idx = defaultdict(lambda: len(self.source_to_idx))
self.idx_to_source = {} self.idx_to_source = {}
@ -33,189 +43,132 @@ class SourceMapper:
def get_source(self, idx): def get_source(self, idx):
return self.idx_to_source.get(idx, "Unknown") return self.idx_to_source.get(idx, "Unknown")
def load_file_catalog(catalog_path): class LegalModel(nn.Module):
try:
with open(catalog_path, 'r', encoding='utf-8') as file:
return json.load(file)
except Exception as e:
print(f"Błąd wczytywania katalogu plików: {str(e)}")
return {}
def identify_legal_document(filename, file_catalog):
base_name = os.path.splitext(filename)[0].lower()
return file_catalog.get(base_name, "Opracowanie własne")
def extract_text_from_file(file_path):
try:
_, ext = os.path.splitext(file_path)
ext = ext.lower()
if ext in ['.txt', '.md']:
with open(file_path, 'r', encoding='utf-8') as file:
return file.read()
elif ext == '.pdf':
text = ""
try:
with open(file_path, 'rb') as file:
reader = PyPDF2.PdfReader(file)
for page in reader.pages:
text += page.extract_text() or ""
except Exception as e:
print(f"Błąd PDF: {str(e)}")
return text
elif ext in ['.doc', '.docx']:
return docx2txt.process(file_path)
elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
return pytesseract.image_to_string(Image.open(file_path))
else:
print(f"Nieobsługiwany format pliku: {ext}")
return ""
except Exception as e:
print(f"Błąd ekstrakcji tekstu: {str(e)}")
return ""
def prepare_dataset(directory, catalog_path, source_mapper):
file_catalog = load_file_catalog(catalog_path)
data = []
print(f"\n{'='*50}\nDIAGNOSTYKA DANYCH\n{'='*50}")
for root, _, files in os.walk(directory):
for file in files:
file_path = os.path.join(root, file)
print(f"\nPrzetwarzanie pliku: {file_path}")
try:
text = extract_text_from_file(file_path)
if not text.strip():
print("Pominięto - brak tekstu")
continue
print(f"Długość tekstu: {len(text)} znaków")
doc_type = identify_legal_document(file, file_catalog)
print(f"Rozpoznany typ dokumentu: {doc_type}")
if doc_type != "Opracowanie własne":
articles = re.split(r'(?i)(Art[\.\s]+\d+[\.\s]?)', text)
articles = [a.strip() for a in articles if a.strip()]
print(f"Znaleziono {len(articles)} fragmentów")
for i in range(0, len(articles)-1, 2):
article_number = articles[i]
article_content = articles[i+1]
if len(article_content) < 50:
continue
source = f"{doc_type}, {article_number}"
source_mapper.add_source(source)
data.append({
"text": f"{article_number} {article_content}",
"source_idx": source_mapper.get_idx(source)
})
else:
clean_text = re.sub(r'\s+', ' ', text).strip()
chunks = [clean_text[i:i+512] for i in range(0, len(clean_text), 512)]
chunks = [c for c in chunks if c.strip()]
for chunk in chunks:
data.append({
"text": chunk,
"source_idx": -1
})
print(f"Dodano {len(chunks)} chunków")
except Exception as e:
print(f"Błąd podczas przetwarzania pliku: {str(e)}")
continue
print(f"\nPodsumowanie przygotowania danych:")
print(f"Łączna liczba przykładów: {len(data)}")
if data:
print("Przykładowy wpis:")
print(json.dumps(data[0], indent=2, ensure_ascii=False))
else:
print("BRAK DANYCH - sprawdź diagnostykę powyżej")
return data
class CustomModel(nn.Module):
def __init__(self, model_name, config): def __init__(self, model_name, config):
super().__init__() super().__init__()
self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config) self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
self.source_embedding = nn.Embedding(10000, config.hidden_size, padding_idx=-1) self.source_embedding = nn.Embedding(100000, config.hidden_size, padding_idx=-1)
self.confidence_layer = nn.Linear(config.hidden_size, 1)
# Freeze base model
for param in self.base_model.parameters(): for param in self.base_model.parameters():
param.requires_grad = False param.requires_grad = False
for param in self.base_model.get_output_embeddings().parameters():
# Trainable components
for layer in [self.source_embedding, self.confidence_layer]:
for param in layer.parameters():
param.requires_grad = True param.requires_grad = True
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs): def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None):
if source_idx is not None: if source_idx is not None:
valid_indices = torch.clamp(source_idx, 0, self.source_embedding.num_embeddings-1) source_idx = torch.clamp(source_idx, 0, self.source_embedding.num_embeddings-1)
source_embeds = self.source_embedding(valid_indices).unsqueeze(1) source_embeds = self.source_embedding(source_idx).unsqueeze(1)
inputs_embeds = self.base_model.get_input_embeddings()(input_ids) + source_embeds inputs_embeds = self.base_model.get_input_embeddings()(input_ids) + source_embeds
return self.base_model( outputs = self.base_model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
attention_mask=attention_mask, attention_mask=attention_mask,
labels=labels, labels=labels
**kwargs
) )
return self.base_model( else:
outputs = self.base_model(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
labels=labels, labels=labels
**kwargs
) )
def generate(self, *args, **kwargs): confidence = torch.sigmoid(self.confidence_layer(outputs.hidden_states[-1].mean(dim=1)))
return self.base_model.generate(*args, **kwargs) return {
"loss": outputs.loss,
class CustomDataCollator(DataCollatorForLanguageModeling): "logits": outputs.logits,
def torch_call(self, examples): "confidence": confidence,
# Przetwórz podstawowe pola "hidden_states": outputs.hidden_states
input_ids = torch.stack([torch.tensor(ex["input_ids"]) for ex in examples])
attention_mask = torch.stack([torch.tensor(ex["attention_mask"]) for ex in examples])
labels = torch.stack([torch.tensor(ex["labels"]) for ex in examples])
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels
} }
# Dodaj source_idx jeśli istnieje def load_file_catalog(self, catalog_path):
if "source_idx" in examples[0]: try:
source_idx = torch.stack([torch.tensor(ex["source_idx"]) for ex in examples]) with open(catalog_path, 'r', encoding='utf-8') as f:
batch["source_idx"] = source_idx return json.load(f)
except Exception as e:
print(f"Błąd ładowania katalogu: {str(e)}")
return {}
return batch def extract_text(self, file_path):
ext = os.path.splitext(file_path)[1].lower()
try:
if ext in ['.txt', '.md']:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
elif ext == '.pdf':
text = ""
with open(file_path, 'rb') as f:
reader = PyPDF2.PdfReader(f)
for page in reader.pages:
text += page.extract_text() or ""
return text
elif ext in ['.doc', '.docx']:
return docx2txt.process(file_path)
elif ext in ['.jpg', '.jpeg', '.png']:
return pytesseract.image_to_string(Image.open(file_path))
else:
return ""
except Exception as e:
print(f"Błąd przetwarzania {file_path}: {str(e)}")
return ""
def main(): def prepare_data(self, data_dir, catalog_path):
source_mapper = SourceMapper() catalog = self.load_file_catalog(catalog_path)
model_name = "crumb/nano-mistral" data = []
source_mapper = self.SourceMapper()
for root, _, files in os.walk(data_dir):
for file in files:
file_path = os.path.join(root, file)
text = self.extract_text(file_path)
if not text:
continue
doc_type = catalog.get(os.path.splitext(file)[0].lower(), "Opracowanie własne")
if doc_type != "Opracowanie własne":
articles = re.split(r'(?i)(Art\.\s*\d+[a-z]*)', text)
for i in range(1, len(articles), 2):
art_num = articles[i].strip()
content = articles[i+1].strip()
if len(content) < 100:
continue
source = f"{doc_type}, {art_num}"
source_mapper.add_source(source)
data.append({
"text": f"[LEGAL] {art_num} {content}",
"source_idx": source_mapper.get_idx(source),
"is_legal": 1
})
else:
chunks = [f"[GENERAL] {text[i:i+512]}" for i in range(0, len(text), 512)]
for chunk in chunks:
data.append({
"text": chunk,
"source_idx": -1,
"is_legal": 0
})
return Dataset.from_dict({k: [d[k] for d in data] for k in data[0]}), source_mapper
def train(self, model_name="crumb/nano-mistral", data_dir="data", catalog_path="catalog.json"):
# Przygotowanie danych
dataset, source_mapper = self.prepare_data(data_dir, catalog_path)
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
# Przygotowanie danych # Tokenizacja
catalog_path = "file_catalog.json" def tokenize_fn(examples):
data = prepare_dataset("files", catalog_path, source_mapper)
if not data:
print("\nBrak danych do treningu!")
return
#dataset = Dataset.from_list(data)
dataset = Dataset.from_dict({k: [d[k] for d in data] for k in data[0]})
def tokenize_function(examples):
tokenized = tokenizer( tokenized = tokenizer(
examples["text"], examples["text"],
truncation=True,
padding="max_length", padding="max_length",
truncation=True,
max_length=512, max_length=512,
return_tensors="pt" return_tensors="pt"
) )
@ -223,39 +176,118 @@ def main():
"input_ids": tokenized["input_ids"].squeeze(), "input_ids": tokenized["input_ids"].squeeze(),
"attention_mask": tokenized["attention_mask"].squeeze(), "attention_mask": tokenized["attention_mask"].squeeze(),
"labels": tokenized["input_ids"].squeeze().clone(), "labels": tokenized["input_ids"].squeeze().clone(),
"source_idx": examples["source_idx"] # Dodano bez konwersji do tensora "source_idx": examples["source_idx"]
} }
tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=16) tokenized_dataset = dataset.map(tokenize_fn, batched=True, batch_size=16)
model = CustomModel(model_name, AutoModelForCausalLM.from_pretrained(model_name).config) # Inicjalizacja modelu
model.source_mapper = source_mapper config = AutoModelForCausalLM.from_pretrained(model_name).config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = self.LegalModel(model_name, config).to(self.device)
model.to(device)
# Konfiguracja treningu
training_args = TrainingArguments( training_args = TrainingArguments(
output_dir="./results", output_dir="./legal_ai_model",
num_train_epochs=3, num_train_epochs=5,
per_device_train_batch_size=2, per_device_train_batch_size=2,
gradient_accumulation_steps=4, gradient_accumulation_steps=4,
learning_rate=2e-5, learning_rate=2e-5,
fp16=torch.cuda.is_available(), fp16=torch.cuda.is_available(),
logging_steps=10, logging_steps=50,
save_strategy="steps", save_strategy="steps",
save_steps=1000, save_steps=500,
report_to="none", report_to="none",
remove_unused_columns=False remove_unused_columns=False
) )
trainer = Trainer( # Customowy Trainer
class LegalTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
outputs = model(**inputs)
loss = outputs.loss
# Confidence loss
target_conf = (inputs["source_idx"] != -1).float()
conf_loss = nn.BCELoss()(outputs.confidence.squeeze(), target_conf)
total_loss = loss + 0.7*conf_loss
return (total_loss, outputs) if return_outputs else total_loss
# Trening
trainer = LegalTrainer(
model=model, model=model,
args=training_args, args=training_args,
train_dataset=tokenized_dataset, train_dataset=tokenized_dataset,
data_collator=CustomDataCollator(tokenizer=tokenizer, mlm=False) data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
) )
print("\nRozpoczęcie treningu...") print("Rozpoczęcie treningu...")
trainer.train() trainer.train()
# Zapisz model
model.save_pretrained("./trained_legal_ai")
tokenizer.save_pretrained("./trained_legal_ai")
with open("./trained_legal_ai/source_mapper.json", "w") as f:
json.dump(source_mapper.idx_to_source, f)
print("Trening zakończony i model zapisany!")
def generate_response(self, prompt, confidence_threshold=0.65):
# Ładowanie modelu
model = self.LegalModel.from_pretrained("./trained_legal_ai",
config=AutoModelForCausalLM.from_pretrained("crumb/nano-mistral").config)
tokenizer = AutoTokenizer.from_pretrained("./trained_legal_ai")
model.to(self.device)
# Ładowanie mapowania źródeł
with open("./trained_legal_ai/source_mapper.json", "r") as f:
source_mapper = json.load(f)
# Przygotowanie wejścia
inputs = tokenizer(
f"[PROMPT] {prompt} [RESPONSE]",
return_tensors="pt",
max_length=512,
truncation=True
).to(self.device)
# Generacja
with torch.no_grad():
outputs = model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
max_length=512,
do_sample=True,
temperature=0.7,
top_k=50,
pad_token_id=tokenizer.eos_token_id,
output_scores=True,
return_dict_in_generate=True
)
# Analiza wyników
full_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
confidence = torch.sigmoid(outputs.scores[-1][:, tokenizer.eos_token_id]).item()
# Ekstrakcja i weryfikacja źródeł
citations = list(set(re.findall(r"Art\.\s*\d+[a-z]*", full_text)))
verified = [c for c in citations if any(c in s for s in source_mapper.values())]
if confidence < confidence_threshold or not verified:
return "Nie mogę udzielić jednoznacznej odpowiedzi na podstawie dostępnych danych."
else:
return f"{full_text}\n\nPotwierdzone źródła: {', '.join(verified)}"
if __name__ == "__main__": if __name__ == "__main__":
main() legal_ai = LegalAITrainer()
# Etap 1: Trening
legal_ai.train(
model_name="crumb/nano-mistral",
data_dir="./legal_docs",
catalog_path="./catalog.json"
)
# Etap 2: Testowanie
test_prompt = "Ile dni urlopu przysługuje po 5 latach pracy w pełnym wymiarze?"
print(legal_ai.generate_response(test_prompt))