Ten kod działa!

This commit is contained in:
l.gabrysiak 2025-02-25 23:32:39 +01:00
parent 537e191d5f
commit a0aab164cb
1 changed files with 208 additions and 243 deletions

451
hft.py
View File

@ -1,296 +1,261 @@
import os import os
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from datasets import Dataset
import re import re
import json import json
import numpy as np
import PyPDF2 import PyPDF2
import docx2txt import docx2txt
import pytesseract import pytesseract
from PIL import Image from PIL import Image
from collections import defaultdict from collections import defaultdict
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from datasets import Dataset, Features, Value
from huggingface_hub import login from huggingface_hub import login
# Konfiguracja
os.environ['TORCH_USE_CUDA_DSA'] = '1'
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
login(token="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") login(token="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")
class LegalAITrainer: class SourceMapper:
def __init__(self): def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.source_to_idx = defaultdict(lambda: len(self.source_to_idx))
self.idx_to_source = {}
class SourceMapper:
def __init__(self): def add_source(self, source):
self.source_to_idx = defaultdict(lambda: len(self.source_to_idx)) if source and source not in self.source_to_idx:
self.idx_to_source = {} idx = self.source_to_idx[source]
self.idx_to_source[idx] = source
def add_source(self, source):
if source and source not in self.source_to_idx:
idx = self.source_to_idx[source]
self.idx_to_source[idx] = source
def get_idx(self, source):
return self.source_to_idx[source] if source else -1
def get_source(self, idx):
return self.idx_to_source.get(idx, "Unknown")
class LegalModel(nn.Module):
def __init__(self, model_name, config):
super().__init__()
self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
self.source_embedding = nn.Embedding(100000, config.hidden_size, padding_idx=-1)
self.confidence_layer = nn.Linear(config.hidden_size, 1)
for param in self.base_model.parameters(): def get_idx(self, source):
param.requires_grad = False return self.source_to_idx[source] if source else -1
for layer in [self.source_embedding, self.confidence_layer]: def get_source(self, idx):
for param in layer.parameters(): return self.idx_to_source.get(idx, "Unknown")
param.requires_grad = True
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None): def load_file_catalog(catalog_path):
if source_idx is not None: try:
source_idx = torch.clamp(source_idx, 0, self.source_embedding.num_embeddings-1) with open(catalog_path, 'r', encoding='utf-8') as file:
source_embeds = self.source_embedding(source_idx).unsqueeze(1) return json.load(file)
inputs_embeds = self.base_model.get_input_embeddings()(input_ids) + source_embeds except Exception as e:
outputs = self.base_model( print(f"Błąd wczytywania katalogu plików: {str(e)}")
inputs_embeds=inputs_embeds, return {}
attention_mask=attention_mask,
labels=labels
)
else:
outputs = self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
confidence = torch.sigmoid(self.confidence_layer(outputs.hidden_states[-1].mean(dim=1)))
return {
"loss": outputs.loss,
"logits": outputs.logits,
"confidence": confidence,
"hidden_states": outputs.hidden_states
}
def load_file_catalog(self, catalog_path): def identify_legal_document(filename, file_catalog):
try: base_name = os.path.splitext(filename)[0].lower()
with open(catalog_path, 'r', encoding='utf-8') as f: return file_catalog.get(base_name, "Opracowanie własne")
return json.load(f)
except Exception as e:
print(f"Błąd ładowania katalogu: {str(e)}")
return {}
def extract_text(self, file_path): def extract_text_from_file(file_path):
ext = os.path.splitext(file_path)[1].lower() try:
try: _, ext = os.path.splitext(file_path)
if ext in ['.txt', '.md']: ext = ext.lower()
with open(file_path, 'r', encoding='utf-8') as f:
return f.read() if ext in ['.txt', '.md']:
elif ext == '.pdf': with open(file_path, 'r', encoding='utf-8') as file:
text = "" return file.read()
with open(file_path, 'rb') as f: elif ext == '.pdf':
reader = PyPDF2.PdfReader(f) text = ""
try:
with open(file_path, 'rb') as file:
reader = PyPDF2.PdfReader(file)
for page in reader.pages: for page in reader.pages:
text += page.extract_text() or "" text += page.extract_text() or ""
return text except Exception as e:
elif ext in ['.doc', '.docx']: print(f"Błąd PDF: {str(e)}")
return docx2txt.process(file_path) return text
elif ext in ['.jpg', '.jpeg', '.png']: elif ext in ['.doc', '.docx']:
return pytesseract.image_to_string(Image.open(file_path)) return docx2txt.process(file_path)
else: elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
return "" return pytesseract.image_to_string(Image.open(file_path))
except Exception as e: else:
print(f"Błąd przetwarzania {file_path}: {str(e)}") print(f"Nieobsługiwany format pliku: {ext}")
return "" return ""
except Exception as e:
print(f"Błąd ekstrakcji tekstu: {str(e)}")
return ""
def prepare_data(self, data_dir, catalog_path): def prepare_dataset(directory, catalog_path, source_mapper):
catalog = self.load_file_catalog(catalog_path) file_catalog = load_file_catalog(catalog_path)
data = [] data = []
source_mapper = self.SourceMapper()
print(f"\n{'='*50}\nDIAGNOSTYKA DANYCH\n{'='*50}")
for root, _, files in os.walk(data_dir):
for file in files: for root, _, files in os.walk(directory):
file_path = os.path.join(root, file) for file in files:
text = self.extract_text(file_path) file_path = os.path.join(root, file)
print(f"\nPrzetwarzanie pliku: {file_path}")
if not text:
try:
text = extract_text_from_file(file_path)
if not text.strip():
print("Pominięto - brak tekstu")
continue continue
print(f"Długość tekstu: {len(text)} znaków")
doc_type = catalog.get(os.path.splitext(file)[0].lower(), "Opracowanie własne") doc_type = identify_legal_document(file, file_catalog)
print(f"Rozpoznany typ dokumentu: {doc_type}")
if doc_type != "Opracowanie własne": if doc_type != "Opracowanie własne":
articles = re.split(r'(?i)(Art\.\s*\d+[a-z]*)', text) articles = re.split(r'(?i)(Art[\.\s]+\d+[\.\s]?)', text)
for i in range(1, len(articles), 2): articles = [a.strip() for a in articles if a.strip()]
art_num = articles[i].strip()
content = articles[i+1].strip() print(f"Znaleziono {len(articles)} fragmentów")
for i in range(0, len(articles)-1, 2):
article_number = articles[i]
article_content = articles[i+1]
if len(content) < 100: if len(article_content) < 50:
continue continue
source = f"{doc_type}, {art_num}" source = f"{doc_type}, {article_number}"
source_mapper.add_source(source) source_mapper.add_source(source)
data.append({ data.append({
"text": f"[LEGAL] {art_num} {content}", "text": f"{article_number} {article_content}",
"source_idx": source_mapper.get_idx(source), "source_idx": source_mapper.get_idx(source)
"is_legal": 1
}) })
else: else:
chunks = [f"[GENERAL] {text[i:i+512]}" for i in range(0, len(text), 512)] clean_text = re.sub(r'\s+', ' ', text).strip()
chunks = [clean_text[i:i+512] for i in range(0, len(clean_text), 512)]
chunks = [c for c in chunks if c.strip()]
for chunk in chunks: for chunk in chunks:
data.append({ data.append({
"text": chunk, "text": chunk,
"source_idx": -1, "source_idx": -1
"is_legal": 0
}) })
print(f"Dodano {len(chunks)} chunków")
except Exception as e:
print(f"Błąd podczas przetwarzania pliku: {str(e)}")
continue
print(f"\nPodsumowanie przygotowania danych:")
print(f"Łączna liczba przykładów: {len(data)}")
if data:
print("Przykładowy wpis:")
print(json.dumps(data[0], indent=2, ensure_ascii=False))
else:
print("BRAK DANYCH - sprawdź diagnostykę powyżej")
features = Features({ return data
"text": Value("string"),
"source_idx": Value("int32"), class CustomModel(nn.Module):
"is_legal": Value("int32") def __init__(self, model_name, config):
}) super().__init__()
self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
self.source_embedding = nn.Embedding(10000, config.hidden_size, padding_idx=-1)
return Dataset.from_dict({ for param in self.base_model.parameters():
"text": [d["text"] for d in data], param.requires_grad = False
"source_idx": np.array([d["source_idx"] for d in data], dtype=np.int32), for param in self.base_model.get_output_embeddings().parameters():
"is_legal": np.array([d["is_legal"] for d in data], dtype=np.int32) param.requires_grad = True
}, features=features), source_mapper
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs):
def train(self, model_name="crumb/nano-mistral", data_dir="data", catalog_path="catalog.json"): if source_idx is not None:
dataset, source_mapper = self.prepare_data(data_dir, catalog_path) valid_indices = torch.clamp(source_idx, 0, self.source_embedding.num_embeddings-1)
tokenizer = AutoTokenizer.from_pretrained(model_name) source_embeds = self.source_embedding(valid_indices).unsqueeze(1)
tokenizer.pad_token = tokenizer.eos_token inputs_embeds = self.base_model.get_input_embeddings()(input_ids) + source_embeds
return self.base_model(
def tokenize_fn(examples): inputs_embeds=inputs_embeds,
tokenized = tokenizer( attention_mask=attention_mask,
examples["text"], labels=labels,
padding="max_length", **kwargs
truncation=True,
max_length=512,
return_tensors="pt"
) )
return { return self.base_model(
"input_ids": tokenized["input_ids"].squeeze().tolist(), input_ids=input_ids,
"attention_mask": tokenized["attention_mask"].squeeze().tolist(), attention_mask=attention_mask,
"labels": tokenized["input_ids"].squeeze().clone().tolist(), labels=labels,
"source_idx": examples["source_idx"] **kwargs
}
tokenized_dataset = dataset.map(tokenize_fn, batched=True, batch_size=16)
class CustomDataCollator(DataCollatorForLanguageModeling):
def torch_call(self, examples):
batch = super().torch_call(examples)
if "source_idx" in examples[0]:
batch["source_idx"] = torch.tensor(
[ex["source_idx"] for ex in examples],
dtype=torch.int32
)
return batch
config = AutoModelForCausalLM.from_pretrained(model_name).config
model = self.LegalModel(model_name, config).to(self.device)
training_args = TrainingArguments(
output_dir="./legal_ai_model",
num_train_epochs=3,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
learning_rate=2e-5,
fp16=torch.cuda.is_available(),
logging_steps=50,
save_strategy="steps",
save_steps=500,
report_to="none",
remove_unused_columns=False
) )
def generate(self, *args, **kwargs):
return self.base_model.generate(*args, **kwargs)
class LegalTrainer(Trainer): class CustomDataCollator(DataCollatorForLanguageModeling):
def compute_loss(self, model, inputs, return_outputs=False): def torch_call(self, examples):
outputs = model(**inputs) # Przetwórz podstawowe pola
loss = outputs["loss"] input_ids = torch.stack([torch.tensor(ex["input_ids"]) for ex in examples])
attention_mask = torch.stack([torch.tensor(ex["attention_mask"]) for ex in examples])
target_conf = (inputs["source_idx"] != -1).float() labels = torch.stack([torch.tensor(ex["labels"]) for ex in examples])
conf_loss = nn.BCELoss()(outputs["confidence"].squeeze(), target_conf)
total_loss = loss + 0.7 * conf_loss
return (total_loss, outputs) if return_outputs else total_loss
trainer = LegalTrainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=CustomDataCollator(tokenizer=tokenizer, mlm=False)
)
print("Rozpoczęcie treningu...")
trainer.train()
model.save_pretrained("./trained_legal_ai") batch = {
tokenizer.save_pretrained("./trained_legal_ai") "input_ids": input_ids,
with open("./trained_legal_ai/source_mapper.json", "w") as f: "attention_mask": attention_mask,
json.dump(source_mapper.idx_to_source, f) "labels": labels
}
print("Trening zakończony!")
def generate_response(self, prompt, confidence_threshold=0.65):
model = self.LegalModel.from_pretrained(
"./trained_legal_ai",
config=AutoModelForCausalLM.from_pretrained("crumb/nano-mistral").config
).to(self.device)
tokenizer = AutoTokenizer.from_pretrained("./trained_legal_ai") # Dodaj source_idx jeśli istnieje
if "source_idx" in examples[0]:
with open("./trained_legal_ai/source_mapper.json", "r") as f: source_idx = torch.stack([torch.tensor(ex["source_idx"]) for ex in examples])
source_mapper = json.load(f) batch["source_idx"] = source_idx
return batch
inputs = tokenizer( def main():
f"[PROMPT] {prompt} [RESPONSE]", source_mapper = SourceMapper()
return_tensors="pt", model_name = "crumb/nano-mistral"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
# Przygotowanie danych
catalog_path = "file_catalog.json"
data = prepare_dataset("files", catalog_path, source_mapper)
if not data:
print("\nBrak danych do treningu!")
return
#dataset = Dataset.from_list(data)
dataset = Dataset.from_dict({k: [d[k] for d in data] for k in data[0]})
def tokenize_function(examples):
tokenized = tokenizer(
examples["text"],
truncation=True,
padding="max_length",
max_length=512, max_length=512,
truncation=True return_tensors="pt"
).to(self.device) )
return {
"input_ids": tokenized["input_ids"].squeeze(),
"attention_mask": tokenized["attention_mask"].squeeze(),
"labels": tokenized["input_ids"].squeeze().clone(),
"source_idx": examples["source_idx"] # Dodano bez konwersji do tensora
}
with torch.no_grad(): tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=16)
outputs = model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
max_length=512,
do_sample=True,
temperature=0.7,
top_k=50,
pad_token_id=tokenizer.eos_token_id,
output_scores=True,
return_dict_in_generate=True
)
full_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) model = CustomModel(model_name, AutoModelForCausalLM.from_pretrained(model_name).config)
confidence = torch.sigmoid(outputs.scores[-1][:, tokenizer.eos_token_id]).item() model.source_mapper = source_mapper
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
citations = list(set(re.findall(r"Art\.\s*\d+[a-z]*", full_text))) model.to(device)
verified = [c for c in citations if any(c in s for s in source_mapper.values())]
if confidence < confidence_threshold or not verified: training_args = TrainingArguments(
return "Nie mogę udzielić jednoznacznej odpowiedzi na podstawie dostępnych danych." output_dir="./results",
else: num_train_epochs=3,
return f"{full_text}\n\nPotwierdzone źródła: {', '.join(verified)}" per_device_train_batch_size=2,
gradient_accumulation_steps=4,
learning_rate=2e-5,
fp16=torch.cuda.is_available(),
logging_steps=10,
save_strategy="steps",
save_steps=1000,
report_to="none",
remove_unused_columns=False
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=CustomDataCollator(tokenizer=tokenizer, mlm=False)
)
print("\nRozpoczęcie treningu...")
trainer.train()
if __name__ == "__main__": if __name__ == "__main__":
legal_ai = LegalAITrainer() main()
legal_ai.train(
model_name="crumb/nano-mistral",
data_dir="./legal_docs",
catalog_path="./catalog.json"
)
test_prompt = "Jakie są kary za nieprzestrzeganie przepisów RODO?"
print(legal_ai.generate_response(test_prompt))