#!/usr/bin/python3
import os
import sys
import subprocess
import re

CUSTOM_CF_PATH = '/etc/mail/spamassassin/custom.cf'

def get_exim_conf():
    if os.path.exists('/etc/exim4/exim4.conf.template'):
        return '/etc/exim4/exim4.conf.template'
    elif os.path.exists('/etc/exim/exim.conf'):
        return '/etc/exim/exim.conf'
    return '/etc/exim.conf'

EXIM_CONF_PATH = get_exim_conf()
def get_exim_service():
    return 'exim4' if 'exim4' in EXIM_CONF_PATH else 'exim'

def esc(text):
    return str(text).replace('&', '&amp;').replace('<', '&lt;').replace('>', '&gt;').replace('"', '&quot;')

def parse_config():
    config = {
        'required_score': '5.0', 'report_safe': '1',
        'use_bayes': '1', 'bayes_auto_learn': '1',
        'ok_languages': '', 'trusted_networks': '', 'int_nets': '',
        'sa_score_reject': '50',
        'whitelist_from': [], 'blacklist_from': [],
    }
    extra_lines = []
    if os.path.exists(CUSTOM_CF_PATH):
        with open(CUSTOM_CF_PATH, 'r', encoding='utf-8', errors='replace') as f:
            for line in f:
                stripped = line.strip()
                if not stripped or stripped.startswith('#'):
                    extra_lines.append(line.rstrip('\n'))
                    continue
                parts = stripped.split(None, 1)
                if len(parts) < 2:
                    extra_lines.append(line.rstrip('\n'))
                    continue
                key, value = parts[0], parts[1]
                if key == 'internal_networks': config['int_nets'] = value
                elif key == 'whitelist_from': config['whitelist_from'].append(value)
                elif key == 'blacklist_from': config['blacklist_from'].append(value)
                elif key in config: config[key] = value
                elif key == 'loadplugin': pass
                else: extra_lines.append(line.rstrip('\n'))
    config['sa_score_reject'] = read_exim_sa_score_reject()
    # Убираем пустые строки в начале и конце extra_lines
    while extra_lines and not extra_lines[0].strip():
        extra_lines.pop(0)
    while extra_lines and not extra_lines[-1].strip():
        extra_lines.pop()
    return config, extra_lines

def read_exim_sa_score_reject():
    try:
        if os.path.exists(EXIM_CONF_PATH):
            with open(EXIM_CONF_PATH, 'r', encoding='utf-8', errors='replace') as f:
                for line in f:
                    m = re.search(r'SA_SCORE_REJECT\s*=\s*(\d+)', line)
                    if m: return m.group(1)
    except: pass
    return '50'

def write_exim_sa_score_reject(value):
    try:
        if not os.path.exists(EXIM_CONF_PATH): return
        with open(EXIM_CONF_PATH, 'r') as f: content = f.read()
        new_val = str(value).strip()
        if re.search(r'SA_SCORE_REJECT\s*=', content):
            new_content = re.sub(r'SA_SCORE_REJECT\s*=\s*\d+', 'SA_SCORE_REJECT = ' + new_val, content)
        else: new_content = 'SA_SCORE_REJECT = ' + new_val + '\n' + content
        if new_content != content:
            with open(EXIM_CONF_PATH, 'w') as f: f.write(new_content)
    except: pass

def print_form(config, extra, ok_msg=None, err_msg=None):
    b1 = config['use_bayes'].strip()
    b2 = config['bayes_auto_learn'].strip()
    bc1 = 'checked="yes"' if b1 == '1' else ''
    bc2 = 'checked="yes"' if b2 == '1' else ''
    
    print('<?xml version="1.0"?>')
    print('<doc func="sa_custom">')
    if ok_msg:
        print('<ok msg="yes"><msg name="body">' + esc(ok_msg) + '</msg></ok>')
    if err_msg:
        print('<error><msg name="body">' + esc(err_msg) + '</msg></error>')
    print('<metadata type="form">')
    print('<form>')
    print('<field name="required_score"><input type="text" name="required_score" size="5"/></field>')
    print('<field name="sa_score_reject"><input type="text" name="sa_score_reject" size="5"/></field>')
    print('<field name="report_safe"><select name="report_safe"/></field>')
    print('<field name="ok_languages"><input type="text" name="ok_languages" size="30"/></field>')
    print('<field name="trusted_networks"><input type="text" name="trusted_networks" size="50"/></field>')
    print('<field name="int_nets"><input type="text" name="int_nets" size="50"/></field>')
    print('<field name="use_bayes"><input type="checkbox" name="use_bayes" ' + bc1 + '/></field>')
    print('<field name="bayes_auto_learn"><input type="checkbox" name="bayes_auto_learn" ' + bc2 + '/></field>')
    print('<field name="whitelist_from" fullwidth="yes"><textarea name="whitelist_from" rows="6"/></field>')
    print('<field name="blacklist_from" fullwidth="yes"><textarea name="blacklist_from" rows="6"/></field>')
    print('<field name="extra_config" fullwidth="yes"><textarea name="extra_config" rows="10"/></field>')
    print('<buttons>')
    print('<button name="save" type="ok"/>')
    print('<button name="reset_bayes" type="ok"/>')
    print('</buttons>')
    print('</form>')
    print('</metadata>')
    print('<required_score>' + esc(config['required_score']) + '</required_score>')
    print('<sa_score_reject>' + esc(config['sa_score_reject']) + '</sa_score_reject>')
    print('<report_safe>' + esc(config['report_safe']) + '</report_safe>')
    print('<ok_languages>' + esc(config['ok_languages']) + '</ok_languages>')
    print('<trusted_networks>' + esc(config['trusted_networks']) + '</trusted_networks>')
    print('<int_nets>' + esc(config['int_nets']) + '</int_nets>')
    if b1 == '1': print('<use_bayes>on</use_bayes>')
    if b2 == '1': print('<bayes_auto_learn>on</bayes_auto_learn>')
    print('<whitelist_from>' + esc('\n'.join(config['whitelist_from'])) + '</whitelist_from>')
    print('<blacklist_from>' + esc('\n'.join(config['blacklist_from'])) + '</blacklist_from>')
    print('<extra_config>' + esc('\n'.join(extra)) + '</extra_config>')
    print('<slist name="report_safe">')
    print('<val key="0">0 - Add report to message body</val>')
    print('<val key="1">1 - Attach original as MIME part (safest)</val>')
    print('<val key="2">2 - Attach original as text MIME part</val>')
    print('</slist>')
    print('<messages>')
    print('<msg name="title">SpamAssassin Configuration</msg>')
    print('<msg name="required_score">Spam Threshold Score</msg>')
    print('<msg name="hint_required_score">Messages with score above this value will be marked as spam. Lower = more aggressive filtering. Default: 5.0. Recommended range: 3.0 (strict) to 7.0 (loose).</msg>')
    print('<msg name="sa_score_reject">Exim SA_SCORE_REJECT</msg>')
    print('<msg name="hint_sa_score_reject">Messages with score above this value will be rejected by Exim before delivery. Must be higher than required_score. Default: 50.</msg>')
    print('<msg name="report_safe">Report Format</msg>')
    print('<msg name="hint_report_safe">How to handle spam messages: 0 = add report to body, 1 = attach original as MIME part (safest), 2 = attach original as text MIME part.</msg>')
    print('<msg name="ok_languages">Allowed Languages</msg>')
    print('<msg name="hint_ok_languages">Only accept emails in these languages. ISO codes space-separated (e.g. ru en de). Empty = allow all. Requires TextCat plugin.</msg>')
    print('<msg name="trusted_networks">Trusted Networks</msg>')
    print('<msg name="hint_trusted_networks">IP addresses or CIDR ranges that are fully trusted. Emails from these hosts skip all spam checks. Example: 192.168.1.0/24 10.0.0.1</msg>')
    print('<msg name="int_nets">Internal Networks</msg>')
    print('<msg name="hint_int_nets">IP addresses or CIDR ranges considered internal. Must be a subnet of Trusted Networks. Example: 10.0.0.0/8 (if trusted is 10.0.0.0/8 or wider).</msg>')
    print('<msg name="use_bayes">Use Bayesian Filter</msg>')
    print('<msg name="hint_use_bayes">Enable the self-learning Bayesian classifier. It analyzes words and phrases to improve spam detection over time. Recommended: ON.</msg>')
    print('<msg name="bayes_auto_learn">Auto-learn Bayes</msg>')
    print('<msg name="hint_bayes_auto_learn">Automatically train the Bayesian filter on high-confidence emails (very high or very low spam score). Recommended: ON.</msg>')
    print('<msg name="whitelist_from">Whitelist</msg>')
    print('<msg name="hint_whitelist_from">Email addresses or domains that will NEVER be marked as spam. One entry per line. Examples: *@yourcompany.com, partner@trusted-domain.com</msg>')
    print('<msg name="blacklist_from">Blacklist</msg>')
    print('<msg name="hint_blacklist_from">Email addresses or domains that will ALWAYS be marked as spam. One entry per line. Examples: *@spam-domain.com, baduser@gmail.com</msg>')
    print('<msg name="extra_config">Additional Settings</msg>')
    print('<msg name="hint_extra_config">Any additional SpamAssassin configuration directives (custom rules, score modifications, plugin settings, etc.). These lines will be appended to custom.cf as-is.</msg>')
    print('<msg name="msg_save">Save and Restart</msg>')
    print('<msg name="msg_reset_bayes">Reset Bayes DB</msg>')
    print('<msg name="hint_msg_reset_bayes">Clear all Bayesian filter training data. Use this if the filter has learned incorrectly and needs a fresh start.</msg>')
    print('</messages>')
    print('</doc>')

def build_custom_cf(config, extra_lines):
    lines = []
    if config.get('ok_languages', '').strip(): lines.append('loadplugin Mail::SpamAssassin::Plugin::TextCat')
    lines.append('required_score ' + config.get('required_score', '5.0').strip())
    lines.append('report_safe ' + config.get('report_safe', '1').strip())
    lines.append('use_bayes ' + config.get('use_bayes', '1'))
    lines.append('bayes_auto_learn ' + config.get('bayes_auto_learn', '1'))
    for key in ['ok_languages', 'trusted_networks']:
        val = config.get(key, '').strip()
        if val: lines.append(key + ' ' + val)
    int_val = config.get('int_nets', '').strip()
    if int_val: lines.append('internal_networks ' + int_val)
    for addr in config.get('whitelist_from', []):
        if addr.strip(): lines.append('whitelist_from ' + addr.strip())
    for addr in config.get('blacklist_from', []):
        if addr.strip(): lines.append('blacklist_from ' + addr.strip())
    # Добавляем extra_lines без лишней пустой строки
    if extra_lines:
        filtered = [l for l in extra_lines if l.strip()]  # Только непустые
        if filtered:
            lines.append('')
            for line in filtered:
                if not line.strip().startswith('loadplugin'):
                    lines.append(line)
    return '\n'.join(lines) + '\n'

def lint_config(config_content):
    try:
        old_exists = os.path.exists(CUSTOM_CF_PATH)
        old_content = ''
        if old_exists:
            with open(CUSTOM_CF_PATH) as f: old_content = f.read()
        with open(CUSTOM_CF_PATH, 'w') as f: f.write(config_content)
        result = subprocess.run(['spamassassin', '--lint'], capture_output=True, text=True, timeout=30)
        if old_exists:
            with open(CUSTOM_CF_PATH, 'w') as f: f.write(old_content)
        else: os.remove(CUSTOM_CF_PATH)
        issues = []
        for line in result.stderr.split('\n') + result.stdout.split('\n'):
            if line.strip() and ('warn:' in line.lower() or 'error:' in line.lower()):
                issues.append(line.split('] ', 1)[-1] if '] ' in line else line.strip())
        return len(issues) == 0, '\n'.join(issues[-5:]) if issues else ''
    except: return True, ''

def save_custom_cf(config, extra_lines):
    with open(CUSTOM_CF_PATH, 'w') as f: f.write(build_custom_cf(config, extra_lines))

def restart_svc(name):
    for cmd in [['systemctl', 'restart', name], ['service', name, 'restart']]:
        try: subprocess.run(cmd, capture_output=True, text=True, timeout=30)
        except: pass

def restart_all():
    restart_svc('spamassassin')
    restart_svc(get_exim_service())

def reset_bayes():
    for cmd in [['sa-learn', '--clear']]:
        try: subprocess.run(cmd, capture_output=True, text=True, timeout=60)
        except: pass

# ========== MAIN ==========
try:
    clicked = os.environ.get('PARAM_clicked_button', '')
    sok = os.environ.get('sok', '')
    is_save = (clicked == 'save') or (sok == 'ok' and clicked != 'reset_bayes')
    is_reset = (clicked == 'reset_bayes')

    if is_save:
        config, old_extra = parse_config()
        for f in ['required_score', 'report_safe', 'ok_languages', 'trusted_networks', 'int_nets', 'sa_score_reject']:
            val = os.environ.get('PARAM_' + f, '')
            if val: config[f] = val
        config['use_bayes'] = '1' if os.environ.get('PARAM_use_bayes') == 'on' else '0'
        config['bayes_auto_learn'] = '1' if os.environ.get('PARAM_bayes_auto_learn') == 'on' else '0'
        if 'PARAM_whitelist_from' in os.environ:
            config['whitelist_from'] = [a.strip() for a in os.environ['PARAM_whitelist_from'].split('\n') if a.strip()]
        if 'PARAM_blacklist_from' in os.environ:
            config['blacklist_from'] = [a.strip() for a in os.environ['PARAM_blacklist_from'].split('\n') if a.strip()]
        extra_text = os.environ.get('PARAM_extra_config', '')
        new_extra = [l for l in extra_text.split('\n')]
        
        cf = build_custom_cf(config, new_extra)
        ok, err = lint_config(cf)
        if ok:
            try:
                sr = int(config['sa_score_reject'])
                if sr < 1 or sr > 1000: raise ValueError
            except:
                config, extra = parse_config()
                print_form(config, extra, err_msg='SA_SCORE_REJECT must be 1-1000')
                sys.exit(0)
            save_custom_cf(config, new_extra)
            write_exim_sa_score_reject(config['sa_score_reject'])
            restart_all()
            config, extra = parse_config()
            print_form(config, extra, ok_msg='Saved and restarted')
        else:
            config, extra = parse_config()
            print_form(config, extra, err_msg='Lint failed:\n' + err)

    elif is_reset:
        reset_bayes()
        config, extra = parse_config()
        print_form(config, extra, ok_msg='Bayes reset')

    else:
        config, extra = parse_config()
        print_form(config, extra)

except Exception as e:
    print('<?xml version="1.0"?><doc func="sa_custom"><error><msg name="body">' + esc(str(e)) + '</msg></error></doc>')
