4.2 Functions for reward matrix and querying initating mutations

Function to create reward matrix

create_reward_matrix<-function(Known_mat,weights){
  
  num_type <- 2
  num_mutations <- nrow(Known_mat); 
  mutant_names  <- Known_mat$Genes
  num_clones    <- ncol(Known_mat)
  num_states    <- num_type^num_mutations
  
  possible_mut_list<- unlist(apply(as.matrix(Known_mat),1,function(x){list(0:max(unique(as.numeric(x[-1])))) }),recursive = FALSE)
 
  states<-data.frame(expand.grid(possible_mut_list))
  state_interactions<-data.frame(expand.grid(apply(states[,1:num_mutations],1,function(x){paste(x,collapse="_",sep="_")}),
                                             apply(states[,1:num_mutations],1,function(x){paste(x,collapse="_",sep="_")})))
  
  state_interactions$possible<-ifelse(apply(state_interactions,1,function(x){
    A<-as.numeric(do.call(cbind,strsplit(as.character(x[1]),split="_")))
    B<-as.numeric(do.call(cbind,strsplit(as.character(x[2]),split="_")))
    sum(abs(A-B))<=1
  }),0,NA)
  
  state_interactions$action<-apply(state_interactions,1,function(x){
    A<-as.numeric(do.call(cbind,strsplit(as.character(x[1]),split="_")))
    B<-as.numeric(do.call(cbind,strsplit(as.character(x[2]),split="_")))
    if(!is.na(x["possible"])){
      if(sum(abs(B-A))==0){
        return("stay")
      } else{
        return(mutant_names[which((B-A)==1)])
      }
    }
  })
  
  dat<-setNames(state_interactions%>%filter(action%in%c(mutant_names,"stay")),
                c("State","NextState","Reward","Action"))[,c(1,4,2,3)]
  
  dat$Reward <- as.numeric(apply(dat,1,function(x){
    ifelse(x$NextState%in%names(weights),weights[x$NextState],x$Reward)
  }))
  dat$Reward <- as.numeric(apply(dat,1,function(x){
    ifelse(x$Action%in%"stay",0,x$Reward)
  }))
  dat$State <- as.character(dat$State)
  dat$NextState <- as.character(dat$NextState)
  dat$Action <- as.character(dat$Action)
  
  control <- list(alpha = 0.8, gamma = 0.9)
  model <- ReinforcementLearning(data = dat, s = "State", a = "Action", r = "Reward",  s_new = "NextState",  iter =  1,control=control)
  x<- model$Q
  rownames(x) <- substring(rownames(x),1)
  Q_mat <- setNames(melt(x),c("State","Action","Q"))
  set<-inner_join(dat,Q_mat,by=c("State","Action"))
  set$Valid <- TRUE
  return(set)
  }

Function for retraining with reinforcement learning

create_reward_matrix_retrain<-function(Known_mat,weights){
  
  num_type <- 2
  num_mutations <- nrow(Known_mat); 
  mutant_names  <- Known_mat$Genes
  num_clones    <- ncol(Known_mat)
  num_states    <- num_type^num_mutations
  
  possible_mut_list<- unlist(apply(as.matrix(Known_mat),1,function(x){list(0:max(unique(as.numeric(x[-1])))) }),recursive = FALSE)
 
  states<-data.frame(expand.grid(possible_mut_list))
  state_interactions<-data.frame(expand.grid(apply(states[,1:num_mutations],1,function(x){paste(x,collapse="_",sep="_")}),
                                             apply(states[,1:num_mutations],1,function(x){paste(x,collapse="_",sep="_")})))
  
  state_interactions$possible<-ifelse(apply(state_interactions,1,function(x){
    A<-as.numeric(do.call(cbind,strsplit(as.character(x[1]),split="_")))
    B<-as.numeric(do.call(cbind,strsplit(as.character(x[2]),split="_")))
    sum(abs(A-B))<=1
  }),0,NA)
  
  state_interactions$action<-apply(state_interactions,1,function(x){
    A<-as.numeric(do.call(cbind,strsplit(as.character(x[1]),split="_")))
    B<-as.numeric(do.call(cbind,strsplit(as.character(x[2]),split="_")))
    if(!is.na(x["possible"])){
      if(sum(abs(B-A))==0){
        return("stay")
      } else{
        return(mutant_names[which((B-A)==1)])
      }
    }
  })
  
  dat<-setNames(state_interactions%>%filter(action%in%c(mutant_names,"stay")),
                c("State","NextState","Reward","Action"))[,c(1,4,2,3)]
  
  dat$Reward <- as.numeric(apply(dat,1,function(x){
    ifelse(x$NextState%in%names(weights),weights[x$NextState],x$Reward)
  }))
  dat$Reward <- as.numeric(apply(dat,1,function(x){
    ifelse(x$Action%in%"stay",0,x$Reward)
  }))
  dat$State <- as.character(dat$State)
  dat$NextState <- as.character(dat$NextState)
  dat$Action <- as.character(dat$Action)
  
  control <- list(alpha = 0.8, gamma = 0.9)
  model1 <- ReinforcementLearning(data = dat, s = "State", a = "Action", r = "Reward",  s_new = "NextState",  iter =  1,control=control)
  model <- ReinforcementLearning(data = dat, s = "State", a = "Action", r = "Reward",  s_new = "NextState",  iter =  1000,control=list(alpha = 0.8, gamma = 0.9,epsilon=0.4),model=model1)

  
  x<- model$Q
  rownames(x) <- substring(rownames(x),1)
  Q_mat <- setNames(melt(x),c("State","Action","Q"))
  set<-inner_join(dat,Q_mat,by=c("State","Action"))
  set$Valid <- TRUE
  return(set)
  }

Query initiating mutations function

query_initiating_mutations<-function(graph_results_test){
  graph_results_test<- graph_results[[sample]]
  start_index<-paste(rep(0,length(strsplit(graph_results_test$State[1],split="_")[[1]])),sep="_",collapse="_")
  possible_starting_actions<-graph_results_test%>%filter(State==start_index&Action!="stay")%>%pull(Action)
  final_results<-list()
  initating_action_count<-0
  for(initating_action in possible_starting_actions){
   # print(initating_action)
    set <- graph_results_test
    initating_action_count<-initating_action_count+1
    storage_results<- list()
    branches<-0
    state_to_kill <- set%>%filter(State==start_index&Action==initating_action)%>%pull(NextState)
    start_killed <- sum(set%>%filter(State==state_to_kill)%>%pull(Valid))
    while(start_killed>0){
      #print(branches)
      # print(start_killed)
      branches <- branches +1
      number_of_mutations<-0
      state_log<- list()
      optimal_reward<-list()
      action_log<-list()
      current_state<- start_index
      indicator<-TRUE
      nextState<-0
      while(current_state!=nextState)  {
        # print(number_of_mutations)
        number_of_mutations <- number_of_mutations+1
        if(number_of_mutations==1){
          state_log[[number_of_mutations]] <- start_index
        }
        current_state  <- state_log[[number_of_mutations]]
        nextState_indicator<- FALSE
        
        while(nextState_indicator==FALSE){
          
          if(number_of_mutations==1){
            max_potential_action_index<-  set%>%
              filter(State==current_state&Action==initating_action)
          } else {
            max_potential_action_index <- set%>%
              filter(State==current_state&Valid==TRUE)%>%
              filter(Q==max(Q))%>%slice_sample(n=1)
          }
          if(nrow(max_potential_action_index)==0){
            break
          }
          max_potential_action <- max_potential_action_index%>%pull(NextState)
          next_valid_action <- any(set%>%filter(State==max_potential_action&Action!="stay")%>%pull(Valid))  
          if(next_valid_action==TRUE){
            nextState <-max_potential_action
            current_action <-  max_potential_action_index%>%pull(Action)
            nextState_indicator==TRUE
            break
          } else{
            set[set$State%in%max_potential_action_index["State"]&
                  set$Action%in%max_potential_action_index["Action"],"Valid"] <- FALSE  
          }
        }
        if(nrow(set%>%filter(State==current_state&Action==current_action))==0){
          optimal_reward[[number_of_mutations]] <-NA
        } else {
          optimal_reward[[number_of_mutations]] <- set%>%
            filter(State==current_state&Action==current_action)%>%
            pull(Reward) 
        }
        state_log[[number_of_mutations+1]]<- nextState
        action_log[[number_of_mutations]] <- current_action 
        if(current_action==nextState){
          indicator==FALSE
          state_log[[number_of_mutations+1]]<-NULL
          break
        }
      }
      optimal_reward[[number_of_mutations+1]] <- NA
      action_log[[number_of_mutations+1]] <- NA
      storage_results[[branches]] <-data.frame("states"=do.call(rbind,state_log),#[1:(length(state_log)-1)]),
                                               "actions"=do.call(rbind,action_log),
                                               "reward"=do.call(rbind,optimal_reward),
                                               "nextState"=do.call(rbind,c(state_log[2:length(state_log)],NA)) )
      storage_results[[branches]] <- storage_results[[branches]]%>%
        filter(states!=nextState)
      storage_results[[branches]]$cumulative_reward <- cumsum(storage_results[[branches]]$reward)
      
      #storage_results[[branches]] <-storage_results[[branches]][1:which.max(storage_results[[branches]]$cumulative_reward), ]
      set[set$State%in%current_state&set$Action%in%current_action,"Valid"] <- FALSE
      start_killed <- sum(set%>%filter(State==state_to_kill)%>%pull(Valid))
    }
    final_results[[initating_action_count]]<-storage_results[!duplicated(storage_results)]
  }
  names(final_results)<-possible_starting_actions
  return(final_results)
}

Trajectory summarization

trajectory_summariztion <- function(sample,optimal_mutants_only=FALSE){
   #Extract out sample of interest
        all_results <-final_results[[sample]]

        #apply over each potential initating mutation and identify the stepwise trajectory that accumulates the most reward        
        all_results_filtered<-setNames(lapply(names(all_results),function(initiating_mutation){
                                #print(initiating_mutation)
                                storage_results<-all_results[[initiating_mutation]]
                                storage_results[lapply(storage_results,function(x){sum(x$reward,na.rm = TRUE)})==0]<-NULL
                                  if(length(storage_results)==0){
                                    return(NULL)
                                  }
                                storage_results<-lapply(storage_results,function(x){x[1:which.max(x$cumulative_reward),]})
                               
                                if(length(storage_results)==0){
                                  print("error")
                                 return(NULL)
                                  break
                                } else {
                                 
                                # Extract columnss of interest
                                final<-do.call(rbind,storage_results)[,c("states","nextState","reward","actions")]
                                
                                # Remove decisions that do not result in a state change, or terminal nodes that do not exist
                                nodes_to_remove <- setdiff(
                                                      setdiff(final$nextState,
                                                                   final$states),
                                                          final_sample_summary[[sample]]$Clones$Clone)
                               final <- final%>%
                                              filter(!nextState%in%nodes_to_remove)%>%
                                              distinct()%>%
                                              mutate("initiating_mutation"=initiating_mutation)%>%
                                              unite(col="edge",states,nextState,sep="->",remove=FALSE)%>%
                                              relocate(edge,.after = last_col())
                               return(final) 
                               }
                              }),names(all_results))
      
      if(length(all_results_filtered)==0){
        return(NULL)
      } else if(all(lapply(all_results_filtered,is.null))){
        return(NULL)
      } else {

      
      optimal<-names(which.max(do.call(c,lapply(all_results_filtered,function(x){
                                                  sum(x$reward)
                                        }))))
      all_mutants <-unique(names(all_results_filtered))
        
      mutation_output <-if(optimal_mutants_only){
            (optimal)
          } else{
            (all_mutants)
          }

      final<-do.call(rbind,all_results_filtered)%>%
                                  filter(initiating_mutation%in%mutation_output)%>%
                                  mutate(observed=ifelse(states%in%final_sample_summary[[sample]]$Clones$Clone&
                                                         nextState%in%final_sample_summary[[sample]]$Clones$Clone,
                                                                       "Yes","No"))
  return(final)
      }
}

plot_optimal_graph_for_trajectory_new<-function(sample,optimal_mutants_only){
  
      final <- trajectory_summariztion(sample,optimal_mutants_only=TRUE)
  
      if(final=="error"|is.null(final)){
        return("error")
      } else{
      
      graph<-graph_from_data_frame(final,directed=T)
      weights<-final_sample_summary[[sample]]$Clones$Count/sum(final_sample_summary[[sample]]$Clones$Count)
      names(weights) <-final_sample_summary[[sample]]$Clones$Clone
      weight_subset<-weights[names(weights)%in%names(V(graph))]
      nodes_to_add_names<-setdiff(names(V(graph)),names(weights))
      nodes_to_add <- rep(0.1,length(nodes_to_add_names))
      names(nodes_to_add)<-nodes_to_add_names
      weight_final <- c(weight_subset,nodes_to_add)[names(V(graph))]

      clone_colors <-ifelse(names(V(graph))%in%final_sample_summary[[sample]]$Clones$Clone,brewer.pal(5,"Reds")[5],"grey80")
      observe_edges <- final%>%filter(observed=="Yes")%>%pull(edge)
      plot(graph,layout=layout_as_tree,
                  vertex.color=ifelse(names(V(graph))%in%final_sample_summary[[sample]]$Clones$Clone,brewer.pal(5,"Reds")[5],"grey80"),
                  vertex.frame.color=ifelse(names(V(graph))%in%final_sample_summary[[sample]]$Clones$Clone ,brewer.pal(5,"Reds")[5],"grey80"),
                  vertex.size=log2(1+weight_final*500),
                  vertex.label=NA,
                  edge.color=ifelse(edge_attr(graph)$edge %in%observe_edges,brewer.pal(5,"Blues")[5],"grey80"))#, 

       }
}