//Compile with g++ budgetbluescreen.cpp -o budgetbluescreen
//Requires EasyBMP

#include "EasyBMP.cpp"
#include <math.h>
#include <cstdlib>
#define PI (3.141592653589793)
using namespace std;

int width, height;
BMP Reference, Backdrop;
float get_color_distance(unsigned char r1,unsigned char g1,unsigned char b1,unsigned char r2,unsigned char g2,unsigned char b2);
void doframe(int num);
int get_hue(int ri, int gi, int bi);
int get_bright(RGBApixel *pix);
void regionalise();
unsigned int getregionsize(unsigned int regionnum);
void togglemaskregion(unsigned int regionnum);
void eraseregion(unsigned int regionnum);
void regionreplace(unsigned int replace, unsigned int with);
unsigned int classpixel(float delta, unsigned char refbright, unsigned char pixbright);
void erode();
void cellularfilt(unsigned char *mask);
void makeinitialmask(BMP ThisFrame, unsigned char *mask, unsigned char *hintmap);
void remove_small_regions(unsigned char *mask);
void boxblur(unsigned char *map, unsigned int iterations);

float threshhold;
int brightthresh;
unsigned char backclampblackthresh;
unsigned char foreclampblackthresh;
unsigned int regionthresh;
unsigned int *regions;
unsigned char *mask;
unsigned int maskblurrad;
unsigned int refineitr;
unsigned int erodeitr;
float hintstrength;
int main(int argc, char *argv[]){
  brightthresh=64;
  maskblurrad=10;
  refineitr=3;
  hintstrength=0.005;
  if(argc<=1){
    printf("Two files are expected in the CWD: ref.bmp holding a reference frame of the background, and backdrop.bmp holding the backdrop to replace it.\n\n");
    printf("Paramaters are in the form <name>=<val>. If not specified, defaults will be used:\n");
    printf("threshhold=int Vector dot threshhold. Lower values favor more translucency. This is a non-linear setting: A value of infinity would correspond to unconditional opacity. Default 15.\n");
    printf("brightdthresh=(char) Brightness delta threshhold from 0 to 255. Just put 255 to disable. Used to improve handling of reflective highlights. If your actor wears glasses, you'll want this. Default %u.\n", brightthresh);
    printf("minregion=(int) Smallest feature threshhold in pixels. This serves as a very effective mask noise removal method, but you'll have to disable it if objects are being thrown or your props have holes in. Default 800.\n");
    printf("backblackclampthresh=(char) Reference black clamp to background threshold. 0 to disable. Use if your backdrop has black elements. Your actor won't be able to walk in front of them, and it does require the background be a bit more diffusely lit or shadows will be mistaken for foreground objects. Default 0 (disabled).\n");
    printf("foreclampblackthresh=(char) Black clamp to foreground threshhold. 0 to disable. Turning it up improves the handling of very dark clothing, at the expense of making the program much fussier about having a semi-decent lighting setup and mostly wrinkle-free backdrop. Default 0 (disabled).\n");
    printf("maskblur=(int) Mask hint blur radius. Turning this up will make the mask edges 'smoother.' Removes things that poke out of a larger region. Default %u\n", maskblurrad);
    printf("hintstrength=(float) Previous iteration hint strength. Default %f\n", hintstrength);
    printf("refineitr=(int) Refine operations. Default %u\n", refineitr);
    printf("erode=(int) Erode passes before final mask. Makes the mask just slightly smaller, to eliminate the 'halo' effect. Default 0.\n");
    printf("start=(int) and end=int() - first and last frame numbers, inclusive, to process. Expecting n.bmp in In/ and writing to same in Out/.\n");
    printf("\n\nYou should find better results if you run the video through a temporal smooth first. A little less noise makes a big difference.\n");
    printf("This is a PROOF OF CONCEPT program. It is not intended for production use, and as such it isn't performance optimised and doesn't come with a nice clean interface. If you want something a bit more refined, tell me so I'll have some motivation to spend time on it. Or do it yourself. The algorithm isn't complicated.\n");
   printf("http://birds-are-nice.me/video/bluescreen.shtml\n");
    return(1);
  }
  threshhold=1.00-(1.00/15);
  regionthresh=800;
  backclampblackthresh=0;
  foreclampblackthresh=0;
  erodeitr=0;
  unsigned int start=0;
  unsigned int end=0;
  for(int argnum=1; argnum<argc; argnum++)
    if(index(argv[argnum], '=')){
      char *val=index(argv[argnum], '=')+1;
      char *para=(char*)malloc(strlen(argv[argnum]));
      strcpy(para, argv[argnum]);
      index(para, '=')[0]=0;
      if(strcmp(para, "threshhold")==0)
        threshhold=1-(1.00/(float)atoi(val));
      if(strcmp(para, "brightdthresh")==0)
        brightthresh=atoi(val);
      if(strcmp(para, "minregion")==0)
        regionthresh=atoi(val);
      if(strcmp(para, "backclampblackthresh")==0)
        backclampblackthresh=atoi(val);
      if(strcmp(para, "foreclampblackthresh")==0)
        foreclampblackthresh=atoi(val);
      if(strcmp(para, "start")==0)
        start=atoi(val);
      if(strcmp(para, "end")==0)
        end=atoi(val);
      if(strcmp(para, "maskblur")==0)
        maskblurrad=atoi(val);
      if(strcmp(para, "refineitr")==0)
        refineitr=atoi(val);
      if(strcmp(para, "erode")==0)
        erodeitr=atoi(val);
      if(strcmp(para, "hintstrength")==0)
        hintstrength=atof(val);
    }

  Reference.ReadFromFile("ref.bmp");
  Backdrop.ReadFromFile("backdrop.bmp");
  width=Reference.TellWidth();
  height=Reference.TellHeight();
  printf("Dimensions %u,%u\n", width, height);
  printf("Threshholds %f, %u.\n", threshhold, brightthresh);
  for(int n=start; n<=end; n++)
    doframe(n);
}
/*
unsigned int abs(int val){
if(val>=0)
  return(val);
return(0-val);
}*/

unsigned int classpixel(float delta, unsigned char refbright, unsigned char pixbright){
  //return true for background.
  if(refbright<backclampblackthresh)
    return(1);
  if(pixbright<foreclampblackthresh)
    return(0);
  if((pixbright-refbright)>brightthresh)
    return(0);
  if(delta < threshhold)
    return(0);
  return(1);
}

void makeinitialmask(BMP ThisFrame, unsigned char *mask, unsigned char *hintmap){
  for( int y=0; y < height; y++)
    for( int x=0; x < width ; x++){
      float delta=get_color_distance(Reference(x, y)->Red,
                                    Reference(x, y)->Green,
                                    Reference(x, y)->Blue,
                                    ThisFrame(x, y)->Red,
                                    ThisFrame(x, y)->Green,
                                    ThisFrame(x, y)->Blue);
//    printf("%f ", delta);
      unsigned char refbright=(Reference(x,y)->Red+Reference(x,y)->Green+Reference(x, y)->Blue)/3;
      unsigned char pixbright=(ThisFrame(x,y)->Red+ThisFrame(x,y)->Green+ThisFrame(x, y)->Blue)/3;
      if(!hintmap)
        mask[x+(y*width)]=classpixel(delta, refbright, pixbright);
      else
        mask[x+(y*width)]=classpixel(delta+((hintmap[x+(y*width)]-127)*hintstrength), refbright, pixbright);
    }
}

void cellularfilt(unsigned char *mask){
  unsigned int changed=0, lastchanged=0, itt=0;
  do{
    changed=0;
    for( int y=1; y < height-1; y++)
      for( int x=1; x < width-1 ; x++){
        int vote=0;
        for (int y1=-1; y1<=1; y1++)
          for (int x1=-1; x1<=1; x1++)
            if(mask[x+x1+((y+y1)*width)] & 1) 
              vote++;
        if(vote>=5)
          mask[x+(y*width)]=mask[x+(y*width)] | 2; //Using bit 1 to store the result.
        if((mask[x+(y*width)] == 1) || (mask[x+(y*width)] == 2))
          changed++;
      }
    for( int y=1; y < height-1; y++)
      for( int x=1; x < width-1 ; x++)
        mask[x+(y*width)]=mask[x+(y*width)]>> 1; //Erase bit 0
//    printf("Cellular iteration tweeked %d pixels.\n", changed);
    itt++;
  }while(changed>4 && itt<32);

}
void doframe(int num){
  printf("Processing frame %d.\n", num);
  BMP ThisFrame;
  char filename[128];
  mask=(unsigned char *)malloc(width*height);
  sprintf(filename, "In/%d.bmp", num);
  ThisFrame.ReadFromFile(filename);
  printf("Read %s.\n", filename);
  printf("Estimating mask.\n");
  makeinitialmask(ThisFrame, mask, NULL);
  printf("Mask estimated.\n");
  //This iterative filter (It's actually a form of cellular automata) makes the mask tidier by sharpening edges.
  cellularfilt(mask);


  regions=(unsigned int*)malloc(width*height*sizeof(unsigned int));
  if(regionthresh){
    printf("Regionalising.\n");
    remove_small_regions(mask);
  }

  for(int refinepass=0; refinepass<refineitr; refinepass++){
    for(int n=0; n<(width*height); n++)
      if(mask[n])
        mask[n]=255;
    boxblur(mask, maskblurrad);
    makeinitialmask(ThisFrame, mask, mask);
    cellularfilt(mask);
    if(regionthresh)
      remove_small_regions(mask);
  }

  for(int n=0; n<erodeitr; n++)
    erode();

  printf("Applying mask.\n");fflush(stdout);
  sprintf(filename, "In/%d.bmp", num);
  printf("Re-reading input %s.\n", filename);fflush(stdout);
  ThisFrame.ReadFromFile(filename);
  printf("  Done\n");fflush(stdout);
  for( int y=0; y < height ; y++)
    for( int x=0; x < width; x++)
      if(mask[x+(y*width)])
        ThisFrame.SetPixel(x,y,Backdrop.GetPixel(x,y));

  sprintf(filename, "Out/%d.bmp", num);
  ThisFrame.WriteToFile(filename);
  printf("Frame done.\n");
  free(regions);
  free(mask);
}

void remove_small_regions(unsigned char *mask){
  regionalise();
  int regioncount=0, exregions=0;;
  for( int y=0; y < height ; y++)
    for( int x=0; x < width; x++)
    if(regions[x+(y*width)]!=0){
      if(getregionsize(regions[x+(y*width)])<regionthresh){
        togglemaskregion(regions[x+(y*width)]);
        exregions++;
      }
      eraseregion(regions[x+(y*width)]);
      regioncount++;
    }
    printf("Processed %u regions, erased %u.\n", regioncount, exregions);
}

unsigned int getregionsize(unsigned int regionnum){
  unsigned int size=0;
  unsigned int pixels=width*height;
  for(unsigned int n=0;n<pixels;n++)
    if(regions[n]==regionnum)
      size++;
  return(size);
}

void togglemaskregion(unsigned int regionnum){
//  printf("Toggling %u.\n", regionnum);
  unsigned int pixels=width*height;
  for(unsigned int n=0;n<pixels;n++)
    if(regions[n]==regionnum)
      mask[n]=(!mask[n]);

}

void eraseregion(unsigned int regionnum){
  unsigned int pixels=width*height;
  for(unsigned int n=0;n<pixels;n++)
    if(regions[n]==regionnum)
      regions[n]=0;
}

void regionreplace(unsigned int replace, unsigned int with){
  unsigned int pixels=width*height;
  for(unsigned int n=0;n<pixels;n++)
    if(regions[n]==replace)
      regions[n]=with;
}

void regionalise(){
  unsigned int nextregion=1; //Zero is specially reserved.
  regions[0]=nextregion++;

  for(int n=1;n<width;n++)
    if(mask[n]==mask[n-1])
      regions[n]=regions[n-1];
    else
      regions[n]=nextregion++;
  for(int n=1;n<height;n++)
    if(mask[n*width]==mask[(n-1)*width])
      regions[n*width]=regions[(n-1)*width];
    else
      regions[n*width]=nextregion++;
  for(int y=1;y<height;y++)
    for(int x=1;x<width;x++){
      if(mask[x+(y*width)]==mask[x+((y-1)*width)])
        regions[x+(y*width)]=regions[x+((y-1)*width)];
      else if(mask[x+(y*width)]==mask[x+(y*width)-1])
        regions[x+(y*width)]=regions[x+(y*width)-1];
      else regions[x+(y*width)]=nextregion++;
    }

  for(int y=1;y<height;y++)
    for(int x=1;x<width;x++){
      if((mask[x+(y*width)]==mask[x+(y*width)-1])&&(regions[x+(y*width)]!=regions[x+(y*width)-1])){
        if(regions[x+(y*width)]>regions[x+(y*width)-1])
          regionreplace(regions[x+(y*width)], regions[x+(y*width)-1]);
        else
          regionreplace(regions[x+(y*width)-1], regions[x+(y*width)]);
      }
      if((mask[x+(y*width)]==mask[x+((y-1)*width)])&&(regions[x+(y*width)]!=regions[x+((y-1)*width)])){
        if(regions[x+(y*width)]>regions[x+((y-1)*width)])
          regionreplace(regions[x+(y*width)], regions[x+((y-1)*width)]);
        else
          regionreplace(regions[x+((y-1)*width)], regions[x+(y*width)]);
      }
    }
}


void erode(){
  unsigned char *line=(unsigned char *)malloc(width);
  for(int y=0;y<height;y++){
    memset(line, 0, width);
    line[0]=mask[y*width];line[width-1]=mask[(y*width)+width-1];
    for(int x=1;x<(width-1);x++)
      if(mask[x+(y*width)]){
        line[x-1]=1;line[x]=1;line[x+1]=1;
      }
    for(int x=0;x<width;x++)
      mask[x+(y*width)]=line[x];
  }
  free(line);
  //Separable function. Yay.
  unsigned char *col=(unsigned char *)malloc(height);
  for(int x=0;x<width;x++){ //Rather untidy cachewise, but... it'll do.
    memset(col, 0, height);
    col[0]=mask[x];col[height-1]=mask[x+(height*(width-1))];
    for(int y=1;y<(height-1);y++)
      if(mask[x+(y*width)]){
        col[y]=1;col[y-1]=1;col[y+1]=1;
      }
    for(int y=0;y<height;y++)
      mask[x+(y*width)]=col[y];
  }
  free(col);
}


int get_bright(RGBApixel *pix){
  return((pix->Red + pix->Blue + pix->Green)/3);
}

float get_color_distance(unsigned char r1,unsigned char g1,unsigned char b1,unsigned char r2,unsigned char g2,unsigned char b2){
  //Vectors! Dot product, that's all.
  unsigned int dot=(r1*r2)+(g1*g2)+(b1*b2);
  float l1=sqrt((r1*r1)+(b1*b1)+(g1*g1));
  float l2=sqrt((r2*r2)+(b2*b2)+(g2*g2));
  float result=acos(dot/(l1*l2));
//  printf("Angle %f\n", result);
  r1=255-r1;
  r2=255-r2;
  b1=255-b1;
  b2=255-b2;
  g1=255-g1;
  g2=255-g2;
  dot=(r1*r2)+(g1*g2)+(b1*b2);
  l1=sqrt((r1*r1)+(b1*b1)+(g1*g1));
  l2=sqrt((r2*r2)+(b2*b2)+(g2*g2));
  result=(result+acos(dot/(l1*l2)))/2;

  return(1-(result/(PI/2))); //Normalise to 0-1.

}

void boxblur(unsigned char *map, unsigned int iterations){

  unsigned char *line=(unsigned char *)malloc(width);
  unsigned char *col=(unsigned char *)malloc(height);
  for(int n=0;n<iterations;n++){
    for(int y=0;y<height;y++){
      line[0]=map[y*width];line[width-1]=map[(y*width)+width-1];
      for(int x=1;x<(width-1);x++)
        line[x]=(map[(y*width)+x-1]+map[(y*width)+x]+map[(y*width)+x+1])/3;
      for(int x=0;x<width;x++)
        map[x+(y*width)]=line[x];
    }
   //Separable function. Yay.
    for(int x=0;x<width;x++){ //Rather untidy cachewise, but... it'll do.
      col[0]=map[x];col[height-1]=map[x+(height*(width-1))];
      for(int y=1;y<(height-1);y++)
        col[y]=(map[((y-1)*width)+x]+map[(y*width)+x]+map[((y+1)*width)+x])/3;
      for(int y=0;y<height;y++)
        map[x+(y*width)]=col[y];
    }
  }
  free(line);
  free(col);


}


