Ten kod działa!

This commit is contained in:
l.gabrysiak 2025-02-25 23:32:39 +01:00
parent 537e191d5f
commit a0aab164cb
1 changed files with 208 additions and 243 deletions

343
hft.py
View File

@ -1,31 +1,22 @@
import os import os
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from datasets import Dataset
import re import re
import json import json
import numpy as np
import PyPDF2 import PyPDF2
import docx2txt import docx2txt
import pytesseract import pytesseract
from PIL import Image from PIL import Image
from collections import defaultdict from collections import defaultdict
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from datasets import Dataset, Features, Value
from huggingface_hub import login from huggingface_hub import login
# Konfiguracja
os.environ['TORCH_USE_CUDA_DSA'] = '1'
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
login(token="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") login(token="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")
class LegalAITrainer:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class SourceMapper: class SourceMapper:
def __init__(self): def __init__(self):
self.source_to_idx = defaultdict(lambda: len(self.source_to_idx)) self.source_to_idx = defaultdict(lambda: len(self.source_to_idx))
@ -42,255 +33,229 @@ class LegalAITrainer:
def get_source(self, idx): def get_source(self, idx):
return self.idx_to_source.get(idx, "Unknown") return self.idx_to_source.get(idx, "Unknown")
class LegalModel(nn.Module): def load_file_catalog(catalog_path):
def __init__(self, model_name, config):
super().__init__()
self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
self.source_embedding = nn.Embedding(100000, config.hidden_size, padding_idx=-1)
self.confidence_layer = nn.Linear(config.hidden_size, 1)
for param in self.base_model.parameters():
param.requires_grad = False
for layer in [self.source_embedding, self.confidence_layer]:
for param in layer.parameters():
param.requires_grad = True
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None):
if source_idx is not None:
source_idx = torch.clamp(source_idx, 0, self.source_embedding.num_embeddings-1)
source_embeds = self.source_embedding(source_idx).unsqueeze(1)
inputs_embeds = self.base_model.get_input_embeddings()(input_ids) + source_embeds
outputs = self.base_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels
)
else:
outputs = self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
confidence = torch.sigmoid(self.confidence_layer(outputs.hidden_states[-1].mean(dim=1)))
return {
"loss": outputs.loss,
"logits": outputs.logits,
"confidence": confidence,
"hidden_states": outputs.hidden_states
}
def load_file_catalog(self, catalog_path):
try: try:
with open(catalog_path, 'r', encoding='utf-8') as f: with open(catalog_path, 'r', encoding='utf-8') as file:
return json.load(f) return json.load(file)
except Exception as e: except Exception as e:
print(f"Błąd ładowania katalogu: {str(e)}") print(f"Błąd wczytywania katalogu plików: {str(e)}")
return {} return {}
def extract_text(self, file_path): def identify_legal_document(filename, file_catalog):
ext = os.path.splitext(file_path)[1].lower() base_name = os.path.splitext(filename)[0].lower()
return file_catalog.get(base_name, "Opracowanie własne")
def extract_text_from_file(file_path):
try: try:
_, ext = os.path.splitext(file_path)
ext = ext.lower()
if ext in ['.txt', '.md']: if ext in ['.txt', '.md']:
with open(file_path, 'r', encoding='utf-8') as f: with open(file_path, 'r', encoding='utf-8') as file:
return f.read() return file.read()
elif ext == '.pdf': elif ext == '.pdf':
text = "" text = ""
with open(file_path, 'rb') as f: try:
reader = PyPDF2.PdfReader(f) with open(file_path, 'rb') as file:
reader = PyPDF2.PdfReader(file)
for page in reader.pages: for page in reader.pages:
text += page.extract_text() or "" text += page.extract_text() or ""
except Exception as e:
print(f"Błąd PDF: {str(e)}")
return text return text
elif ext in ['.doc', '.docx']: elif ext in ['.doc', '.docx']:
return docx2txt.process(file_path) return docx2txt.process(file_path)
elif ext in ['.jpg', '.jpeg', '.png']: elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
return pytesseract.image_to_string(Image.open(file_path)) return pytesseract.image_to_string(Image.open(file_path))
else: else:
print(f"Nieobsługiwany format pliku: {ext}")
return "" return ""
except Exception as e: except Exception as e:
print(f"Błąd przetwarzania {file_path}: {str(e)}") print(f"Błąd ekstrakcji tekstu: {str(e)}")
return "" return ""
def prepare_data(self, data_dir, catalog_path): def prepare_dataset(directory, catalog_path, source_mapper):
catalog = self.load_file_catalog(catalog_path) file_catalog = load_file_catalog(catalog_path)
data = [] data = []
source_mapper = self.SourceMapper()
for root, _, files in os.walk(data_dir): print(f"\n{'='*50}\nDIAGNOSTYKA DANYCH\n{'='*50}")
for root, _, files in os.walk(directory):
for file in files: for file in files:
file_path = os.path.join(root, file) file_path = os.path.join(root, file)
text = self.extract_text(file_path) print(f"\nPrzetwarzanie pliku: {file_path}")
if not text: try:
text = extract_text_from_file(file_path)
if not text.strip():
print("Pominięto - brak tekstu")
continue continue
doc_type = catalog.get(os.path.splitext(file)[0].lower(), "Opracowanie własne") print(f"Długość tekstu: {len(text)} znaków")
doc_type = identify_legal_document(file, file_catalog)
print(f"Rozpoznany typ dokumentu: {doc_type}")
if doc_type != "Opracowanie własne": if doc_type != "Opracowanie własne":
articles = re.split(r'(?i)(Art\.\s*\d+[a-z]*)', text) articles = re.split(r'(?i)(Art[\.\s]+\d+[\.\s]?)', text)
for i in range(1, len(articles), 2): articles = [a.strip() for a in articles if a.strip()]
art_num = articles[i].strip()
content = articles[i+1].strip()
if len(content) < 100: print(f"Znaleziono {len(articles)} fragmentów")
for i in range(0, len(articles)-1, 2):
article_number = articles[i]
article_content = articles[i+1]
if len(article_content) < 50:
continue continue
source = f"{doc_type}, {art_num}" source = f"{doc_type}, {article_number}"
source_mapper.add_source(source) source_mapper.add_source(source)
data.append({ data.append({
"text": f"[LEGAL] {art_num} {content}", "text": f"{article_number} {article_content}",
"source_idx": source_mapper.get_idx(source), "source_idx": source_mapper.get_idx(source)
"is_legal": 1
}) })
else: else:
chunks = [f"[GENERAL] {text[i:i+512]}" for i in range(0, len(text), 512)] clean_text = re.sub(r'\s+', ' ', text).strip()
chunks = [clean_text[i:i+512] for i in range(0, len(clean_text), 512)]
chunks = [c for c in chunks if c.strip()]
for chunk in chunks: for chunk in chunks:
data.append({ data.append({
"text": chunk, "text": chunk,
"source_idx": -1, "source_idx": -1
"is_legal": 0
}) })
print(f"Dodano {len(chunks)} chunków")
features = Features({ except Exception as e:
"text": Value("string"), print(f"Błąd podczas przetwarzania pliku: {str(e)}")
"source_idx": Value("int32"), continue
"is_legal": Value("int32")
})
return Dataset.from_dict({ print(f"\nPodsumowanie przygotowania danych:")
"text": [d["text"] for d in data], print(f"Łączna liczba przykładów: {len(data)}")
"source_idx": np.array([d["source_idx"] for d in data], dtype=np.int32), if data:
"is_legal": np.array([d["is_legal"] for d in data], dtype=np.int32) print("Przykładowy wpis:")
}, features=features), source_mapper print(json.dumps(data[0], indent=2, ensure_ascii=False))
else:
print("BRAK DANYCH - sprawdź diagnostykę powyżej")
def train(self, model_name="crumb/nano-mistral", data_dir="data", catalog_path="catalog.json"): return data
dataset, source_mapper = self.prepare_data(data_dir, catalog_path)
class CustomModel(nn.Module):
def __init__(self, model_name, config):
super().__init__()
self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
self.source_embedding = nn.Embedding(10000, config.hidden_size, padding_idx=-1)
for param in self.base_model.parameters():
param.requires_grad = False
for param in self.base_model.get_output_embeddings().parameters():
param.requires_grad = True
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs):
if source_idx is not None:
valid_indices = torch.clamp(source_idx, 0, self.source_embedding.num_embeddings-1)
source_embeds = self.source_embedding(valid_indices).unsqueeze(1)
inputs_embeds = self.base_model.get_input_embeddings()(input_ids) + source_embeds
return self.base_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels,
**kwargs
)
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
**kwargs
)
def generate(self, *args, **kwargs):
return self.base_model.generate(*args, **kwargs)
class CustomDataCollator(DataCollatorForLanguageModeling):
def torch_call(self, examples):
# Przetwórz podstawowe pola
input_ids = torch.stack([torch.tensor(ex["input_ids"]) for ex in examples])
attention_mask = torch.stack([torch.tensor(ex["attention_mask"]) for ex in examples])
labels = torch.stack([torch.tensor(ex["labels"]) for ex in examples])
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels
}
# Dodaj source_idx jeśli istnieje
if "source_idx" in examples[0]:
source_idx = torch.stack([torch.tensor(ex["source_idx"]) for ex in examples])
batch["source_idx"] = source_idx
return batch
def main():
source_mapper = SourceMapper()
model_name = "crumb/nano-mistral"
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
def tokenize_fn(examples): # Przygotowanie danych
catalog_path = "file_catalog.json"
data = prepare_dataset("files", catalog_path, source_mapper)
if not data:
print("\nBrak danych do treningu!")
return
#dataset = Dataset.from_list(data)
dataset = Dataset.from_dict({k: [d[k] for d in data] for k in data[0]})
def tokenize_function(examples):
tokenized = tokenizer( tokenized = tokenizer(
examples["text"], examples["text"],
padding="max_length",
truncation=True, truncation=True,
padding="max_length",
max_length=512, max_length=512,
return_tensors="pt" return_tensors="pt"
) )
return { return {
"input_ids": tokenized["input_ids"].squeeze().tolist(), "input_ids": tokenized["input_ids"].squeeze(),
"attention_mask": tokenized["attention_mask"].squeeze().tolist(), "attention_mask": tokenized["attention_mask"].squeeze(),
"labels": tokenized["input_ids"].squeeze().clone().tolist(), "labels": tokenized["input_ids"].squeeze().clone(),
"source_idx": examples["source_idx"] "source_idx": examples["source_idx"] # Dodano bez konwersji do tensora
} }
tokenized_dataset = dataset.map(tokenize_fn, batched=True, batch_size=16) tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=16)
class CustomDataCollator(DataCollatorForLanguageModeling): model = CustomModel(model_name, AutoModelForCausalLM.from_pretrained(model_name).config)
def torch_call(self, examples): model.source_mapper = source_mapper
batch = super().torch_call(examples) device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if "source_idx" in examples[0]: model.to(device)
batch["source_idx"] = torch.tensor(
[ex["source_idx"] for ex in examples],
dtype=torch.int32
)
return batch
config = AutoModelForCausalLM.from_pretrained(model_name).config
model = self.LegalModel(model_name, config).to(self.device)
training_args = TrainingArguments( training_args = TrainingArguments(
output_dir="./legal_ai_model", output_dir="./results",
num_train_epochs=3, num_train_epochs=3,
per_device_train_batch_size=2, per_device_train_batch_size=2,
gradient_accumulation_steps=4, gradient_accumulation_steps=4,
learning_rate=2e-5, learning_rate=2e-5,
fp16=torch.cuda.is_available(), fp16=torch.cuda.is_available(),
logging_steps=50, logging_steps=10,
save_strategy="steps", save_strategy="steps",
save_steps=500, save_steps=1000,
report_to="none", report_to="none",
remove_unused_columns=False remove_unused_columns=False
) )
class LegalTrainer(Trainer): trainer = Trainer(
def compute_loss(self, model, inputs, return_outputs=False):
outputs = model(**inputs)
loss = outputs["loss"]
target_conf = (inputs["source_idx"] != -1).float()
conf_loss = nn.BCELoss()(outputs["confidence"].squeeze(), target_conf)
total_loss = loss + 0.7 * conf_loss
return (total_loss, outputs) if return_outputs else total_loss
trainer = LegalTrainer(
model=model, model=model,
args=training_args, args=training_args,
train_dataset=tokenized_dataset, train_dataset=tokenized_dataset,
data_collator=CustomDataCollator(tokenizer=tokenizer, mlm=False) data_collator=CustomDataCollator(tokenizer=tokenizer, mlm=False)
) )
print("Rozpoczęcie treningu...") print("\nRozpoczęcie treningu...")
trainer.train() trainer.train()
model.save_pretrained("./trained_legal_ai")
tokenizer.save_pretrained("./trained_legal_ai")
with open("./trained_legal_ai/source_mapper.json", "w") as f:
json.dump(source_mapper.idx_to_source, f)
print("Trening zakończony!")
def generate_response(self, prompt, confidence_threshold=0.65):
model = self.LegalModel.from_pretrained(
"./trained_legal_ai",
config=AutoModelForCausalLM.from_pretrained("crumb/nano-mistral").config
).to(self.device)
tokenizer = AutoTokenizer.from_pretrained("./trained_legal_ai")
with open("./trained_legal_ai/source_mapper.json", "r") as f:
source_mapper = json.load(f)
inputs = tokenizer(
f"[PROMPT] {prompt} [RESPONSE]",
return_tensors="pt",
max_length=512,
truncation=True
).to(self.device)
with torch.no_grad():
outputs = model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
max_length=512,
do_sample=True,
temperature=0.7,
top_k=50,
pad_token_id=tokenizer.eos_token_id,
output_scores=True,
return_dict_in_generate=True
)
full_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
confidence = torch.sigmoid(outputs.scores[-1][:, tokenizer.eos_token_id]).item()
citations = list(set(re.findall(r"Art\.\s*\d+[a-z]*", full_text)))
verified = [c for c in citations if any(c in s for s in source_mapper.values())]
if confidence < confidence_threshold or not verified:
return "Nie mogę udzielić jednoznacznej odpowiedzi na podstawie dostępnych danych."
else:
return f"{full_text}\n\nPotwierdzone źródła: {', '.join(verified)}"
if __name__ == "__main__": if __name__ == "__main__":
legal_ai = LegalAITrainer() main()
legal_ai.train(
model_name="crumb/nano-mistral",
data_dir="./legal_docs",
catalog_path="./catalog.json"
)
test_prompt = "Jakie są kary za nieprzestrzeganie przepisów RODO?"
print(legal_ai.generate_response(test_prompt))