#!/usr/bin/env python
"""Wrapper for /sbin/modload

Usage: modload.py [-Mpath1:path2...] [-Adinvs] [-o output] files ...

  -A           automatically load required LKMs.
  -d           passed to /sbin/modload.
  -i           print dependencies for LKM files.
  -M path      module search path. 
  -n           don't run. print what to do.
  -n           don't run. print what to do.
  -o output    passed to /sbin/modload.
  -s           passed to /sbin/modload.
  -v           be verbose.
"""

import sys, getopt, os, re
from sys import stderr
from os.path import basename, dirname

def usage_and_exit(r):
  "print usage to stderr, then exit"
  print >> stderr, """\
Usage: %(program)s [-Ainvs] [-M path1:path2...] [-o output] files...""" % \
  {"program": sys.argv[0]}
  sys.exit(r)


#modload_prog = "/bin/echo"
modload_prog = "/sbin/modload"

options = {}
lkm_pool = None
module_path = ['.', "/usr/lkm"]

class LkmPool:
  def __init__(self):
    self.all_lkms = {}
    self.features = {}

  def add(self, lkm):
    if self.all_lkms.has_key(lkm.basename):
      # print >> stderr, "LKM file " + lkm.basename + " is ignored."
      pass
    else:
      self.all_lkms[lkm.basename] = lkm
      for f in lkm.provide:
        self.features[f] = lkm

  def find(self, req):
    "Find an LKM that provides required feature"
    try:
      return self.features[req]
    except KeyError:
      return None

  def scan_dir(self, dir):
    print >> stderr, "Scanning", dir
    files = os.listdir(dir)
    for f in files:
      ff = os.path.join(dir,f)
      if os.path.splitext(f)[1] == ".o" and \
         not os.path.isdir(ff):
        self.add(LoadableKernelModule(ff))

  def get(self, fname):
    base = basename(fname)
    if self.all_lkms.has_key(base):
      return self.all_lkms[base]
    new_lkm = LoadableKernelModule(fname)
    self.all_lkms[base] = new_lkm
    return new_lkm

class LoadableKernelModule:
  "A class to represent LKM"

  def read_dependencies(self, offset, size):
    f = open(self.filename,"r")
    f.seek(offset)
    data = f.read(size)
    f.close()
    data = data.rstrip('\x00')
    if len(data) == 0:   # We need this because
      return []          #   "".split('\x00') returns ['']
    return data.split('\x00')

  def __init__(self,fname):
    self.filename = fname
    self.basename = basename(fname)
    self.dirname = dirname(fname)
    self.follows = {}
    self.provide = self.require = {}
    
    objdump = os.popen("objdump -h %(fname)s" % {"fname": fname},
                       "r");
    for line in objdump:
      m = re.search(r"^\s*\d+\s+.note.netbsd.lkm.(require|provide)\s*",
                    line)
      if m:
        reqpro = m.group(1)
        postmatch = line[m.end():]
        size, vma, lma, fileoff, align = postmatch.split()
        a = self.read_dependencies(int(fileoff,16),
                                   int(size, 16))
        if reqpro == "require":
          self.require = a
        else:
          self.provide = a

    objdump.close()
    

  def __str__(self):
    return "LKM<" + self.filename + ">"


  def print_info(self):
    print self.filename, "requires:"
    for a in self.require:
      print "\t", a
    print self.filename, "provides:"
    for a in self.provide:
      print "\t", a

#
# for modload -i files...
#
def list_provide_and_require(fname):
  "print LKM depedency informations for `filename'"
  lkm = lkm_pool.get(fname)
  lkm.print_info()

def resolve(files):
  """find all LKMs required by ones in `files',
  and sort all LKMs in loading order"""
  
  dir_searched = False
  lkms_to_load = {}

  for f in files:
    lkm = LoadableKernelModule(f)
    lkms_to_load[lkm] = True
    lkm_pool.add(lkm)

  def scan_dirs(dirs):
    for dir in dirs:
      if dir == "":
        dir = "."
      lkm_pool.scan_dir(dir)

  changed = True
  while changed:
    changed = False
    for o in lkms_to_load.keys():
      # print >> stderr, "Checking for", o
      for d in o.require:
        # print >> stderr, "  requires", d
        lkm = lkm_pool.find(d)
        if lkm:
          # print >> stderr, "   provided by", lkm
          if lkms_to_load.has_key(lkm):
            pass # already in the list
          else:
            lkms_to_load[lkm] = o
            changed = True
            o.follows[lkm] = True
        else:
          if not dir_searched:
            scan_dirs( module_path )
            dir_searched = True
            changed = True
          else:
            print >> stderr, "No modules for feature", d

# for k, v in lkms_to_load.iteritems():
#   print k, "<-", v
#   print k.require

  order = []
  while len(lkms_to_load) > 0:
    a = lkms_to_load.keys()
    for lkm in a:
      leaf = True
      for p in lkm.follows.keys():
        if lkms_to_load.has_key(p):
          leaf = False
      if leaf:
        order.append(lkm)
        del lkms_to_load[lkm]
  return order

def run_modload(file):
  args = [modload_prog]
  for o in ["-o", "-v", "-d"]:
    if options.has_key(o): args.append(o + options[o])
  if options.has_key("-s") or options.has_key("-A"): args.append("-s")
  args.append(file)

  if options.has_key("-n"):
    if options.has_key("-v"): print "Loading", file, "..."
    print " ".join(args)
    status = 0
  else:
    if options.has_key("-v"): print >> stderr, "Loading", file + "..."
    status = os.spawnv(os.P_WAIT, modload_prog, args)
  return status


def main():
  global lkm_pool
  global options
  global module_path

  try:
    opts, args = getopt.getopt(sys.argv[1:], "AdiM:nvso:")
  except getopt.GetoptError, errmsg:
    print >> stderr, "%(program)s: %(errmsg)s" %\
          {"program": sys.argv[0], "errmsg": errmsg}
    usage_and_exit(2)

  lkm_pool = LkmPool()
  
  for o, a in opts:
    options[o] = a

  if len(args) == 0:
    print >> stderr, "Too few arguments"
    usage_and_exit(2)

  if options.has_key("-i"):
    for f in args:
      list_provide_and_require(f)
    sys.exit(0)

  if options.has_key("-M"):
    module_path = options["-M"].split(":")
  else:
    module_path = [ dirname(args[0]), "/usr/lkm" ]
    
  if not options.has_key("-A"):
    err = 0
    for f in args:
      if run_modload(f) != 0: err += 1
    sys.exit(err)

  load_order = resolve(args)

  for lkm in load_order:
    run_modload(lkm.filename)

if __name__ == "__main__":
  main()
