#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Milter that rejects mails with forged addresses in both
	* the envelope (MAIL FROM protocol stage)
	* the "From" header (DATA protocol stage)

Legitimate email addresses are found as 'mailPrimaryAddress' and
'mailAlternativeAddress' in LDAP.

INSTALLATION:
	1) This is intended to run on a UCS server and requires
	univention-mail-postfix to be installed.
	2) Install pythin-libmilter:
	# wget https://raw.githubusercontent.com/crustymonkey/python-libmilter/master/libmilter.py \
	-O /usr/local/lib/python2.7/site-packages/libmilter.py
	3)Install the milter (this file):
	# mkdir /usr/local/share/milters
	# cp [this file] /usr/local/share/milters/
	4) install, enable and start milter service:
	# cp no_forged_from_milter.service /etc/systemd/system/
	# systemctl daemon-reload
	# systemctl enable no_forged_from_milter.service
	# sytemctl start no_forged_from_milter.service
	5) Check if installation worked:
	# service no_forged_from_milter status
	# grep no_forged_from /var/log/mail.log
	5) Configure Postfix to use the milter:
	# echo 'smtpd_milters = inet:localhost:5656' >> /etc/postfix/main.cf.local
	# ucr commit /etc/postfix/main.cf
	# service postfix restart

Copyright (c) 2018 Daniel Tröder, daniel@admin-box.com
Licensed under the MIT License (MIT)
SPDX short identifier: MIT
License text: https://opensource.org/licenses/MIT

python-libmilter is licensed under the GPLv3 (https://www.gnu.org/licenses/gpl-3.0.en.html)
"""
import re
import sys
import signal
import syslog
import traceback

from ldap.filter import filter_format
import libmilter as lm
import univention.uldap


class NoForgedFromMilter(lm.ForkMixin, lm.MilterProtocol):
	_lo = None

	def __init__(self, opts=0, protos=0):
		lm.MilterProtocol.__init__(self, opts, protos)
		lm.ForkMixin.__init__(self)
		self.sasl_login_name = ''
		self.legitimate_addresses = []
		self.envelope_from = ''
		self.header_from = ''
		self.recipients = []

	@classmethod
	def log(cls, msg, level='INFO'):
		if level == 'ERROR':
			level = 'ERR'
		try:
			syslog_level = getattr(syslog, 'LOG_{}'.format(level))
		except KeyError:
			syslog_level = syslog.LOG_INFO
		syslog.syslog(syslog_level, msg)

	def clear_variables(self):
		self.sasl_login_name = ''
		self.legitimate_addresses = []
		self.envelope_from = ''
		self.header_from = ''
		self.recipients = []

	@classmethod
	def get_lo(cls):
		if not cls._lo:
			cls._lo = univention.uldap.getMachineConnection(
				ldap_master=False,
				secret_file='/etc/listfilter.secret'
			)
		return cls._lo

	@classmethod
	def get_legitimate_addresses_for_username(cls, sasl_login_name):
		lo = cls.get_lo()
		ldap_attr = ['mailPrimaryAddress', 'mailAlternativeAddress']
		ldap_filter = filter_format('(&(uid=%s)(objectclass=univentionMail))', (sasl_login_name,))
		ldap_result = lo.search(filter=ldap_filter, attr=ldap_attr)
		try:
			return ldap_result[0][1]['mailPrimaryAddress'] + ldap_result[0][1].get('mailAlternativeAddress', [])
		except IndexError:
			cls.log('Found no email address for sasl_login_name={!r}.'.format(sasl_login_name), 'ERROR')
			return ''

	@lm.noReply
	def connect(self, hostname, family, ip, port, cmdDict):
		self.clear_variables()
		return lm.CONTINUE

	def mailFrom(self, frAddr, cmdDict):
		self.envelope_from = frAddr
		try:
			self.sasl_login_name = cmdDict['auth_authen']
		except KeyError:
			# not a submission
			pass
		else:
			if not self.legitimate_addresses:
				self.legitimate_addresses = self.get_legitimate_addresses_for_username(self.sasl_login_name)
			if self.envelope_from in self.legitimate_addresses:
				return lm.CONTINUE
			else:
				self.log('REJECT: envelope_from ({}) not in legitimate addresses ({}).'.format(
					self.envelope_from, ', '.join(self.legitimate_addresses)))
				return lm.REJECT
		return lm.CONTINUE

	def header(self, key, val, cmdDict):
		if self.sasl_login_name and key.lower() == 'from':
			m = re.match(r'.*<(.+@.+\..+)>$', val)
			if not m:
				m = re.match(r'(.+@.+\..+)$', val)
			if m:
				self.header_from = m.groups()[0]
				if not self.legitimate_addresses:
					self.legitimate_addresses = self.get_legitimate_addresses_for_username(self.sasl_login_name)
				if self.header_from in self.legitimate_addresses:
					return lm.CONTINUE
				else:
					self.log('REJECT: {!r} in "From" header ({}) not in legitimate addresses ({!s}).'.format(
						self.header_from, val, ', '.join(self.legitimate_addresses)))
					return lm.REJECT
			else:
				self.log('Cannot parse header: {!r}: {!r}.'.format(key, val), 'ERROR')
		return lm.CONTINUE

	def eob(self, cmdDict):
		# don't log unnecessarily
		return lm.CONTINUE

	def close(self):
		# don't log unnecessarily
		pass


def run_milter():
	syslog.openlog(ident="no_forged_from", logoption=syslog.LOG_PID, facility=syslog.LOG_MAIL)
	NoForgedFromMilter.log('Starting NoForgedFromMilter.')
	# test LDAP connection
	NoForgedFromMilter.get_lo()

	opts = lm.SMFIP_NOHELO | lm.SMFIP_NORCPT | lm.SMFIP_NOBODY | lm.SMFIP_NOEOH | lm.SMFIP_NODATA
	milter_factory = lm.ForkFactory('inet:127.0.0.1:5656', NoForgedFromMilter, opts)

	def sig_handler(num, frame):
		NoForgedFromMilter.log('Stopping NoForgedFromMilter.')
		milter_factory.close()
		sys.exit(0)
	signal.signal(signal.SIGINT, sig_handler)
	signal.signal(signal.SIGTERM, sig_handler)

	try:
		milter_factory.run()
	except Exception as exc:
		milter_factory.close()
		print('EXCEPTION OCCURRED: {}'.format(exc))
		traceback.print_tb(sys.exc_traceback)
		sys.exit(3)


if __name__ == '__main__':
	run_milter()
