#!/usr/bin/python
#
#  Check SPF results and provide recommended action back to Postfix.
#
#  Tumgreyspf source
#  Copyright (c) 2004-2005, Sean Reifschneider, tummy.com, ltd.
#  <jafo@tummy.com>
#
#  pypolicyd-spf
#  Copyright (c) 2007, Scott Kitterman <scott@kitterman.com>
'''
    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License version 2 as published 
    by the Free Software Foundation.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License along
    with this program; if not, write to the Free Software Foundation, Inc.,
    51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.'''

__version__ = "0.4.1: August 12, 2007"

import syslog, os, sys, string, re, time, popen2, urllib, stat, errno, socket, spf
sys.path.append('/usr/local/lib/policy-spf')
import policydspfsupp

syslog.openlog(os.path.basename(sys.argv[0]), syslog.LOG_PID, syslog.LOG_MAIL)
policydspfsupp.setExceptHook()

#############################################
def cidrmatch(connectip, ipaddrs, n):
    """Match connect IP against a list of other IP addresses. From pyspf."""

    try:
        if connectip.count(':'):
            MASK = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFL
            connectip = spf.inet_pton(connectip)
            for arg in ipaddrs:
                ipaddrs[ipaddrs.index(arg)] = spf.inet_pton(arg)
            bin = spf.bin2long6
        else:
            MASK = 0xFFFFFFFFL
            bin = spf.addr2bin
        c = ~(MASK >> n) & MASK & bin(connectip)
        for ip in [bin(ip) for ip in ipaddrs]:
            if c == ~(MASK >> n) & MASK & ip: return True
    except socket.error: pass
    return False

def parse_cidr(cidr_ip):
    """Breaks CIDR notation into a (address,cidr,cidr6) tuple.  The cidr 
       defaults to 32 if not present. Derived from pyspf"""
    import re
    RE_DUAL_CIDR = re.compile(r'//(0|[1-9]\d*)$')
    RE_CIDR = re.compile(r'/(0|[1-9]\d*)$')
    a = RE_DUAL_CIDR.split(cidr_ip)
    if len(a) == 3:
        cidr_ip, cidr6 = a[0], int(a[1])
    else:
        cidr6 = None
    a = RE_CIDR.split(cidr_ip)
    if len(a) == 3:
        cidr_ip, cidr = a[0], int(a[1])
    else:
        cidr = None
    b = cidr_ip.split(':', 1)
    if len(b) < 2:
        return cidr_ip, cidr
    return a[0], cidr6

#############################################
def spfcheck(data, instance_dict, configData):  #{{{1
    debugLevel = configData.get('debugLevel', 0)
    ip = data.get('client_address')
    if ip == None:
        if debugLevel: syslog.syslog('spfcheck: No client address, exiting')
        return(( None, None, instance_dict ))
    # Do not check SPF for localhost addresses - add to skip addresses to 
    # skip SPF for internal networks if desired.
    skip_addresses = ['127.0.0.0/8', '::ffff:127.0.0.0//104', '::1//128',]
    for cidr in skip_addresses:
        parsed_address = parse_cidr(cidr)
        good_ip = [parsed_address[0],]
        if cidrmatch(ip, good_ip, int(parsed_address[1])):
            Header = ('X-Comment: SPF check N/A for local connections - '+ 'client-ip=%s; helo=%s; envelope-from=%s; receiver=%s; '
                % ( data.get('client_address', '<UNKNOWN>'),
                    data.get('helo_name', '<UNKNOWN>'),
                    data.get('sender', '<UNKNOWN>'),
                    data.get('recipient', '<UNKNOWN>'),
                    ))
            if debugLevel: syslog.syslog(Header)
            return (('prepend', Header, instance_dict ))
    # Whitelist designated IP addresses from SPF checks (e.g. secondary MX or 
    # known forwarders.
    if configData.get('Whitelist'):
        Whitelist = (str(configData.get('Whitelist')))
        Whitelist_list = Whitelist.split(',')
        for cidr in Whitelist_list:
            parsed_address = parse_cidr(cidr)
            good_ip = [parsed_address[0],]
            if cidrmatch(ip, good_ip, int(parsed_address[1])):
                Header = ('X-Comment: SPF skipped for whitelisted relay - '+ 'client-ip=%s; helo=%s; envelope-from=%s; receiver=%s; '
                    % ( data.get('client_address', '<UNKNOWN>'),
                        data.get('helo_name', '<UNKNOWN>'),
                        data.get('sender', '<UNKNOWN>'),
                        data.get('recipient', '<UNKNOWN>'),
                        ))
                if debugLevel: syslog.syslog(Header)
                return (('prepend', Header, instance_dict ))

    #recipient = data.get('recipient')
    receiver=socket.gethostname()
    sender = data.get('sender')
    helo = data.get('helo_name')
    if not sender and not helo:
        if debugLevel: syslog.syslog('spfcheck: No sender or helo, exiting')
        return(( None, None, instance_dict ))

    #  start query
    spfResult = None
    spfReason = None
    instance = data.get('instance')
    # The following if is only needed for testing.  Postfix 
    # will always provide instance.
    if not instance:
        import random
        instance = str(int(random.random()*100000))
    # This is to prevent multiple headers being prepended
    # for multi-recipient mail.
    found_instance = instance_dict.has_key(instance)
    '''Data structure for results is a list of:
        [0] SPF result 
        [1] SPF reason
        [2] Identity (HELO/Mail From)
        [3] Action based on local policy
        [4] Header'''
    if not found_instance:
        # First do HELO check
        #  if no helo name sent, use domain from sender for later use.
        if not helo:
            foo = string.split(sender, '@', 1)
            if len(foo) <  2: helo = 'unknown'
            else: helo = foo[1]
        else:
            if configData.get('HELO_reject') != 'No_Check':
                helo_fake_sender = 'postmaster@' + helo
                res = spf.check2(ip, helo_fake_sender, helo)
                helo_result = [res[0], res[1]]
                helo_result.append('HELO') 
                if debugLevel:
                    syslog.syslog('spfcheck: pyspf result: "%s"' % str(helo_result))
                helo_result[0] = helo_result[0].lower()
                helo_result[0] = helo_result[0].capitalize()
                if configData.get('HELO_reject') == 'Null' and sender:
                    helo_result.append('dunno')
                elif helo_result[0] == 'Temperror' and configData.get('TempError_Defer') == 'True':
                    helo_result.append('defer')
                elif helo_result[0] == 'Permerror' and configData.get('PermError_Defer') == 'True':
                    helo_result.append('reject')
                elif helo_result[0] == 'Permerror' and configData.get('PermError_Defer') == 'False':
                    helo_result('prepend')
                elif configData.get('HELO_reject') == 'Fail' and helo_result[0] == 'Fail':
                    helo_result.append('reject')
                elif configData.get('HELO_reject') == 'SPF_Not_Pass' and (helo_result[0] == 'Fail' or helo_result[0] == 'Softfail' or helo_result[0] == 'Neutral'):
                    helo_result.append('reject')
                    helo_result[1] = 'HELO result rejected due to local policy (Not Pass/None)'
                else: helo_result.append('prepend')

                spfDetail = (helo_result[2] + ' client-ip=%s; helo=%s; envelope-from=%s; receiver=%s; '
                    % ( data.get('client_address', '<UNKNOWN>'),
                        data.get('helo_name', '<UNKNOWN>'),
                        data.get('sender', '<UNKNOWN>'),
                        data.get('recipient', '<UNKNOWN>'),
                         ))
                syslogData = helo_result[1] + ":" + spfDetail
                syslogData = str(syslogData)
                syslog.syslog(syslogData)
                header = 'Received-SPF: '+ helo_result[0] + ' (' + helo_result[1] +') ' + spfDetail
                helo_result.append(header)
                instance_dict[instance] = helo_result
                # Only act on the HELO result if it is authoritative.
                if helo_result[3] == 'reject':
                    return(( 'reject', header, instance_dict ))
                if helo_result[3] == 'defer':
                    return(( 'defer', header, instance_dict ))
        # Second do Mail From Check
        if sender == '':
            if configData.get('HELO_reject') != 'No_Check':
                return(( helo_result[3], header, instance_dict ))
        else:
            if configData.get('Mail_From_reject') != 'No_Check':
                res = spf.check2(ip, sender, helo)
                mfrom_result = [res[0], res[1]]
                mfrom_result.append('Mail From')
                if debugLevel:
                    syslog.syslog('spfcheck: pyspf result: "%s"' % str(mfrom_result))
                mfrom_result[0] = mfrom_result[0].lower()
                mfrom_result[0] = mfrom_result[0].capitalize()
                mfrom_result.append('dunno')
                if mfrom_result[0] == 'Temperror' and configData.get('TempError_Defer') == 'True':
                    mfrom_result[3] = 'defer'
                elif mfrom_result[0] == 'Permerror' and configData.get('PermError_Defer') == 'True':
                    mfrom_result[3] = 'reject'
                elif mfrom_result[0] == 'Permerror' and configData.get('PermError_Defer') == 'False':
                    mfrom_result[3] = 'prepend'
                elif configData.get('Mail_From_reject') == 'Fail' and mfrom_result[0] == 'Fail':
                    mfrom_result[3] = 'reject'
                else: mfrom_result[3] = 'prepend'
                if mfrom_result[0] != 'None':
                    spfDetail = (mfrom_result[2] + ' client-ip=%s; helo=%s; envelope-from=%s; receiver=%s; '
                        % ( data.get('client_address', '<UNKNOWN>'),
                            data.get('helo_name', '<UNKNOWN>'),
                            data.get('sender', '<UNKNOWN>'),
                            data.get('recipient', '<UNKNOWN>'),
                             ))
                    syslogData = mfrom_result[1] + ":" + spfDetail
                    syslogData = str(syslogData)
                    syslog.syslog(syslogData)
                    header = 'Received-SPF: '+ mfrom_result[0] + ' (' + mfrom_result[1] +') ' + spfDetail
                    mfrom_result.append(header)
                    instance_dict[instance] = mfrom_result
                # Act on the Mail From result if it is authoritative.
                if mfrom_result[3] == 'reject':
                    return(( 'reject', header, instance_dict ))
                if mfrom_result[3] == 'defer':
                    return(( 'defer', header, instance_dict ))
                if mfrom_result[3] != 'dunno' or helo_result[3] =='dunno':
                    return(( 'prepend', header, instance_dict ))
    else:
        cached_instance = instance_dict[instance]
        if cached_instance[3] == 'prepend':
            return(( 'dunno', 'Header already pre-pended', instance_dict ))
        else:
            return(( cached_instance[3], cached_instance[4], instance_dict ))
    return(( 'None', 'None', instance_dict ))

###################################################
#  load config file  {{{1
configFile = None
if len(sys.argv) > 1:
    if sys.argv[1] in ( '-?', '--help', '-h' ):
        print 'usage: policyd-spf [<configfilename>]'
        sys.exit(1)
    configFile = sys.argv[1]

configGlobal = policydspfsupp.processConfigFile(filename = configFile)

#  loop reading data  {{{1
debugLevel = configGlobal.get('debugLevel', 0)
if debugLevel >= 2: syslog.syslog('Starting')
instance_dict = {'0':'init',}
instance_dict.clear()
data = {}
lineRx = re.compile(r'^\s*([^=\s]+)\s*=(.*)$')
while 1:
    line = sys.stdin.readline()
    if not line: break
    line = string.rstrip(line)
    if debugLevel >= 4: syslog.syslog('Read line: "%s"' % line)

    #  end of entry  {{{2
    if not line:
        if debugLevel >= 4: syslog.syslog('Found the end of entry')
        configData = configGlobal
        if debugLevel >= 2: syslog.syslog('Config: %s' % str(configData))

        #  run the checkers  {{{3
        checkerValue = None
        checkerReason = None
        checkerValue, checkerReason, instance_dict = spfcheck(data, 
                    instance_dict, configData)
        if configData.get('SPFSEEDONLY', 0):
            checkerValue = None
            checkerReason = None

        #  handle results  {{{3
        if checkerValue == 'reject':
            sys.stdout.write('action=550 %s\n\n' % checkerReason)

        elif checkerValue == 'prepend':
            sys.stdout.write('action=prepend %s\n\n' % checkerReason)

        elif checkerValue == 'defer':
            sys.stdout.write('action=defer_if_permit %s\n\n' % checkerReason)

        elif checkerValue == 'warn':
            sys.stdout.write('action=warn %s\n\n' % checkerReason)

        else:
            sys.stdout.write('action=dunno\n\n')

        #  end of record  {{{3
        sys.stdout.flush()
        data = {}
        continue

    #  parse line  {{{2
    m = lineRx.match(line)
    if not m: 
        syslog.syslog('ERROR: Could not match line "%s"' % line)
        continue

    #  save the string  {{{2
    key = m.group(1)
    value = m.group(2)
    if key not in [ 'protocol_state', 'protocol_name', 'queue_id' ]:
        value = string.lower(value)
    data[key] = value
