ably.do/hft.py

292 lines
11 KiB
Python

import os
import torch
import torch.nn as nn
import re
import json
import PyPDF2
import docx2txt
import pytesseract
from PIL import Image
from collections import defaultdict
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from datasets import Dataset, Features, Value, Sequence
from huggingface_hub import login
# Konfiguracja
os.environ["TOKENIZERS_PARALLELISM"] = "false"
login(token="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")
class LegalAITrainer:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class SourceMapper:
def __init__(self):
self.source_to_idx = defaultdict(lambda: len(self.source_to_idx))
self.idx_to_source = {}
def add_source(self, source):
if source and source not in self.source_to_idx:
idx = self.source_to_idx[source]
self.idx_to_source[idx] = source
def get_idx(self, source):
return self.source_to_idx[source] if source else -1
def get_source(self, idx):
return self.idx_to_source.get(idx, "Unknown")
class LegalModel(nn.Module):
def __init__(self, model_name, config):
super().__init__()
self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
self.source_embedding = nn.Embedding(100000, config.hidden_size, padding_idx=-1)
self.confidence_layer = nn.Linear(config.hidden_size, 1)
# Freeze base model
for param in self.base_model.parameters():
param.requires_grad = False
# Trainable components
for layer in [self.source_embedding, self.confidence_layer]:
for param in layer.parameters():
param.requires_grad = True
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None):
if source_idx is not None:
source_idx = torch.clamp(source_idx, 0, self.source_embedding.num_embeddings-1)
source_embeds = self.source_embedding(source_idx).unsqueeze(1)
inputs_embeds = self.base_model.get_input_embeddings()(input_ids) + source_embeds
outputs = self.base_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels
)
else:
outputs = self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
confidence = torch.sigmoid(self.confidence_layer(outputs.hidden_states[-1].mean(dim=1)))
return {
"loss": outputs.loss,
"logits": outputs.logits,
"confidence": confidence,
"hidden_states": outputs.hidden_states
}
def load_file_catalog(self, catalog_path):
try:
with open(catalog_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
print(f"Błąd ładowania katalogu: {str(e)}")
return {}
def extract_text(self, file_path):
ext = os.path.splitext(file_path)[1].lower()
try:
if ext in ['.txt', '.md']:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
elif ext == '.pdf':
text = ""
with open(file_path, 'rb') as f:
reader = PyPDF2.PdfReader(f)
for page in reader.pages:
text += page.extract_text() or ""
return text
elif ext in ['.doc', '.docx']:
return docx2txt.process(file_path)
elif ext in ['.jpg', '.jpeg', '.png']:
return pytesseract.image_to_string(Image.open(file_path))
else:
return ""
except Exception as e:
print(f"Błąd przetwarzania {file_path}: {str(e)}")
return ""
def prepare_data(self, data_dir, catalog_path):
catalog = self.load_file_catalog(catalog_path)
data = []
source_mapper = self.SourceMapper()
for root, _, files in os.walk(data_dir):
for file in files:
file_path = os.path.join(root, file)
text = self.extract_text(file_path)
if not text:
continue
doc_type = catalog.get(os.path.splitext(file)[0].lower(), "Opracowanie własne")
if doc_type != "Opracowanie własne":
articles = re.split(r'(?i)(Art\.\s*\d+[a-z]*)', text)
for i in range(1, len(articles), 2):
art_num = articles[i].strip()
content = articles[i+1].strip()
if len(content) < 100:
continue
source = f"{doc_type}, {art_num}"
source_mapper.add_source(source)
data.append({
"text": f"[LEGAL] {art_num} {content}",
"source_idx": source_mapper.get_idx(source),
"is_legal": 1
})
else:
chunks = [f"[GENERAL] {text[i:i+512]}" for i in range(0, len(text), 512)]
for chunk in chunks:
data.append({
"text": chunk,
"source_idx": -1,
"is_legal": 0
})
features = Features({
"text": Value("string"),
"source_idx": Sequence(Value("int32")),
"is_legal": Value("int32")
})
return Dataset.from_dict({
"text": [d["text"] for d in data],
"source_idx": [[d["source_idx"]] for d in data], # Zwracamy jako listę list
"is_legal": [d["is_legal"] for d in data]
}, features=features), source_mapper
def train(self, model_name="crumb/nano-mistral", data_dir="data", catalog_path="catalog.json"):
dataset, source_mapper = self.prepare_data(data_dir, catalog_path)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
def tokenize_fn(examples):
tokenized = tokenizer(
examples["text"],
padding="max_length",
truncation=True,
max_length=512,
return_tensors="pt"
)
# Konwersja tensorów do list i odpowiednich typów
return {
"input_ids": [ids.tolist() for ids in tokenized["input_ids"]],
"attention_mask": [mask.tolist() for mask in tokenized["attention_mask"]],
"labels": [labels.tolist() for labels in tokenized["input_ids"]],
"source_idx": [[idx] for idx in examples["source_idx"]] # Sekwencja długości 1
}
tokenized_dataset = dataset.map(tokenize_fn, batched=True, batch_size=16)
config = AutoModelForCausalLM.from_pretrained(model_name).config
model = self.LegalModel(model_name, config).to(self.device)
training_args = TrainingArguments(
output_dir="./legal_ai_model",
num_train_epochs=3,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
learning_rate=2e-5,
fp16=torch.cuda.is_available(),
logging_steps=50,
save_strategy="steps",
save_steps=500,
report_to="none",
remove_unused_columns=False
)
class LegalTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
outputs = model(**inputs)
loss = outputs["loss"]
target_conf = (inputs["source_idx"] != -1).float()
conf_loss = nn.BCELoss()(outputs["confidence"].squeeze(), target_conf)
total_loss = loss + 0.7 * conf_loss
return (total_loss, outputs) if return_outputs else total_loss
trainer = LegalTrainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
)
print("Rozpoczęcie treningu...")
trainer.train()
model.save_pretrained("./trained_legal_ai")
tokenizer.save_pretrained("./trained_legal_ai")
with open("./trained_legal_ai/source_mapper.json", "w") as f:
json.dump(source_mapper.idx_to_source, f)
print("Trening zakończony!")
def generate_response(self, prompt, confidence_threshold=0.65):
model = self.LegalModel.from_pretrained(
"./trained_legal_ai",
config=AutoModelForCausalLM.from_pretrained("crumb/nano-mistral").config
).to(self.device)
tokenizer = AutoTokenizer.from_pretrained("./trained_legal_ai")
with open("./trained_legal_ai/source_mapper.json", "r") as f:
source_mapper = json.load(f)
inputs = tokenizer(
f"[PROMPT] {prompt} [RESPONSE]",
return_tensors="pt",
max_length=512,
truncation=True
).to(self.device)
with torch.no_grad():
outputs = model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
max_length=512,
do_sample=True,
temperature=0.7,
top_k=50,
pad_token_id=tokenizer.eos_token_id,
output_scores=True,
return_dict_in_generate=True
)
full_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
confidence = torch.sigmoid(outputs.scores[-1][:, tokenizer.eos_token_id]).item()
citations = list(set(re.findall(r"Art\.\s*\d+[a-z]*", full_text)))
verified = [c for c in citations if any(c in s for s in source_mapper.values())]
if confidence < confidence_threshold or not verified:
return "Nie mogę udzielić jednoznacznej odpowiedzi na podstawie dostępnych danych."
else:
return f"{full_text}\n\nPotwierdzone źródła: {', '.join(verified)}"
if __name__ == "__main__":
legal_ai = LegalAITrainer()
# Trening
legal_ai.train(
model_name="crumb/nano-mistral",
data_dir="./legal_docs",
catalog_path="./catalog.json"
)
# Test
test_prompt = "Jakie są obowiązki pracodawcy w zakresie BHP?"
print(legal_ai.generate_response(test_prompt))