#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright Grégory Soutadé 2012

# This file is part of autojump2

# autojump2 is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# autojump2 is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with autojump2.  If not, see <http://www.gnu.org/licenses/>.

# Original code has been written by Joel Schaerer (https://github.com/joelthelion/autojump)

from sys import argv, exit, stderr, stdout
import getopt
import os
from optparse import OptionParser
import shutil
import re

COMPLETION_SEPARATOR = '__'
CONFIG_DIR = os.environ.get("AUTOJUMP2_DATA_DIR", os.path.expanduser("~"))
DATABASE_NAME = '.autojump2.dict'

def _walk_path(path):
    res = []
    try:
        for p in os.listdir(path):
            p = path + p
            if not os.path.islink(p) and not os.path.isdir(p): 
                continue
            res.append(p)
    except OSError:
        print >> stderr, "Error listing %s" % path

    return res

def walk_path(path):
    res = []
    try:
        pos = path.index('*')
        sub_path = path[:pos]
        for p in _walk_path(sub_path):
            if '*' in path[pos+1:]:
                for p2 in walk_path(p + path[pos+1:]):
                    res.append(p2)
            else:
                res.append(p + path[pos+1:])
    except ValueError:
        pass

    return res

def add_path(l, h, path):
    path = path.replace('\\*', '*')

    if path in l:
        return False

    tmp = []

    l.append(path)
    l.sort()

    # Generic paths must be after normal path
    for p in l:
        if not '*' in p:
            tmp.append(p)

    for p in l:
        if '*' in p:
            tmp.append(p)

    del l[:]

    for p in tmp:
        l.append(p)

    if '*' in path:
        paths = walk_path(p)
        if len(paths) > 0:
            h[path] = paths

    return True

def remove_path(l, h, path):
    path = path.replace('\\*', '*')

    if not path in l:
        return False
    else:
        l.remove(path)
        if path in h:
            del h[path]

    return True

def modify_path(l, h, src, dest):
    src = src.replace('\\*', '*')
    dest = dest.replace('\\*', '*')

    if not src in l:
        return False

    if not remove_path(l, h, src) or not add_path(l, h, dest):
        return False

    return True

def save_database(l, h):
    path = CONFIG_DIR + os.sep + DATABASE_NAME
    bak = path + '.bak'

    if os.path.exists(bak):
        try:
            os.remove(bak)
        except OSError as e:
            pass

    try:
        if os.path.exists(path):
            shutil.copy(path, bak)
    except OSError as e:
        print >> stderr, "Error while creating backup of autojump2 dic @ \'%s\'. (%s)" % (path.bak, e)
        raise

    try:
        if os.path.exists(path):
            os.remove(path)
    except OSError as e:
        print >> stderr, "Error can't remove autojump2 database @ \'%s\'. (%s)" % (path, e)
        raise

    f = open(path, 'w')
    try:
        for p in l:
            f.write(p + '\n')
            if p in h:
                for p2 in h[p]:
                    f.write('>>> ' + p2 + '\n')
        f.flush()
        f.close()
    except OSError as e:
        print >> stderr, "Error write autojump2 database @ \'%s\'. (%s)" % (path, e)
        shutil.copy(bak, path)

    return True
        
def open_database():
    l = []
    h = {}
    path = CONFIG_DIR + os.sep + DATABASE_NAME
    prev_line = ''

    if not os.path.exists(path): return l, h
    
    f = open(path)
    for line in f:
        if line.startswith('>>> ') and prev_line != '':
            if not prev_line in h: h[prev_line] = []
            h[prev_line].append(line[4:][:-1]) # Remove '>>> ' and '\n'
        else:
            l.append(line[:-1]) # Remove last '\n'
            prev_line = line[:-1]
    f.close()

    return l, h

def list_database(l, h):
    if not len(l):
        print >> stderr, 'Any path saved'
        return

    print >> stderr, 'Saved paths :'
    for p in l:
        print >> stderr, '\t' + p
        if p in h:
            sublist = h[p]
            sublist.sort()
            for p2 in sublist:
                print >> stderr, '\t>>> ' + p2

    return True

def path_matching(l, h, args, pos = -1):
    res = []
    exprs = []

    if len(args) == 0: return res

    # If args are : 'proj' and  'v2'
    # exp1 is /a/b/c/projXXXv2
    # exp2 is /a/b/proj/v2
    exp1 = '^.*' + os.sep
    for a in args:
        exp1 += '[^' + os.sep + ']*' + a + '[^' + os.sep + ']*'
    # exp1 += '$'
    e1 = re.compile(exp1)
    exprs.append(e1)

    exp2 = ''
    if len(args) > 1:
        exp2 = '^.*' + os.sep
        for a in args[:-1]:
            exp2 += '[^' + os.sep + ']*' + a + '.*' + os.sep
        exp2 += '[^' + os.sep + ']*' + args[-1] + '[^' + os.sep + ']*'
        # exp2 += '[^' + os.sep + ']*' + args[-1] + '[^' + os.sep + ']*$'
        e2 = re.compile(exp2)
        exprs.append(e2)

    for path in l:
        if '*' in path:
            try:
                sublist = h[path]
                sublist.sort()
                for p in sublist:
                    for e in exprs:
                        if re.match(e, p):
                            if pos != -1:
                                pos -= 1
                                if pos == 0 and not p in res:
                                    res.append(p)
                            elif not p in res:
                                res.append(p)
                            break
            except KeyError as e:
                print >> stderr, "Error database may be corrupted, invalid key (%s)" % e
        else:
            for e in exprs:
                if re.match(e, path):
                    if pos != -1:
                        pos -= 1
                        if pos == 0 and not path in res:
                            res.append(path)
                    elif not path in res:
                        res.append(path)
                    break
        
    
    # print >> stderr, 'Exprs : \n%s\n%s' % (exp1, exp2)
    # print >> stderr, 'Matching paths :'
    # for path in res:
    #     print >> stderr, path

    return res

#################################### Main code ####################################

if __name__ == '__main__':
    usage = '%prog [options]\n'\
            'Navigate throw your filesystems with recorded links'

    optparser = OptionParser(usage=usage)

    optparser.add_option('-a', '--add', dest='add',
                         help='Add a new path to the database',
                         metavar="path")
    optparser.add_option('-r', '--remove', dest='remove',
                         help='Remove a path from the database',
                         metavar="path")
    optparser.add_option('-m', '--modify', dest='modify', nargs=2,
                         help='Modify key weight',
                         metavar="path_src path_dest")
    optparser.add_option('-u', '--update', dest='update',
                         help='Update path with *',
                         metavar="path")
    optparser.add_option('-l', '--list', dest='list',
                         action="store_true",
                         help='List database')
    optparser.add_option('-c', '--completion', dest='completion',
                         action="store_true",
                         help='Use autojump\'s completion')
    optparser.add_option('-b', '--bash', dest='bash',
                         action="store_true",
                         help='Current shell is bash')

    (optlist, args) = optparser.parse_args(argv[1:])

    l, h = open_database()
    
    if optlist.add:
        if optlist.completion: exit(1)
        path = os.path.abspath(optlist.add)
        if os.path.isfile(path):
            print >> stderr, "Error, cannot add a file (%s) in database, directory needed" % path
        elif add_path(l, h, path):
            save_database(l, h)
            print >> stderr, '>>> \'%s\' correctly added to database' % (path)
        else:
            print >> stderr, 'Error \'%s\' already exists in database' % (path)

    elif optlist.remove:
        if optlist.completion: exit(1)
        if remove_path(l, h, os.path.abspath(optlist.remove)):
            save_database(l, h)
            print >> stderr, '>>> \'%s\' correctly removed from database' % (optlist.remove)
        else:
            print >> stderr, 'Error \'%s\' doesn\'t exists in database' % (optlist.remove)

    elif optlist.modify:
        if optlist.completion: exit(1)
        if os.path.isfile(os.path.abspath(optlist.modify[1])):
            print >> stderr, "Error, cannot add a file in database, directory needed"
        elif modify_path(l, h, optlist.modify[0], os.path.abspath(optlist.modify[1])):
            save_database(l, h)
            print >> stderr, '>>> \'%s\' is now \'%s\'' % (optlist.modify[0], os.path.abspath(optlist.modify[1]))
        else:
            print >> stderr, 'Error \'%s\' doesn\'t exists in database' % (optlist.modify[0])

    elif optlist.update:
        if optlist.completion: exit(1)
        if not modify_path(l, h, optlist.update, optlist.update):
            print >> stderr, 'Error updating ' + optlist.update
        else:
            save_database(l, h)
            print >> stderr, '>>> Database updated'
            
    elif optlist.list:
        if optlist.completion: exit(1)
        list_database(l, h)

    else:
# Do the hard work
        if optlist.bash: quotes = '"'
        else: quotes = ""

        if optlist.completion:
            m = re.search('^.*(' + COMPLETION_SEPARATOR + '.*)$', args[-1])
            if m: # Remove '__'
                args[-1] = args[-1][:-len(m.group(1))]
            matches = path_matching(l, h, args)
            if len(matches) > 1:
                print("\n" . join(("%s%s%d%s%s" % (args[-1], COMPLETION_SEPARATOR, n+1, COMPLETION_SEPARATOR, r)\
                                       for n,r in enumerate(matches))))
            elif len(matches) == 1:
                print quotes + matches[0] + quotes
        else:
            m = re.search('^.*' + COMPLETION_SEPARATOR + '([0-9]+)$', args[-1])
            if m:
                args[-1] = args[-1][:-len(COMPLETION_SEPARATOR)-len(m.group(1))]
                matches = path_matching(l, h, args, int(m.group(1)))
            else:
                matches = path_matching(l, h, args)
            if len(matches) > 0:
                print quotes + matches[0] + quotes
        exit(0)
    exit(1)
