#!/usr/bin/python
# coding: utf-8
#
# Copyright 2004-2016 Univention GmbH
#
# http://www.univention.de/
#
# All rights reserved.
#
# The source code of this program is made available
# under the terms of the GNU Affero General Public License version 3
# (GNU AGPL V3) as published by the Free Software Foundation.
#
# Binary versions of this program provided by Univention to you as
# well as other copyrighted, protected or trademarked materials like
# Logos, graphics, fonts, specific documentations and configurations,
# cryptographic keys etc. are subject to a license agreement between
# you and Univention and not subject to the GNU AGPL V3.
#
# In the case you use this program under the terms of the GNU AGPL V3,
# the program is provided in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public
# License with the Debian GNU/Linux or Univention distribution in file
# /usr/share/common-licenses/AGPL-3; if not, see
# <http://www.gnu.org/licenses/>.

from __future__ import print_function

import os
import re
import sys
import ldap
import getpass
import argparse
import subprocess

import univention.config_registry as configRegistry


class Parser(argparse.ArgumentParser):
	def error(self, _message):
		# If the user supplied invalid arguments `ArgumentParser` would print a
		# usage help and exit. But we want to pass invalid arguments to
		# `ldbsearch`, so we handle this case on our own.
		raise ValueError("Invalid Arguments")

	@classmethod
	def parse(cls, arguments):
		epilog = ("NOTE: If no `-b`/`--basedn` is supplied, this tool will "
		          "search in the standard base DN and the configuration "
		          " base DN.\n See `ldbsearch --help` for further arguments.")
		parser = cls(epilog=epilog)
		parser.add_argument("-b", "--basedn")
		parser.add_argument("-d", "--debuglevel")
		parser.add_argument("-k", "--kerberos", choices=["yes", "no"])
		parser.add_argument("-A", "--authentication-file")
		parser.add_argument("-P", "--machine-pass", action="store_true")
		parser.add_argument("-U", "--user")
		parser.add_argument("--password")
		parser.add_argument("--simple-bind-dn")

		(parsed, _unknown) = parser.parse_known_args(arguments)
		return parsed


class LDBSearch(object):
	UCR = configRegistry.ConfigRegistry()
	UCR.load()

	def __init__(self, arguments, host=None):
		self.host = host or self._build_host()
		self.arguments = arguments

	def _build_host(self):
		ldap_server_name = self.UCR.get("ldap/server/name")
		return "ldaps://{}".format(ldap_server_name)

	def _build_arguments(self):
		args = ["ldbsearch", "--debug-stderr", "-H", self.host]
		args.extend(self.arguments)
		return args

	def debug_string(self, prefix="### Output of:"):
		return "{} {}".format(prefix, " ".join(self._build_arguments()))

	def _oserror_handler(self, error):
		print("Error:", error, file=sys.stderr)
		print(self.debug_string("While trying to run:"), file=sys.stderr)
		sys.exit(1)

	def search(self):
		try:
			output = subprocess.check_output(self._build_arguments())
		except (OSError, subprocess.CalledProcessError) as error:
			self._oserror_handler(error)

		print(output)
		entries = list()
		referrals = list()
		for record in output.split("\n\n"):
			if record.startswith("# record "):
				entries.append(record)
			elif record.startswith("# Referral"):
				referrals.append(record)
		return (entries, referrals)

	def search_key(self, key):
		(entries, _referrals) = self.search()
		pattern = re.compile("^{}: (.*)$".format(key), re.MULTILINE)
		for entry in entries:
			matches = pattern.findall(entry)
			if matches:
				return matches[0]
		return ""


def credentials_given(parsed_arguments):
	return parsed_arguments.kerberos == "yes" or \
		parsed_arguments.authentication_file or \
		parsed_arguments.machine_pass or \
		parsed_arguments.password or \
		parsed_arguments.user and "%" in parsed_arguments.user


def account_given(parsed_arguments):
	return parsed_arguments.user or \
		parsed_arguments.simple_bind_dn


def read_password(user=None):
	promt = "Password{}: ".format(" [{}]".format(user) if user else "")
	return getpass.getpass(promt)


def file_readable(path):
	return os.access(path, os.R_OK)


def search_machine_secret():
	# currently the password in the secrets.ldb is set to machine.secret only
	# on provision host, so we need to look it up from the secrets.ldb
	hostname = LDBSearch.UCR.get("hostname")
	ldbsearch = LDBSearch(["samAccountName={}$".format(hostname), "secret"],
	                      host="/var/lib/samba/private/secrets.ldb")
	return (hostname, ldbsearch.search_key("secret"))


def allow_kerberos(parsed_arguments):
	return parsed_arguments.kerberos != "no"


def has_kerberos_ticket():
	return subprocess.call(["klist", "-t"], stderr=sys.stdout) == 0


def configuration_base_dn():
	samba_base_dn = ldap.dn.str2dn(LDBSearch.UCR.get("samba4/ldap/base"))
	conf_base_rdn = [("CN", "Configuration", ldap.AVA_STRING)]
	return ldap.dn.dn2str([conf_base_rdn] + samba_base_dn)


def build_authentication_argument(parsed_arguments):
	if not credentials_given(parsed_arguments):
		if account_given(parsed_arguments):
			sampassword = read_password(parsed_arguments.user)
			return "--password={}".format(sampassword)
		elif file_readable("/etc/machine.secret"):
			(hostname, sampassword) = search_machine_secret()
			return "--user={}$%{}".format(hostname, sampassword)
		elif allow_kerberos(parsed_arguments) and has_kerberos_ticket():
			return "--kerberos=yes"
		else:
			user = getpass.getuser()
			sampassword = read_password(user)
			return "--user={}%{}".format(user, sampassword)
	return ""


def print_searches(*searches):
	num_entries = 0
	referrals = list()

	for search in searches:
		(search_entries, search_referrals) = search.search()
		for entry in search_entries:
			num_entries += 1
			print("# record {}".format(num_entries),
			      *entry.split("\n")[1:], sep="\n")
			print()
		referrals.extend(search_referrals)

	if referrals:
		print(*referrals, sep="\n\n")
		print()

	num_referrals = len(referrals)
	print("# returned {} records".format(num_entries + num_referrals))
	print("# {} entries".format(num_entries))
	print("# {} referrals".format(num_referrals))


def main(arguments):
	try:
		parsed_arguments = Parser.parse(arguments)
	except ValueError:
		print_searches(LDBSearch(arguments))
		exit()

	ldbsearch = arguments[:]
	ldbsearch.append(build_authentication_argument(parsed_arguments))
	if credentials_given(parsed_arguments) and parsed_arguments.kerberos is None:
		ldbsearch.append("--kerberos=no")

	searches = [LDBSearch(ldbsearch)]

	if parsed_arguments.basedn is None:
		ldbsearch_conf = ldbsearch[:]
		ldbsearch_conf.append("--basedn={}".format(configuration_base_dn()))
		searches.append(LDBSearch(ldbsearch_conf))

	print_searches(*searches)

	if parsed_arguments.debuglevel is not None:
		for search in searches:
			print(search.debug_string(), file=sys.stderr)

if __name__ == "__main__":
	main(sys.argv[1:])
