require_relative "./ops.rb"
require_relative "./lex.rb"

AST = Struct.new(:token, :args)
BlockAST = Struct.new(:token, :args, :block, :block_arg)
ArgNode = Struct.new(:arg_ind)
VarNode = Struct.new(:name)
DataNode = Struct.new(:impl, :str)

RegisterName = "@ register"

class Stack
  def initialize
    @stack = []
    @negative_used = 0
  end
  def push(ind)
    @stack << ind
    self
  end
  def pop
      # return node inds lazily since negative stack size means future values that don't exist yet
    if @stack.empty?
      stack_ind = @negative_used += 1
      ->{@stack[-stack_ind]} # only valid once stack size is finalized
    else
      @stack.pop
    end
  end
  def peek
    r = pop
    push r
    r
  end
  def size
    @stack.size - @negative_used
  end
  def finalize(is_main, arg_ind)
    implicit_args = is_main && @negative_used == 0 && size >= 1 ? 0 : [1-size, 1].max
    implicit_args.times{ @stack << arg_ind }
    [@stack[0,size], implicit_args]
  end
end

def parse_main(tokens)
  registers = get_var_ids_set(tokens)
  tokens = split_ids(tokens, registers)
  add_implicit_ends(tokens, registers)
  nodes = []
  main_inds, _, block_arg_ind, token_ind=*parse(tokens, 0, nodes, true, nil, [], registers, false)
  raise if block_arg_ind != 0 # ir assumes 0 => 0
  raise IogiiError.new "program has an extra >", tokens[token_ind-1] if token_ind<=tokens.size
  nodes.each{|node|
    if AST === node || BlockAST === node
      node.args.map!{|a|a=a.call while Proc===a; a}
    end
    if BlockAST === node
      node.block = node.block.call while Proc === node.block
      node.block_arg = node.block_arg.call while Proc === node.block_arg
    end
    if ArgNode === node
      node.arg_ind = node.arg_ind.call while Proc === node.arg_ind
    end
  }
  main_inds.map!{|a| a=a.call while Proc===a; a}
  registers.each{|k,v| registers[k] = v = v.call while Proc===v }
  [nodes, main_inds, registers]
end

def split_ids(tokens, registers)
  tokens.map{|t|
    if t.type==:id && !lookup_op(t.str) && !registers[t.str]
      warn("splitting %p into chars because no op/var by that name, but it is surrounded by whitespace" % t.str, t) if t.surrounded_by_whitespace
      t.split
    else
      t
    end
  }.flatten
end

def get_var_ids_set(tokens)
  set = {}
  tokens.size.times{|i|
    if tokens[i].str == "~"
      raise IogiiError.new "a token must follow assignment operator", tokens[i] if !tokens[i+1]
      raise IogiiError.new "cannot set ~", tokens[i] if tokens[i+1].str == "~"
      set[tokens[i+1].str] = true
    end
  }
  set
end

class DummyStack
  def pop; 0; end
  def peek; 0; end
  def <<(a); self; end
  def empty?; false; end
end
DUMMY_STACK = DummyStack.new

# todo so inefficient, but obvious alg
def can_brackets_match(begins, ends)
  return true if ends.empty?
  return false if begins.empty?
  s=" "*([begins[-1],ends[-1]].max+1)
  begins.each{|i| s[i]="(" }
  ends.each{|i| raise if s[i] != " "; s[i]=")" }
  s.tr!" ",""
  s["()"] = "" while s["()"]
  return !s[")"]
end

# start at leftmost >
# try each block opener starting at right most, if parse matches, then move on to next >, else mark that opener as needing implicit end, and try previous opener
  # if implicit opener doesn't end before >, or number of openers less than remaining >s then force explicit match (implicit args/ops will be used)
def add_implicit_ends(tokens, registers)
  begins = []
  ends = registers['>'] ? [] : tokens.size.times.filter{|ti| tokens[ti].str == ">" }
  all_begins = tokens.size.times.filter{|ti| t=tokens[ti]
      !registers[t.str] && (o=lookup_op(t.str)) && o.block
    }
  raise IogiiError.new "unmatched >", tokens[ends[-1]] if !can_brackets_match(all_begins, ends)
  prev_end = 0
  ends.each.with_index{|e,ei|
    begins += all_begins.filter{|bi| bi >= prev_end && bi < e }
    begin_token = nil
    loop {
      begin_token = tokens[begins.last]
      block_inds, implicit_args, _, _, could_end_earlier = *parse(tokens, begins.last + 1, [], false, DUMMY_STACK, DUMMY_STACK, registers, false)
      begins.pop
      break if !could_end_earlier
      # this block terminates normally
      break if block_inds.size == 1 && implicit_args == 1
      # todo definitely not efficient to call this over and over
      break if !can_brackets_match(begins + all_begins.filter{|bi| bi >= e }, ends[ei..-1])
    }
    begin_token.explicit_end = true
    prev_end = e
  }
end

def parse(tokens, token_ind, nodes, is_main, outer, var_stack, registers, implicit_end)
  could_end_earlier = false
  stack=Stack.new
  nodes << ArgNode.new(arg_ind = nodes.size)

  loop {
    t=tokens[token_ind]
    token_ind+=1
    break if token_ind > tokens.size
    if registers[t.str]
      stack.push nodes.size
      nodes << VarNode.new(t.str)
      next
    end
    break if t.str == ">"
    if t.str == "!"
      raise IogiiError.new "no outer stack in main", t if !outer
      stack.push outer.pop
    elsif t.str == "$"
      stack.push arg_ind
    elsif t.str == "["
      var_stack << stack.peek
    elsif t.str == "]"
      raise IogiiError.new "var stack is empty, cannot pop", t if var_stack.empty?
      stack.push var_stack.pop
    elsif t.str == "~"
      registers[tokens[token_ind].str] = stack.peek
      token_ind += 1
    elsif t.str == "@"
      if tokens[token_ind..-1].any?{|t|t.str == "@"} # get
        stack.push nodes.size
        nodes << VarNode.new(RegisterName)
      else # set
        registers[RegisterName] = stack.peek
      end
    else case t.type
    when :int, :char, :str
      data_impl, token_ind = *parse_data(tokens, orig_token_ind=token_ind)
      stack.push nodes.size
      str = tokens[orig_token_ind-1...token_ind].map(&:str).join
      nodes << DataNode.new(data_impl, str)
    when :id, :sym
      op = lookup_op(t.str)
      raise IogiiError.new "unknown op", t if !op
      args = op.parse_nargs.times.map{ stack.pop }.reverse
      if op.block
        block_inds, _, block_arg_ind, token_ind, _ =*parse(tokens, token_ind, nodes, false, stack, var_stack, registers, !t.explicit_end)
        raise IogiiError.new "excessive values in block and implicit ops not yet implemented", t if block_inds.size > 1
        raise if block_inds.size < 1
        block_ind = block_inds[0]
        nodes << BlockAST.new(t, args, block_ind, block_arg_ind)
      else
        nodes << AST.new(t, args)
      end
      stack.push nodes.size-1
      stack.push args[0] if op.name == "mdup" || op.name == "dup"
      if "del" == t.str
        stack.pop
      end
    when :commas
      raise IogiiError.new "commas must follow data or op that can vectorize", t
    else raise "unknown token type %p" % t.type.to_s
    end end
    if stack.size == 0
      break if implicit_end
      could_end_earlier = true
    end
  }
  [*stack.finalize(is_main, arg_ind), arg_ind, token_ind, could_end_earlier]
end

def parse_data(tokens, token_ind)
  prev_is_datum = false
  tokens = tokens[token_ind-1..-1].take_while{|t|
    if [:int,:char,:str].include?(t.type)
      if prev_is_datum
        false
      else
        prev_is_datum = true
        true
      end
    else
      prev_is_datum = false
      t.type == :commas
    end
  }
  token_ind = token_ind + tokens.size - 1

  depth = tokens.select{|t|t.type == :commas}.map{|t|t.str.size}.max||0
  token_type = nil
  tokens.reject{|t|t.type == :commas}.each{|t|
    token_type ||= t.type
    token_type = :str if token_type != t.type
  }

  value = parse_list(tokens,depth,token_type)
  case token_type
  when :str
    type = CharType
    depth += 1
  when :char
    type = CharType
  when :int
    type = IntType
  else; raise
  end
  [new_op("data", type_and_rank_to_str(type,depth)){ value }, token_ind]
end

def parse_list(tokens,depth,to_type)
  if depth == 0
    raise IogiiError.new 'found ", ," invalid data format' if tokens.empty?
    parse_datum(tokens[0],to_type)
  else
    to_lazy_list(split_tokens(tokens, ","*depth).map{|ts|
      parse_list(ts, depth-1,to_type)
    })
  end
end

def parse_datum(token,to_type)
  case token.type
  when :int
    to_type == :str ? str(token.str.to_i.const) : token.str.to_i
  when :str
    str_to_lazy_list(parse_str(token.str[1..(token.str[-1] != '"' || token.str.size==1 ? -1 : -2)]))
  when :char
    raise IogiiError.new "empty char", token if token.str.size < 2
    x = parse_char(token.str[1..-1]).ord
    to_type == :str ? [x.const, Null] : x
  end
end
