Ten kod działa!

This commit is contained in:
l.gabrysiak 2025-02-25 23:32:39 +01:00
parent 537e191d5f
commit a0aab164cb
1 changed files with 208 additions and 243 deletions

423
hft.py
View File

@ -1,296 +1,261 @@
import os
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from datasets import Dataset
import re
import json
import numpy as np
import PyPDF2
import docx2txt
import pytesseract
from PIL import Image
from collections import defaultdict
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from datasets import Dataset, Features, Value
from huggingface_hub import login
# Konfiguracja
os.environ['TORCH_USE_CUDA_DSA'] = '1'
os.environ["TOKENIZERS_PARALLELISM"] = "false"
login(token="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")
class LegalAITrainer:
class SourceMapper:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.source_to_idx = defaultdict(lambda: len(self.source_to_idx))
self.idx_to_source = {}
class SourceMapper:
def __init__(self):
self.source_to_idx = defaultdict(lambda: len(self.source_to_idx))
self.idx_to_source = {}
def add_source(self, source):
if source and source not in self.source_to_idx:
idx = self.source_to_idx[source]
self.idx_to_source[idx] = source
def add_source(self, source):
if source and source not in self.source_to_idx:
idx = self.source_to_idx[source]
self.idx_to_source[idx] = source
def get_idx(self, source):
return self.source_to_idx[source] if source else -1
def get_idx(self, source):
return self.source_to_idx[source] if source else -1
def get_source(self, idx):
return self.idx_to_source.get(idx, "Unknown")
def get_source(self, idx):
return self.idx_to_source.get(idx, "Unknown")
def load_file_catalog(catalog_path):
try:
with open(catalog_path, 'r', encoding='utf-8') as file:
return json.load(file)
except Exception as e:
print(f"Błąd wczytywania katalogu plików: {str(e)}")
return {}
class LegalModel(nn.Module):
def __init__(self, model_name, config):
super().__init__()
self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
self.source_embedding = nn.Embedding(100000, config.hidden_size, padding_idx=-1)
self.confidence_layer = nn.Linear(config.hidden_size, 1)
def identify_legal_document(filename, file_catalog):
base_name = os.path.splitext(filename)[0].lower()
return file_catalog.get(base_name, "Opracowanie własne")
for param in self.base_model.parameters():
param.requires_grad = False
def extract_text_from_file(file_path):
try:
_, ext = os.path.splitext(file_path)
ext = ext.lower()
for layer in [self.source_embedding, self.confidence_layer]:
for param in layer.parameters():
param.requires_grad = True
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None):
if source_idx is not None:
source_idx = torch.clamp(source_idx, 0, self.source_embedding.num_embeddings-1)
source_embeds = self.source_embedding(source_idx).unsqueeze(1)
inputs_embeds = self.base_model.get_input_embeddings()(input_ids) + source_embeds
outputs = self.base_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels
)
else:
outputs = self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
confidence = torch.sigmoid(self.confidence_layer(outputs.hidden_states[-1].mean(dim=1)))
return {
"loss": outputs.loss,
"logits": outputs.logits,
"confidence": confidence,
"hidden_states": outputs.hidden_states
}
def load_file_catalog(self, catalog_path):
try:
with open(catalog_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
print(f"Błąd ładowania katalogu: {str(e)}")
return {}
def extract_text(self, file_path):
ext = os.path.splitext(file_path)[1].lower()
try:
if ext in ['.txt', '.md']:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
elif ext == '.pdf':
text = ""
with open(file_path, 'rb') as f:
reader = PyPDF2.PdfReader(f)
if ext in ['.txt', '.md']:
with open(file_path, 'r', encoding='utf-8') as file:
return file.read()
elif ext == '.pdf':
text = ""
try:
with open(file_path, 'rb') as file:
reader = PyPDF2.PdfReader(file)
for page in reader.pages:
text += page.extract_text() or ""
return text
elif ext in ['.doc', '.docx']:
return docx2txt.process(file_path)
elif ext in ['.jpg', '.jpeg', '.png']:
return pytesseract.image_to_string(Image.open(file_path))
else:
return ""
except Exception as e:
print(f"Błąd przetwarzania {file_path}: {str(e)}")
except Exception as e:
print(f"Błąd PDF: {str(e)}")
return text
elif ext in ['.doc', '.docx']:
return docx2txt.process(file_path)
elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
return pytesseract.image_to_string(Image.open(file_path))
else:
print(f"Nieobsługiwany format pliku: {ext}")
return ""
except Exception as e:
print(f"Błąd ekstrakcji tekstu: {str(e)}")
return ""
def prepare_data(self, data_dir, catalog_path):
catalog = self.load_file_catalog(catalog_path)
data = []
source_mapper = self.SourceMapper()
def prepare_dataset(directory, catalog_path, source_mapper):
file_catalog = load_file_catalog(catalog_path)
data = []
for root, _, files in os.walk(data_dir):
for file in files:
file_path = os.path.join(root, file)
text = self.extract_text(file_path)
print(f"\n{'='*50}\nDIAGNOSTYKA DANYCH\n{'='*50}")
if not text:
for root, _, files in os.walk(directory):
for file in files:
file_path = os.path.join(root, file)
print(f"\nPrzetwarzanie pliku: {file_path}")
try:
text = extract_text_from_file(file_path)
if not text.strip():
print("Pominięto - brak tekstu")
continue
doc_type = catalog.get(os.path.splitext(file)[0].lower(), "Opracowanie własne")
print(f"Długość tekstu: {len(text)} znaków")
doc_type = identify_legal_document(file, file_catalog)
print(f"Rozpoznany typ dokumentu: {doc_type}")
if doc_type != "Opracowanie własne":
articles = re.split(r'(?i)(Art\.\s*\d+[a-z]*)', text)
for i in range(1, len(articles), 2):
art_num = articles[i].strip()
content = articles[i+1].strip()
articles = re.split(r'(?i)(Art[\.\s]+\d+[\.\s]?)', text)
articles = [a.strip() for a in articles if a.strip()]
if len(content) < 100:
print(f"Znaleziono {len(articles)} fragmentów")
for i in range(0, len(articles)-1, 2):
article_number = articles[i]
article_content = articles[i+1]
if len(article_content) < 50:
continue
source = f"{doc_type}, {art_num}"
source = f"{doc_type}, {article_number}"
source_mapper.add_source(source)
data.append({
"text": f"[LEGAL] {art_num} {content}",
"source_idx": source_mapper.get_idx(source),
"is_legal": 1
"text": f"{article_number} {article_content}",
"source_idx": source_mapper.get_idx(source)
})
else:
chunks = [f"[GENERAL] {text[i:i+512]}" for i in range(0, len(text), 512)]
clean_text = re.sub(r'\s+', ' ', text).strip()
chunks = [clean_text[i:i+512] for i in range(0, len(clean_text), 512)]
chunks = [c for c in chunks if c.strip()]
for chunk in chunks:
data.append({
"text": chunk,
"source_idx": -1,
"is_legal": 0
"source_idx": -1
})
print(f"Dodano {len(chunks)} chunków")
features = Features({
"text": Value("string"),
"source_idx": Value("int32"),
"is_legal": Value("int32")
})
except Exception as e:
print(f"Błąd podczas przetwarzania pliku: {str(e)}")
continue
return Dataset.from_dict({
"text": [d["text"] for d in data],
"source_idx": np.array([d["source_idx"] for d in data], dtype=np.int32),
"is_legal": np.array([d["is_legal"] for d in data], dtype=np.int32)
}, features=features), source_mapper
print(f"\nPodsumowanie przygotowania danych:")
print(f"Łączna liczba przykładów: {len(data)}")
if data:
print("Przykładowy wpis:")
print(json.dumps(data[0], indent=2, ensure_ascii=False))
else:
print("BRAK DANYCH - sprawdź diagnostykę powyżej")
def train(self, model_name="crumb/nano-mistral", data_dir="data", catalog_path="catalog.json"):
dataset, source_mapper = self.prepare_data(data_dir, catalog_path)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
return data
def tokenize_fn(examples):
tokenized = tokenizer(
examples["text"],
padding="max_length",
truncation=True,
max_length=512,
return_tensors="pt"
class CustomModel(nn.Module):
def __init__(self, model_name, config):
super().__init__()
self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
self.source_embedding = nn.Embedding(10000, config.hidden_size, padding_idx=-1)
for param in self.base_model.parameters():
param.requires_grad = False
for param in self.base_model.get_output_embeddings().parameters():
param.requires_grad = True
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs):
if source_idx is not None:
valid_indices = torch.clamp(source_idx, 0, self.source_embedding.num_embeddings-1)
source_embeds = self.source_embedding(valid_indices).unsqueeze(1)
inputs_embeds = self.base_model.get_input_embeddings()(input_ids) + source_embeds
return self.base_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels,
**kwargs
)
return {
"input_ids": tokenized["input_ids"].squeeze().tolist(),
"attention_mask": tokenized["attention_mask"].squeeze().tolist(),
"labels": tokenized["input_ids"].squeeze().clone().tolist(),
"source_idx": examples["source_idx"]
}
tokenized_dataset = dataset.map(tokenize_fn, batched=True, batch_size=16)
class CustomDataCollator(DataCollatorForLanguageModeling):
def torch_call(self, examples):
batch = super().torch_call(examples)
if "source_idx" in examples[0]:
batch["source_idx"] = torch.tensor(
[ex["source_idx"] for ex in examples],
dtype=torch.int32
)
return batch
config = AutoModelForCausalLM.from_pretrained(model_name).config
model = self.LegalModel(model_name, config).to(self.device)
training_args = TrainingArguments(
output_dir="./legal_ai_model",
num_train_epochs=3,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
learning_rate=2e-5,
fp16=torch.cuda.is_available(),
logging_steps=50,
save_strategy="steps",
save_steps=500,
report_to="none",
remove_unused_columns=False
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
**kwargs
)
class LegalTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
outputs = model(**inputs)
loss = outputs["loss"]
def generate(self, *args, **kwargs):
return self.base_model.generate(*args, **kwargs)
target_conf = (inputs["source_idx"] != -1).float()
conf_loss = nn.BCELoss()(outputs["confidence"].squeeze(), target_conf)
class CustomDataCollator(DataCollatorForLanguageModeling):
def torch_call(self, examples):
# Przetwórz podstawowe pola
input_ids = torch.stack([torch.tensor(ex["input_ids"]) for ex in examples])
attention_mask = torch.stack([torch.tensor(ex["attention_mask"]) for ex in examples])
labels = torch.stack([torch.tensor(ex["labels"]) for ex in examples])
total_loss = loss + 0.7 * conf_loss
return (total_loss, outputs) if return_outputs else total_loss
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels
}
trainer = LegalTrainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=CustomDataCollator(tokenizer=tokenizer, mlm=False)
)
# Dodaj source_idx jeśli istnieje
if "source_idx" in examples[0]:
source_idx = torch.stack([torch.tensor(ex["source_idx"]) for ex in examples])
batch["source_idx"] = source_idx
print("Rozpoczęcie treningu...")
trainer.train()
return batch
model.save_pretrained("./trained_legal_ai")
tokenizer.save_pretrained("./trained_legal_ai")
with open("./trained_legal_ai/source_mapper.json", "w") as f:
json.dump(source_mapper.idx_to_source, f)
def main():
source_mapper = SourceMapper()
model_name = "crumb/nano-mistral"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
print("Trening zakończony!")
# Przygotowanie danych
catalog_path = "file_catalog.json"
data = prepare_dataset("files", catalog_path, source_mapper)
def generate_response(self, prompt, confidence_threshold=0.65):
model = self.LegalModel.from_pretrained(
"./trained_legal_ai",
config=AutoModelForCausalLM.from_pretrained("crumb/nano-mistral").config
).to(self.device)
if not data:
print("\nBrak danych do treningu!")
return
tokenizer = AutoTokenizer.from_pretrained("./trained_legal_ai")
#dataset = Dataset.from_list(data)
dataset = Dataset.from_dict({k: [d[k] for d in data] for k in data[0]})
with open("./trained_legal_ai/source_mapper.json", "r") as f:
source_mapper = json.load(f)
inputs = tokenizer(
f"[PROMPT] {prompt} [RESPONSE]",
return_tensors="pt",
def tokenize_function(examples):
tokenized = tokenizer(
examples["text"],
truncation=True,
padding="max_length",
max_length=512,
truncation=True
).to(self.device)
return_tensors="pt"
)
return {
"input_ids": tokenized["input_ids"].squeeze(),
"attention_mask": tokenized["attention_mask"].squeeze(),
"labels": tokenized["input_ids"].squeeze().clone(),
"source_idx": examples["source_idx"] # Dodano bez konwersji do tensora
}
with torch.no_grad():
outputs = model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
max_length=512,
do_sample=True,
temperature=0.7,
top_k=50,
pad_token_id=tokenizer.eos_token_id,
output_scores=True,
return_dict_in_generate=True
)
tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=16)
full_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
confidence = torch.sigmoid(outputs.scores[-1][:, tokenizer.eos_token_id]).item()
model = CustomModel(model_name, AutoModelForCausalLM.from_pretrained(model_name).config)
model.source_mapper = source_mapper
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
citations = list(set(re.findall(r"Art\.\s*\d+[a-z]*", full_text)))
verified = [c for c in citations if any(c in s for s in source_mapper.values())]
if confidence < confidence_threshold or not verified:
return "Nie mogę udzielić jednoznacznej odpowiedzi na podstawie dostępnych danych."
else:
return f"{full_text}\n\nPotwierdzone źródła: {', '.join(verified)}"
if __name__ == "__main__":
legal_ai = LegalAITrainer()
legal_ai.train(
model_name="crumb/nano-mistral",
data_dir="./legal_docs",
catalog_path="./catalog.json"
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
learning_rate=2e-5,
fp16=torch.cuda.is_available(),
logging_steps=10,
save_strategy="steps",
save_steps=1000,
report_to="none",
remove_unused_columns=False
)
test_prompt = "Jakie są kary za nieprzestrzeganie przepisów RODO?"
print(legal_ai.generate_response(test_prompt))
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=CustomDataCollator(tokenizer=tokenizer, mlm=False)
)
print("\nRozpoczęcie treningu...")
trainer.train()
if __name__ == "__main__":
main()