#!/usr/bin/python
#
#  /var/lib/misc/setting_policy
#
import os
import sys
import traceback
import xml.dom.minidom
import syslog

BDEBUG=False

def log(msg):
   global BDEBUG
   if BDEBUG : 
       syslog.syslog("setting_policy:%s"%msg)

def pin_cpu(domxml, cpu_list):
    
        vcpu = domxml.getElementsByTagName('vcpu')[0]
        vcpu.setAttribute('cpuset', cpu_list )


def set_vcpu_share(domxml, share_value):
        domain = domxml.getElementsByTagName('domain')[0]
        cputune = domxml.createElement('cputune')
        domain.appendChild(cputune)

        shares = domxml.createElement('shares')
        cputune.appendChild(shares)

        shares.appendChild( domxml.createTextNode(share_value) )

def pin_mem(domxml, node_id, mode):
        domain = domxml.getElementsByTagName('domain')[0]
        numas = domxml.getElementsByTagName('numatune')

        if not len(numas) > 0:
            numatune = domxml.createElement('numatune')
            domain.appendChild(numatune)

            memory = domxml.createElement('memory')
            memory.setAttribute('mode', mode)
            memory.setAttribute('nodeset', node_id)
            numatune.appendChild(memory)

        else:
            sys.stderr.write('numa: numa already exists in domain xml')

def get_bridge_info(domxml):
         interface = domxml.getElementsByTagName('interface')[0]
         node_mac  = interface.getElementsByTagName('mac')[0]
         node_source = interface.getElementsByTagName('source')[0]

         mac = node_mac.getAttribute("address") 
         br  = node_source.getAttribute("bridge")
         return (br.encode('ascii','ignore'), mac.encode('ascii','ignore'))

def append_disk(domxml, source_file):

     devices = domxml.getElementsByTagName('devices')[0]
     disk = domxml.createElement('disk') 
     disk.setAttribute('device', 'disk')
     disk.setAttribute('type', 'block')
     devices.appendChild(disk)

     source = domxml.createElement('source')
     source.setAttribute('dev', source_file)
     
     target = domxml.createElement('target')
     target.setAttribute('bus', "virtio" )
     target.setAttribute('dev', "vda")
     
     driver = domxml.createElement('driver')
     driver.setAttribute('cache', 'none')
     driver.setAttribute('error_policy', 'stop')
     driver.setAttribute('io', 'native')
     driver.setAttribute('name', 'qemu')
     driver.setAttribute('type', 'raw')

     disk.appendChild(source)
     disk.appendChild(target)
     disk.appendChild(driver)


def create_vdisk(fun_name, tid, domxml):
    
    group_name=None
    if tid in range(1,11):
       group_name="P421_0"
    elif tid in range(13, 23):
       group_name="P421_1"
    elif tid in [ 11 , 12, 23, 24 ,25  ]:
       group_name="P42i" 
    
    fname="/dev/%s/%s-%d"%(group_name, fun_name, tid) 
    append_disk(domxml, fname)

def update_disk_io(domxml):
     devices = domxml.getElementsByTagName('devices')[0]
     disks   = devices.getElementsByTagName('disk')
     log('find %d disks'%len(disks))
     for disk in disks:
         disk_device = disk.getAttribute("device")
         log('got disk_device')
         if disk_device == 'disk': 
             log('set it to native')
             driver = disk.getElementsByTagName('driver')[0]
             driver.setAttribute('io', 'native')
         

#PCI pass through
def pin_pci(domxml, bus, slot, function):
     domain = domxml.getElementsByTagName('domain')[0]
     devices = domain.getElementsByTagName('devices')[0]

     hostdev = domxml.createElement('hostdev')
     hostdev.setAttribute('mode', 'subsystem')
     hostdev.setAttribute('type', 'pci')
     hostdev.setAttribute('managed', 'no')
     devices.appendChild(hostdev)

     source = domxml.createElement('source')
     hostdev.appendChild(source)
     
     address = domxml.createElement('address')
     address.setAttribute("domain", "0x0000")
     address.setAttribute("bus", "0x"+bus )
     address.setAttribute("slot", "0x"+slot)
     address.setAttribute("function", "0x"+function)
     source.appendChild(address)
     
def get_vm_name(domxml):
     domain = domxml.getElementsByTagName('domain')[0]
     name = domain.getElementsByTagName('name')[0]
     fchild=name.firstChild
     return fchild.data

def eth_pciname(interface_name):

    try:
       t=os.readlink( "/sys/class/net/%s/device"%interface_name )
       pci_name=os.path.basename(t)
       return pci_name
    except:
       return None


def virtfn_enabled(pci_name, vid):

    hFile=open("/sys/bus/pci/devices/%s/virtfn%d/enable"%(pci_name,vid), 'r')
    res=int( hFile.readline() )
    hFile.close()
    return res

def find_avaliable_pci(pci_name):

    start_vn=0
    print "/sys/bus/pci/devices/%s/virtfn%d/enable"%(pci_name, start_vn)
    while os.access( "/sys/bus/pci/devices/%s/virtfn%d/enable"%(pci_name, start_vn), os.R_OK):
       enabled = virtfn_enabled(pci_name,start_vn)
       if enabled == 0:
         t = os.path.basename( os.readlink( "/sys/bus/pci/devices/%s/virtfn%d"%(pci_name, start_vn) ) )
         log( os.path.realpath( "/sys/bus/pci/devices/%s/virtfn%d"%(pci_name, start_vn)    ) )
         real_path=os.path.realpath( "/sys/bus/pci/devices/%s/virtfn%d"%(pci_name, start_vn) )
         os.system('/bin/chown qemu:qemu  %s/resource* %s/rom %s/reset '%(real_path, real_path, real_path ) )
         return ( t[5:7],  t[8:10], t[-1] )
       else:
         start_vn = start_vn + 1

    return (None, None, None)


def find_eth_from_br( br_name ):
   flist=os.listdir("/sys/class/net/%s/brif/"%br_name)
   for f in flist:
     if f[0:3] == "eth":
         return f


def solar_iov( domxml ):
     ( br, mac ) = get_bridge_info(domxml)
     eth = find_eth_from_br( br )
     (bus, slot, function) = find_avaliable_pci( eth_pciname( eth ) )
     
     hFile = open("/sys/bus/pci/devices/0000:%s:%s.%s/mac_addr"%( bus, slot, function ) , 'w')
     hFile.write(mac)
     hFile.close()
     
     pin_pci( domxml, bus, slot, function)

def local_eth( domxml ):
     ( br, mac ) = get_bridge_info(domxml)
     eth = find_eth_from_br( br )
     
     hFile = open("/sys/class/net/%s/device/local_addrs"%eth, 'w')
     hFile.write("+"+mac)
     hFile.close()

def parse_vmname(vmname):

     words=vmname.split('-')
     if len(words) != 2 or not words[1].isdigit():
         return (None, None)

     if int(words[1]) <= 0 :
         return (None, None)
     
     if not words[0] in ["infraserver", "webserver", "appserver", "dbserver", "mailserver", "idleserver"]:
         return (None, None)

     return ( words[0], int(words[1]) )

def pin_cpu_mem(fun_name, tid, domxml):
    
    if tid in range(1,13):
        pin_mem(domxml, "0", 'strict')
        pin_cpu(domxml, "0-7,16-23")
       
    if tid in range(13,25): 
        pin_mem(domxml, "1", 'strict')
        pin_cpu(domxml, "8-15,24-31")
     
    if tid == 25:
       if "mailserver" == fun_name or "infraserver"  == fun_name :
           pin_cpu(domxml, "0-7,16-23")
           pin_mem(domxml, "0", 'preferred')
       else:
           pin_cpu(domxml, "8-15,24-31")
           pin_mem(domxml, "1", 'preferred')
          

def solar_bind(fun_name, tid, domxml):

    if fun_name in ["infraserver", "webserver", "appserver", "mailserver" ]:
        solar_iov(domxml)
    else:
        local_eth(domxml)
          

if __name__=="__main__":
    
   log("started")
   if len(sys.argv) != 2:
       sys.exit(0)

   dom_file=sys.argv[1]

   try:
         
        domxml = xml.dom.minidom.parse(dom_file)
        vmname = get_vm_name(domxml)
        (fun_name, tid) = parse_vmname(vmname)
        if not fun_name:
            file("/tmp/%s.xml"%vmname, 'w').write(domxml.toxml(encoding='utf-8'))
        else:
            pin_cpu_mem(fun_name, tid, domxml)
            solar_bind(fun_name, tid, domxml)
            create_vdisk(fun_name, tid, domxml)
            
            file(dom_file, 'w').write(domxml.toxml(encoding='utf-8'))
            file("/tmp/%s.xml"%vmname, 'w').write(domxml.toxml(encoding='utf-8'))

   except:
        log( '[unexpected error]: %s\n' % traceback.format_exc() )
        sys.exit(2)
