local function trim(s) return s:gsub("^ +", ""):gsub(" +$", "").."" end
local function validWordsFromInstrs(instrs)
	local words = {}
	for mnem, _ in pairs(instrs) do
		for word in mnem:gmatch("[^ ]+") do
			words[word] = true
		end
	end
	return words
end
local function decodeNumber(n)
	n = trim(n)
	local sign = 1; if n:sub(1, 1)=="-" then sign = -1; n = n:sub(2, #n); end;
	if n:sub(1, 1)=="$"       then return sign*tonumber(n:sub(2, #n  ), 16), (#n-1)/2
	elseif n:sub(1, 2)=="0x"  then return sign*tonumber(n:sub(3, #n  ), 16), (#n-2)/2
	elseif n:sub(#n, #n)=="h" then return sign*tonumber(n:sub(1, #n-1), 16), (#n-1)/2
	else
		local v = tonumber(n) or error("invalid number "..n)
		if v>=-128 and v<=255 then return v, 1
		elseif v>=-32768 and v<=65535 then return v, 2
		else error("out-of-range number "..v) end
	end
end
local function mnemFromLine(line, instrs, validWords)
	local firstWord = line:match("^[^ ]+")
	local imms = {}
	local function addNum(n)
		local val, len = decodeNumber(n)
		table.insert(imms, { val = val, len = len } )
		return " imm"..(len*8).." "
	end
	local function addLabel(n)
		local len = 2
		if instrs[line:gsub(trim(n), "imm8", 1, true)] then len = 1 end
		n = trim(n)
		table.insert(imms, { label = n, len = len } )
		return " imm"..(len*8).." "
	end
	
	local mnem = " "..line:gsub(" ", "  ").." "
	mnem = mnem:gsub(" %-?%$[0-9a-fA-F]+ ", function(n)                                 return addNum  (n)     end)
	mnem = mnem:gsub(" %-?0x[0-9a-fA-F]+ ", function(n)                                 return addNum  (n)     end)
	mnem = mnem:gsub(" %-?[0-9a-fA-F]+h " , function(n) if not validWords[trim(n)] then return addNum  (n) end end)
	mnem = mnem:gsub(" %-?[0-9]+ "        , function(n) if not validWords[trim(n)] then return addNum  (n) end end)
	mnem = mnem:gsub(" [a-zA-Z0-9_]+ "    , function(n) if not validWords[trim(n)] then return addLabel(n) end end)
	mnem = trim(mnem):gsub(" +", " ")
	
	return mnem, imms
end
local function addByte(state, val)
	assert(val>=-128 and val<=255, "invalid byte "..val)
	assert(not state.memory[state.curAddr], "overwriting memory at "..state.curAddr)
	state.memory[state.curAddr] = val%256
	state.curAddr = state.curAddr + 1
end
local function addWord(state, val)
	assert(val>=0 and val<=65535, "invalid word "..val)
	addByte(state, math.floor(val/256))
	addByte(state, val%256)
end
local function assembleInstruction(line, state, instrs, validWords)
	local mnem, imms = mnemFromLine(line, instrs, validWords)
	local opcode = instrs[mnem] or error("invalid instruction "..line.." (mnem "..mnem..")")
	local writeimms = true
	local padlen = 0
	if type(opcode)=="function" then
		padlen, writeimms = opcode(imms)
		state.curAddr = state.curAddr + padlen
	elseif opcode>=0 then
		addByte(state, opcode)
	end
	if writeimms then
		for _, imm in ipairs(imms) do
			if imm.val then
				if     imm.len==1 then addByte(state, imm.val)
				elseif imm.len==2 then addWord(state, imm.val)
				else error("invalid imm len") end
			elseif imm.label then
				table.insert(state.labelReplacements, {
					name = imm.label,
					addr = state.curAddr,
					len = imm.len,
					rel = imm.len==1,
				})
				state.curAddr = state.curAddr + imm.len
			else error("invalid imm") end
		end
	end
end
local directiveFunctions = {
	fn    = function(state, fn) state.fileName = fn end,
	ln    = function(state, ln) state.lineNum = tonumber(ln) end,
	org   = function(state, addr) state.curAddr = decodeNumber(addr) end,
	align = function(state, alns) local aln = decodeNumber(alns); if state.curAddr%aln~=0 then state.curAddr = state.curAddr + (aln - state.curAddr%aln) end end,
}
local function assembleCode(code, instrs)
	local validWords = validWordsFromInstrs(instrs)
	
	local state = {
		lineNum = 0,
		fileName = "",
		curAddr = 0,
		memory = {},
		labelReplacements = {},
		labelAddrs = {},
	}
	
	for line in code:gmatch("[^\n]+") do
		if line:sub(1, 1)=="." then -- directive
			local dir, rest = line:match("^%.([^ ]+) *(.*)$")
			assert(dir and rest, "no directive on line "..line)
			local dirf = directiveFunctions[dir] or error("invalid directive "..dir)
			dirf(state, rest)
		elseif line:sub(#line, #line)==":" then -- label
			local name = line:sub(1, #line-1)
			assert(not state.labelAddrs[name], "redefinition of label "..name)
			state.labelAddrs[name] = state.curAddr
		elseif line:find("[^ ]") then
			assembleInstruction(line, state, instrs, validWords)
		end
	end
	
	for _, rep in ipairs(state.labelReplacements) do
		local labelAddr = state.labelAddrs[rep.name] or error("no label named "..rep.name)
		state.curAddr = rep.addr
		if     rep.len==1 then addByte(state, labelAddr-(rep.addr+1))
		elseif rep.len==2 then addWord(state, labelAddr)
		else error("invalid labelreplace len") end
	end
	
	return state.memory
end

local function readFile(fn)
	local fi = io.open(fn, "r") or error("could not open file "..fn)
	local text = fi:read("*a")
	fi:close()
	return text
end

local function separateCommas(l)
	local c = {}; for a in l:gmatch("[^,]+") do table.insert(c, trim(a)) end; return c;
end
local function preprocessCode(code)
	local funcmacros = {}
	code = code:gsub(".define ([a-zA-Z0-9_]+)%(([^%)]+)%) ([^\n]+)", function(name, args, repl)
		local argt = separateCommas(args)
		for argidx, arg in ipairs(argt) do assert(not arg:find("[^a-zA-Z0-9_]"), "invalid character in macro arg name: "..name.." "..arg) end
		repl = " "..repl.." "
		funcmacros[name] = function(callargs)
			local callargt = separateCommas(callargs)
			local callrepl = repl
			for argidx, arg in ipairs(argt) do
				local callarg = callargt[argidx]
				callrepl = callrepl:gsub("([^a-zA-Z0-9_])"..arg.."([^a-zA-Z0-9_])", "%1"..callarg.."%2")
			end
			return callrepl
		end
		return ""
	end)
	for name, replf in pairs(funcmacros) do code = code:gsub(name.." *%(([^%)]+)%)", replf) end
	
	local simplemacros = {}
	code = code:gsub("%.define ([a-zA-Z0-9_]+) ([^\n]+)", function(name, repl)
		assert(not simplemacros[name], "Redefinition of macro "..name)
		simplemacros[name] = repl
		return ""
	end)
	for name, repl in pairs(simplemacros) do code = code:gsub(name, repl, 1, true) end
	
	code = code:gsub("\\", "\n")
	
	return code
end
local function fixCode(code)
	code = code:gsub(",", " ")
	code = code:gsub("%]", " %] ")
	code = code:gsub("%[", " %[ ")
	code = code:gsub("\n[ \t\r\n]*", "\n")
	code = code:gsub(" +", " ")
	return code
end
local stringEscapes = { ["\\"] = "\\", ["n"] = "\n", ["r"] = "\r", ["t"] = "\t", ["0"] = "\0", ["\""] = "\"", ["\'"] = "\'", }
local function prefixCode(code, fn) -- fix strings, add line numbers
	local outt = {}
	local linenum = 1
	local function last() return outt[#outt] end
	local function out(c) assert(type(c)=="string"); table.insert(outt, c); end
	local function outn(n) out("$"..string.format("%02X", n).."\n"); end
	local state = "code" -- code, comment, string, stringesc
	local skipnl = false
	
	local lastbracelabel = 0
	local function bracelabel() lastbracelabel = lastbracelabel+1; return "_BRACE_"..lastbracelabel.."_"; end
	local bracestack = {}
	
	out(".ln 1"); out("\n");
	for i = 1, #code do
		local c = code:sub(i, i)
		if state=="code" then
			if     c=="\r"                              then
			elseif c=="\n"                              then
				linenum = linenum+1
				if skipnl then out("\\")
				else out("\n") out(".ln "..linenum) out("\n") end
				skipnl = false
			elseif c=="\t" or c==" "                    then out(" ")
			elseif c=="#"  or c==";" or c=="/"          then state = "comment"
			elseif c=="\""                              then state = "string"
			elseif c:find("^[a-zA-Z0-9_%.:%$%(%)%*,%[%]]$") then out(c)
			elseif c=="\\"                              then skipnl = true
			elseif c=="{"                               then
				
			elseif c=="}" then
				
			else error("invalid char "..c) end
		elseif state=="comment" then
			if c=="\n" then state = "code" out("\n") end
		elseif state=="string" then
			if     c=="\\" then state = "stringesc"
			elseif c=="\"" then state = "code"
			else outn(c:byte()) end
		elseif state=="stringesc" then
			out(stringEscapes[c] or error("invalid escape "..c)); state = "string";
		end
	end
	return table.concat(outt)
end
local function includeFile(fn)
	fn = fn:gsub("\\", "/")
	local code = readFile(fn)
	code = prefixCode(code, fn)
	code = ".fn "..fn.."\n"..code
	code = code:gsub(".include ([^\r\n]+)", function(fn2)
		return "\n"..includeFile(fn2).."\n"..".fn "..fn.."\n"
	end)
	return code
end
local function instrsFromArch(arch)
	local function arraySize(imms) local s = 1; for i = 1, #imms do s = s*imms[i].val end; return s; end
	local instrs = {
		imm8  = function() return 0, true end,
		imm16 = function() return 0, true end,
		byte  = function() return 1, false end,
		word  = function() return 2, false end,
		["byte imm8"] = function() return 0, true end,
		["word imm16"] = function() return 0, true end,
		["byte [ imm8 ]" ] = function(imms) return arraySize(imms)  , false end,
		["byte [ imm16 ]"] = function(imms) return arraySize(imms)  , false end,
		["word [ imm8 ]" ] = function(imms) return arraySize(imms)*2, false end,
		["word [ imm16 ]"] = function(imms) return arraySize(imms)*2, false end,
	}
	for _, instr in ipairs(arch.instructions) do
		if instr.mnem then
			instrs[instr.mnem] = instr.opcode
		end
	end
	return instrs
end
local function assembleFile(fn)
	local code = includeFile(fn)
	code = preprocessCode(code)
	code = fixCode(code)
	local arch = require("rom-8608-defs")
	local instrs = instrsFromArch(arch)
	local mem = assembleCode(code, instrs)
	return mem
end

local function printMemory(mem)
	local lastbase = -16
	for base = 0, 0xFFF0, 16 do
		local line = { string.format("%04X", base), " | " }
		local nonempty = false
		for addr = base, base+15 do
			if mem[addr] then
				nonempty = true
				table.insert(line, string.format("%02X", mem[addr]).." ")
			else
				table.insert(line, "-- ")
			end
		end
		if nonempty then
			if base ~= lastbase+16 then print("...") end
			print(table.concat(line))
			lastbase = base
		end
	end
end

local ts = ts or {
	call = function() end,
	eval = function() end,
}
ts.eval [[
	function commandShiftBrick(%x, %y, %z) { commandToServer('shiftBrick', %x, %y, %z); }
	function commandPlantBrick() { commandToServer('plantBrick'); }
]]
local function plantBrickAt(brickpos, pos)
	local dx, dy, dz = pos[1]-brickpos[1], pos[2]-brickpos[2], pos[3]-brickpos[3]
	ts.call("commandShiftBrick", dy, -dx, dz)
	ts.call("commandPlantBrick")
	brickpos[1], brickpos[2], brickpos[3] = pos[1], pos[2], pos[3]
end
local function buildMemory(mem, romsize, offset, len)
	offset = offset or 0
	
	local rombytes = romsize[1]*romsize[2]*romsize[3]/8
	if len and len>rombytes then error("rom not big enough to hold "..len.." bytes (holds "..rombytes..")") end
	if not len then
		for i = 0, 0xFFFF do
			if mem[i] and (i<offset or i>=offset+rombytes) then error("memory does not fit in rom at addr "..string.format("%04X", i)) end
		end
	end
	
	local brickpos = {0, 0, 0}
	for x = 0, romsize[1]-1 do
		for y = 0, romsize[2]-1 do
			for z = 0, romsize[3]-1 do
				local addr = offset + ((romsize[3]/8)*(x + y*romsize[1]) + math.floor(z/8))
				local pow = math.pow(2, z%8)
				local data = (addr>=offset and ((not len) or addr<offset+len) and mem[addr]) or 0
				local bit = math.floor(data/pow)%2
				if bit==1 then plantBrickAt(brickpos, {x, y, z}) end
			end
		end
	end
end

local function strtovec(str) local v = {}; for word in str:gmatch("[^ \t\r\n]+") do table.insert(v, tonumber(word)) end; return v; end
function AssembleFile(fn, romsizes, offsets, lens) offset = tonumber(offsets); len = tonumber(lens); romsize = strtovec(romsizes);
	local mem = assembleFile(fn)
	printMemory(mem)
	if #romsize>0 then assert(#romsize==3, "incorrect rom size") end
	buildMemory(mem, romsize, offset, len)
end
ts.eval [[
	function AssembleFile(%fn, %romsize, %offset, %len) { luacall("AssembleFile", %fn, %romsize, %offset, %len); }
]]

if arg then AssembleFile(arg[1] or "programs/keyboard.asm", "16 16 8", "0", "256") end