#!/usr/bin/python3
#
#------------------------------------------------------------------------------
# dbx_template_scale_aws
#
# take input template and scale data nodes to desired number
#------------------------------------------------------------------------------

from __future__ import print_function
import sys, os
import argparse
import pprint, json, copy

try: from debug import DEBUG
except: DEBUG = 0


#------------------------------------------------------------------------------
# Strings that must be defined in input template file for data nodes
#------------------------------------------------------------------------------
DBX_DATA_INSTANCE = 'dbXDataInstance'

# Since the input template don't have the dbXDataInstance part as by default it
# creates only single node, dbXDataInstance has been made a variable and on the
# fly it will be created depending on the nodes parameter
data_instance = '''
{
	"dbXDataInstance": 
    {
      "Type": "AWS::EC2::Instance",
      "DependsOn" : "NatGateway",
      "Properties": 
      {
        "UserData" : { "Fn::Base64" : {"Fn::Join" : [ "\\n", [
          "#cloud-config",
          {"Fn::Join" : [ ": ", ["repo_update", "false"] ] },
          {"Fn::Join" : [ ": ", ["repo_upgrade", "none"] ] },
          {"Fn::Join" : [ ": ", ["package_upgrade", "false"] ] },
          "",
          "#dbx",
          {"Fn::Join" : [ " : ", ["dbx-cluster", { "Ref" : "AWS::StackName" }] ] },
          {"Fn::Join" : [ " : ", ["dbx-head", { "Fn::GetAtt" : [ "dbXHeadInstance", "PrivateIp" ] }] ] }
        ] ] }},
        "ImageId" : { "Fn::FindInMap" : [ "AWSRegion2AMI", {"Ref" : "AWS::Region"}, "AMI" ] },
        "InstanceType": {"Ref" : "InstanceType"},
        "EbsOptimized" : { "Fn::FindInMap" : [ "EBSOptimization", {"Ref" : "InstanceType"}, "Optimized" ] },
        "KeyName": {"Ref" : "KeyName"},
        "PlacementGroupName": {"Fn::If" : ["UsePlacementGroup",{"Ref" : "PlacementGroup"},{"Ref" : "AWS::NoValue"}] },
        "NetworkInterfaces" : [{
          "GroupSet"                 : [{"Ref" : "dbxSG"}],
          "AssociatePublicIpAddress" : "false",
          "DeviceIndex"              : "0",
          "DeleteOnTermination"      : "true",
          "SubnetId"                 : {"Ref" : "PrivateSubnet"},
          "PrivateIpAddress"         : "10.0.0.11"
        }],
        "Tags" : [
          { "Key" : "Name", "Value" : {"Fn::Join" : [ "-", [{ "Ref" : "AWS::StackName" }, "Data"] ] } },
          { "Key" : "Application", "Value" : { "Ref" : "AWS::StackId"} },
          { "Key" : "SGDependency" , "Value" : { "Ref" : "dbxSGIngress" } }
        ],
        "BlockDeviceMappings" : [
        {
          "DeviceName" : "/dev/sdn",
          "Ebs" : {
            "VolumeType" : "gp2",
            "DeleteOnTermination" : "true",
            "Encrypted" : {"Ref" : "EncryptVolumes"},
            "VolumeSize" : "128"
          }
        },
		{
          "DeviceName" : "/dev/sdo",
          "Ebs" : {
            "VolumeType" : "gp2",
            "DeleteOnTermination" : "true",
            "Encrypted" : {"Ref" : "EncryptVolumes"},
            "VolumeSize" : "128"
          }
        },
		{
          "DeviceName" : "/dev/sdp",
          "Ebs" : {
            "VolumeType" : "gp2",
            "DeleteOnTermination" : "true",
            "Encrypted" : {"Ref" : "EncryptVolumes"},
            "VolumeSize" : "128"
          }
        },
		{
          "DeviceName" : "/dev/sdq",
          "Ebs" : {
            "VolumeType" : "gp2",
            "DeleteOnTermination" : "true",
            "Encrypted" : {"Ref" : "EncryptVolumes"},
            "VolumeSize" : "128"
          }
        },
        {
          "DeviceName" : "/dev/sdr",
          "Ebs" : {
            "VolumeType" : "gp2",
            "DeleteOnTermination" : "true",
            "Encrypted" : {"Ref" : "EncryptVolumes"},
            "VolumeSize" : "128"
          }
        },
		{
          "DeviceName" : "/dev/sds",
          "Ebs" : {
            "VolumeType" : "gp2",
            "DeleteOnTermination" : "true",
            "Encrypted" : {"Ref" : "EncryptVolumes"},
            "VolumeSize" : "128"
          }
        },
		{
          "DeviceName" : "/dev/sdt",
          "Ebs" : {
            "VolumeType" : "gp2",
            "DeleteOnTermination" : "true",
            "Encrypted" : {"Ref" : "EncryptVolumes"},
            "VolumeSize" : "128"
          }
        },
		{
          "DeviceName" : "/dev/sdu",
          "Ebs" : {
            "VolumeType" : "gp2",
            "DeleteOnTermination" : "true",
            "Encrypted" : {"Ref" : "EncryptVolumes"},
            "VolumeSize" : "128"
          }
        }
        ]
      }
    }
}	
'''

data_instance_template = json.loads(data_instance)
DBX_DATA_NAME = 'Data'

#------------------------------------------------------------------------------
# whole_num
#
# convert input to int and check if it is a whole number (integer >= 0)
# return None on error
#------------------------------------------------------------------------------
def _whole_num(value):
	try:
		number = int(value)
	except:
		number = None
	else:
		if number < 0:
			return None
	return number


#------------------------------------------------------------------------------
# validate_network_address
# simple validation -  make sure there are 4 valid octets
#------------------------------------------------------------------------------
def _validate_network_address(addr):
	octets = addr.split('.')
	if len(octets) != 4:
		return False

	for octet in octets:
		number = _whole_num(octet)
		if number == None:
			return False
		elif number >= 255:
			return False
	return True


#------------------------------------------------------------------------------
# next_network_address
#
# compute next network address
#
# TODO: make more general - currently retuns None if lowest octet is exceeded
#------------------------------------------------------------------------------
def _next_network_address(addr):
	newAddr = None
	if _validate_network_address(addr) == True:
		octets = addr.split('.')
		number = _whole_num(octets[3]) + 1
		if number < 255:
			octets[3] = str(number)
			newAddr = '.'.join(octets)
	return newAddr


#------------------------------------------------------------------------------
# find_change_string
#
# recurse a nested (json) data structure to find each occurance of a particular 
# string and replace it (values only, not keys)
#------------------------------------------------------------------------------
def _find_change_string(obj, string, newString):
	if isinstance(obj, str) == True:
		if obj == string:
			obj = newString
	elif isinstance(obj, dict) == True:
		for key in obj:
			key = _find_change_string(obj[key], string, newString)
	elif isinstance(obj, list) == True:
		for i, item in enumerate(obj):
			obj[i] = _find_change_string(item, string, newString)
	return obj


#------------------------------------------------------------------------------
# find_change_key_value_string
#
# recurse a nested (json) data structure to find each occurance of a particular 
# key and replace its string value
#
# TODO: test this!
#------------------------------------------------------------------------------
def _find_change_key_value_string(obj, key, newString):
	if isinstance(obj, dict) == True:
		for nkey in obj:
			if isinstance(obj[nkey], str) == True:
				if nkey == key:
					nkey = newValue
			else:
				nkey = _find_change_key_value_string(obj[nkey], key, newString)
	elif isinstance(obj, list) == True:
		for i, item in enumerate(obj):
			obj[i] = _find_change_key_value_string(item, key, newString)
	return obj


#------------------------------------------------------------------------------
# parse_file
#
# main parsing routine
# finds the data node instance and replicates it as specified
#------------------------------------------------------------------------------
def _process_template(nodes, ifp, ofp, tab):
	try:
		template = json.load(ifp)
	except:
		sys.stderr.write("Error: invalid input file\n")
		return 1

	if DEBUG > 1:
		print('*')
		pp = pprint.PrettyPrinter(indent=4)
		pp.pprint(template)
		print('Nodes =',nodes)
		print('*')

	# if there is a mapping constant defined for nodes, change it to new number
	try:
		origNodes=template['Mappings']['Constants']['InstanceValues']['DataNodes']
	except:
		pass
	else:
		template['Mappings']['Constants']['InstanceValues']['DataNodes'] = nodes

	# for existing VPC
	try:
		vpc = template['Parameters']['Vpc']
		# add the Data subnet details to the input vpc template
		if nodes > 0:
			template['Parameters']['SubnetData'] = {"Type" : "AWS::EC2::Subnet::Id", "Description" : "SubnetId of an existing private subnet (for the primary network) in your Virtual Private Cloud (VPC)", "ConstraintDescription" : "must be an existing private subnet in the selected Virtual Private Cloud."}
	except:
		pass
	else:
		# remove the attributes not required for existing VPC template
		data_instance_template['dbXDataInstance'].pop('DependsOn', None)
		data_instance_template['dbXDataInstance']['Properties']['NetworkInterfaces'][0].pop('PrivateIpAddress', None)
		# add SubnetData for multi nodes vpc template
		data_instance_template['dbXDataInstance']['Properties']['NetworkInterfaces'][0]['SubnetId']['Ref'] = "SubnetData"

	dataInstance=data_instance_template[DBX_DATA_INSTANCE]

	for node in range(nodes):
		nodeID = str(node+1).zfill(2)
		# make a copy of the original under a new name
		dataInstanceName = DBX_DATA_INSTANCE + nodeID
		if DEBUG: print(dataInstanceName)
		template['Resources'][dataInstanceName] = copy.deepcopy(dataInstance)

		# update new instance "Name"
		tags = template['Resources'][dataInstanceName]['Properties']['Tags']
		for tag in tags:
			if tag['Key'] == 'Name':
				tag['Value'] = _find_change_string(tag['Value'], DBX_DATA_NAME, DBX_DATA_NAME+nodeID)

		# update original network addresses for next node
		# if network address not defined, then skip it
		networkInterfaces = dataInstance['Properties']['NetworkInterfaces']
		for networkInterface in networkInterfaces:
			if 'PrivateIpAddress' in networkInterface:
				if DEBUG: print(networkInterface['PrivateIpAddress'])
				next_addr = _next_network_address(networkInterface['PrivateIpAddress'])
				if next_addr == None:
					sys.stderr.write("Error: bad network address\n")
					return 1
				networkInterface['PrivateIpAddress'] = next_addr

	json.dump(template, ofp, sort_keys=True, indent=tab, separators=(',', ':'))

	return 0


def _usage():
	print("aws_template_scale <input_template> <num_data_nodes>")


#==============================================================================
# main
#
# collect and validate input args
#==============================================================================
if __name__ == '__main__':


	parser = argparse.ArgumentParser(description='process AWS template for dbx to scale data nodes')
	parser.add_argument("nodes", type=int, help="number of data nodes")
	parser.add_argument("--ifile", "-i", type=str, help="input file")
	parser.add_argument("--ofile", "-o", type=str, help="output file")
	#parser.add_argument("--verbose", "-v", action="store_true", help="increase output verbosity")
	parser.add_argument("--tab", "-t", nargs='?', type=int, default=None, const=None, help="output tab (indent) spacing")
	parser.add_argument("--debug", nargs='?', type=int, default=None, const=None, help="enable debug output level")

	args = parser.parse_args()

	if args.ifile != None:
		try:
			ifp = open(args.ifile, 'r')
		except:
			sys.stderr.write("Error: Could not open input file: "+args.ifile+"\n")
			exit(1)
		else:
			if DEBUG: print('Found input file:', args.ifile)

	else:
		ifp = sys.stdin


	if args.ofile != None:
		try:
			ofp = open(args.ofile, 'w')
		except:
			sys.stderr.write("Error: Could not open output file: "+args.ofile+"\n")
			exit(1)
		else:
			if DEBUG: print('Found output file:', args.ofile)

	else:
		ofp = sys.stdout


	if args.nodes < 0:
		sys.stderr.write("Error: Number of nodes must be >= 0\n")
		exit(1)


	if args.tab != None and args.tab < 0:
		sys.stderr.write("Error: Tab must be >= 0\n")
		exit(1)


	if args.debug != None:
		DEBUG = args.debug
		print("DEBUG =", DEBUG)


	sys.exit(_process_template(args.nodes, ifp, ofp, args.tab))

