profile
viewpoint
Yichen Yang yycdavid MIT Cambridge, MA

uwplse/tamago 12

Re-implementation of the TASO compiler using equality saturation

josepablocam/preprocessing 0

Study of preprocessing impact on code search performance

yycdavid/ConditionalJuMP.jl 0

Automatic transformation of implications and complementarity into mixed-integer models in Julia

yycdavid/courses 0

fast.ai Courses

yycdavid/egg 0

egraphs good

yycdavid/tamago 0

Re-implementation of the TASO compiler using equality saturation

yycdavid/tensorflow 0

Computation using data flow graphs for scalable machine learning

push eventyycdavid/taso

Yichen Yang

commit sha 12d1262268fb0e6b0d9fee07aa281451462b52ef

New exp scripts

view details

Yichen Yang

commit sha ecd4de36ddffaf512a74e717d5301a3f3d644b1c

Merge branch 'master' of github.com:yycdavid/taso

view details

Yichen Yang

commit sha 26e102a5dbfd650d9cedcdc056efb84fab65e262

Tidy exp scripts

view details

push time in 11 days

push eventyycdavid/tamago

root

commit sha 1df5ae6c6444fed21e14f563a4dbc5ff8bf230dc

Remove debug prints

view details

root

commit sha 0e492c2a0b701785a9e49a4cd67563764af4cd83

Merge branch 'master' of https://github.com/uwplse/tamago

view details

Yichen Yang

commit sha e468da6732bc98a055a4beed36e4cba2967c1056

Tidy

view details

push time in 11 days

push eventuwplse/tamago

Yichen Yang

commit sha e468da6732bc98a055a4beed36e4cba2967c1056

Tidy

view details

push time in 11 days

push eventuwplse/tamago

Yichen Yang

commit sha e79495a5c67cf76ca70ff8319c3dd5c612b4eb00

Add batchnorm cost

view details

push time in 15 days

push eventyycdavid/tamago

Yichen Yang

commit sha e79495a5c67cf76ca70ff8319c3dd5c612b4eb00

Add batchnorm cost

view details

push time in 15 days

push eventyycdavid/taso

Yichen Yang

commit sha 6f1367a178ff1e9a8cc632f9cfac1e3bb18098cb

Add flags for alpha and iter

view details

Yichen Yang

commit sha c3b13ca2a4746ca1a7b6b430341a370f07d9c280

Merge

view details

push time in 15 days

push eventuwplse/tamago

Yichen Yang

commit sha 391a937f6344094703baa3ca31005da0a1b60479

Add dropout cost

view details

push time in 15 days

push eventyycdavid/tamago

root

commit sha 33ff0328f5d07970fb7f9e961cdb8661e3fe0e90

batchnorm broken

view details

root

commit sha 33f29e49c63d36f5630c9e5e8af7454a2f8095b7

Got batchnorm and mobilenet to work

view details

root

commit sha bb0963b26d7345ace765ee2b7051c69a96f7a623

Finish mobilenet

view details

root

commit sha 0b2e5c1f3a54f027a8f0c4d2b3e4597ac92a7880

Finish vgg.

view details

root

commit sha aba6d3b47558f7483b516dd076c95de1aef5b04b

Add squeezenet.

view details

Yichen Yang

commit sha 13b1f3507fd79afbb8e46421de4ad444eb61bc52

Add run saturation only flag, exp scripts for ablation

view details

root

commit sha e39fbeceaae5b78b0bc22f9585165b2fd8025dce

Wrap up experiments.

view details

Yichen Yang

commit sha e20dcb0e1c9c62290b13cf9560be38e5dd4037c3

Plot speedup trend with split axis

view details

Yichen Yang

commit sha 53d77b661d0e36ec75a3d9a457c9497987c76b96

Plot optimizer time trend

view details

Yichen Yang

commit sha 8d0111003b57759de8b051fad8db9f40376b0e8a

Plot number of enodes trend

view details

Yichen Yang

commit sha b349f7b20656a0009cd76a95b6e6a1e7f7dc43f7

Limit time inside multi run one

view details

Yichen Yang

commit sha f19b0e3942bb7551c492ae9fd0a5de3bb3cfd588

Exp scripts

view details

Yichen Yang

commit sha f51088ec6694be9d21a531f7b8771277334093e7

Merge

view details

Yichen Yang

commit sha 391a937f6344094703baa3ca31005da0a1b60479

Add dropout cost

view details

push time in 15 days

push eventuwplse/tamago

Yichen Yang

commit sha 13b1f3507fd79afbb8e46421de4ad444eb61bc52

Add run saturation only flag, exp scripts for ablation

view details

Yichen Yang

commit sha e20dcb0e1c9c62290b13cf9560be38e5dd4037c3

Plot speedup trend with split axis

view details

Yichen Yang

commit sha 53d77b661d0e36ec75a3d9a457c9497987c76b96

Plot optimizer time trend

view details

Yichen Yang

commit sha 8d0111003b57759de8b051fad8db9f40376b0e8a

Plot number of enodes trend

view details

Yichen Yang

commit sha b349f7b20656a0009cd76a95b6e6a1e7f7dc43f7

Limit time inside multi run one

view details

Yichen Yang

commit sha f19b0e3942bb7551c492ae9fd0a5de3bb3cfd588

Exp scripts

view details

Yichen Yang

commit sha f51088ec6694be9d21a531f7b8771277334093e7

Merge

view details

push time in 16 days

push eventuwplse/tamago

Yichen Yang

commit sha 260a3c5b269fa6a49ce0b38df1fc3908b74a7b3b

Add node limit in multi

view details

Yichen Yang

commit sha 606f0258625ae7657fe944aaa0f719b6980c51e4

Inspect number of applied rules

view details

Yichen Yang

commit sha ec31f239472ee0a711642f7cfc6a6b6956009b3d

Less printing

view details

push time in 21 days

push eventyycdavid/tamago

Yichen Yang

commit sha 260a3c5b269fa6a49ce0b38df1fc3908b74a7b3b

Add node limit in multi

view details

Yichen Yang

commit sha 606f0258625ae7657fe944aaa0f719b6980c51e4

Inspect number of applied rules

view details

Yichen Yang

commit sha ec31f239472ee0a711642f7cfc6a6b6956009b3d

Less printing

view details

push time in 21 days

push eventuwplse/tamago

Yichen Yang

commit sha 890671fa4a743ec255f307b3d35cbff58bd5e923

Support more ops for onnx to rust pipeline

view details

push time in 22 days

push eventyycdavid/tamago

Yichen Yang

commit sha b2d4a6d7ccb881666e37335833d4053ff4ae0836

Merge pull request #1 from yycdavid/master Optimizer

view details

Remy Wang

commit sha 8028321904a7a3c13559f918543dd106b83fb4c0

Turn on release flag and dump ONNX

view details

Remy Wang

commit sha 7def4772272ee3f80767e09b01c63c749f11412d

Cleanup.

view details

Remy Wang

commit sha 325e41a9e73347e7466c1e18e175b75fd06ddab1

Clean up export_model interface.

view details

Remy Wang

commit sha 303f13ef078fe7028fa381ea152b991ec1ae3679

Do not save model for now

view details

Remy Wang

commit sha 14609ab1e55dcb626ff99852cb61f4e187514799

Document save_model

view details

Remy Wang

commit sha d9af23a7ad817b0b37d0c93496e0e4769dc9e981

Add nasnet onnx.

view details

Remy Wang

commit sha de2175e8c8a788c952251d30dd2b7168d5244907

Increase multi-iter

view details

Remy Wang

commit sha 234840b2aa65be7d949445fda3799422d307bed7

Compare release and debug

view details

Remy Wang

commit sha 54087f305dba8d1dc854d040030568a1f3ddb15a

Start parser.

view details

Remy Wang

commit sha 8eb27c3ede3594fdadeaff487d6cfcd9271222e0

Deserialize model.

view details

Yichen Yang

commit sha 890671fa4a743ec255f307b3d35cbff58bd5e923

Support more ops for onnx to rust pipeline

view details

push time in 22 days

push eventyycdavid/taso

Yichen Yang

commit sha 23559fe101e4987d8624e0ce847171a12150363d

Add inceptionv3

view details

push time in 22 days

MemberEvent

push eventuwplse/tamago

root

commit sha 2b6e4e9cae3285543711dbab0341e93ec7f0eba2

Update docker files to include graphviz, bind mount.

view details

yycdavid

commit sha 4a27a5b87b9c2c7b7f6580d961b59d223d58343b

Test email setting

view details

yycdavid

commit sha d4e0811b677bdb61a0b13ee7f44d4a5121484887

Converting taso rules into rw rules in egg

view details

yycdavid

commit sha 944374bccee3a9f950fbcc8c22d20a13522873bb

Update

view details

yycdavid

commit sha 25b975d1c8087432b7540d1313cec5fe5e187537

Add input shape, metadata type

view details

yycdavid

commit sha 43ddabcd60a7b78359bf3e7754c7e77c666120b9

Boilerplate code

view details

yycdavid

commit sha 0882d3eb719479faf204b0df2b8cce4fa00a2949

Update docker env to get most recent egg

view details

yycdavid

commit sha 5dfed4f2f6013bfa93fd0d2a1fb2aac9243b2869

Check applier backbone implemented

view details

yycdavid

commit sha 47f50d9f8009074d6ae709367d66215f1917b19f

Add gdb to Dockerfile

view details

yycdavid

commit sha ee62cabe3e4adc294555e52bbbc970e8d9c20385

Change check applier to pass object instead of pointer

view details

yycdavid

commit sha 1c65177a4ccdbb500a535b773ef88a3cdd193ed5

Conditional applier working, tested shape checking and not re-measure cost

view details

yycdavid

commit sha fedfb59647d9ed9c60643311bb7ab47e9d917b1d

Get cost done, greedy extraction with TASO cost tested

view details

yycdavid

commit sha e0ecb404a7dee2d6a6865c41007046d7a82cb435

Evaluate full graph runtime

view details

yycdavid

commit sha a8242b651bac3a18aa5198e65104d20d24027c2a

Test commit for pull request

view details

yycdavid

commit sha 690da218c13f262921f57d7d3efff209d83f493b

Test commit for pull request

view details

yycdavid

commit sha 1afa18f9f0c9631afb98ef5e98bb11c0a9210c15

Merge branch 'dev' of github.com:yycdavid/tamago into dev

view details

Yichen Yang

commit sha cfc993490be14a52cae864c9f1aeaaef44a36bee

Merge pull request #1 from yycdavid/dev Test commit for pull request

view details

yycdavid

commit sha ef83678d353c36740d01f16f2e9006cd10ab65c9

Add documentations

view details

yycdavid

commit sha f314d00ef7a65f50e7ae36971a2afc863c57c271

Change comment style

view details

Yichen Yang

commit sha 0621c1c7fafccf3b7ef782427a0117720aeafb85

Merge pull request #2 from yycdavid/dev Add documentations

view details

push time in a month

PR merged uwplse/tamago

Optimizer
+5637 -182

0 comment

33 changed files

yycdavid

pr closed time in a month

PR opened uwplse/tamago

Optimizer
+5637 -182

0 comment

33 changed files

pr created time in a month

push eventyycdavid/tamago

yycdavid

commit sha 625f4a3581f85b61cf1f7f5bccc9bc4f3e934273

Compute stats for runtime

view details

yycdavid

commit sha 329cd6a4d2165216c1e7ec47d7478b5346474949

Save greedy solution for ILP

view details

yycdavid

commit sha 5bcfb8731e876c09871d402ab62e8aee49516978

Initialize ILP with greedy solution

view details

yycdavid

commit sha 07cccb3dad99e713808b8dbcf74153f8bdfce2ea

Add flag to set number of threads and time limit for ILP

view details

yycdavid

commit sha f2bfeb0271005ed21b18d2c59abf84dbe9930611

Add noop to combine outputs

view details

yycdavid

commit sha f806a4e35b404c4cc46449011a5d00378304e0c6

Fix bert combine output

view details

yycdavid

commit sha fcf21ce4995f6fa3e0d1993681ac3a096fcf0f3a

Update gitignore

view details

yycdavid

commit sha bda00b776bfcb1cb41cc0c24e0409e5a140e70bf

Add flag for filter cycle and multi-pattern iterations

view details

yycdavid

commit sha 08e3d8a3cc5c82da0d347f6bf46c74493c7b341b

Filter cycle in multi rules

view details

yycdavid

commit sha 63ece8d1c2c861cb25cf801546f8b4e757aa818e

Apply half for symmetric multi pattern rules

view details

yycdavid

commit sha bb86f457f40b50ee575ccdf619759fa023ff2648

Only zero cost for all weight concat

view details

yycdavid

commit sha 9c767e0f10bfa71d67bf7df5980e51777d57183a

Refactor cost model code structure

view details

yycdavid

commit sha 05ecf37aeb9b9fda2f42fefe9c056abaf0c13137

Discount cost for all weight ops

view details

yycdavid

commit sha 58ff0c296b29460624540526c4940d8eab134f75

Bar plots on stats

view details

yycdavid

commit sha a569039e82d9cda46f542e11ae7f55bd53a6d4de

Update stats plot

view details

yycdavid

commit sha d7a24749e89e522a4ed0df9fc379497a365f4bb2

Pre-filter, add blacklist

view details

yycdavid

commit sha f240bd4937af49835437f71c837380ff9446c13e

Check blacklist before apply

view details

yycdavid

commit sha 2563c0717cd3481bedf43c6dfe44cf5c019994f3

Check blacklist in multi

view details

yycdavid

commit sha 904020ea442aeec9903faee342b2eb71b6e77049

Check blacklist in get descendents

view details

yycdavid

commit sha 97d57ab80de60a0d7ca44e71b33c6abfaf72bb0f

Get existing nodes before apply

view details

push time in a month

PR merged yycdavid/tamago

Reviewers
Dev branch

Change list:

  • Initialize ILP with greedy solution
  • Noop to combine outputs
  • Cycle filtering (naive, and efficient)
  • Correct cost model (zero for all weight cost)
  • Get stats and plot results
  • Add nasneta, and new ops involved
  • Shape inference in input interface
+2777 -598

0 comment

22 changed files

yycdavid

pr closed time in a month

push eventyycdavid/tamago

Yichen Yang

commit sha 678f14dafd840e5467465cdd09d11e3b03201c15

Update based on reviews

view details

push time in a month

PullRequestReviewEvent

Pull request review commentyycdavid/tamago

Dev branch

 impl MultiPatterns {                 if compatible(&subst_1_dec, &subst_2_dec, &map_1.var_map) {                     // If so, merge two substitutions                     let merged_subst = merge_subst(subst_1_dec, subst_2_dec, &map_1.var_map);+                    // Check if any source pattern contains blacklisted nodes+                    if self.filter_after {+                        if contains_blacklist(+                            rule.0.ast.as_ref(),+                            &mut runner.egraph,+                            &merged_subst,+                        )+                        .0 || contains_blacklist(+                            rule.1.ast.as_ref(),+                            &mut runner.egraph,+                            &merged_subst,+                        )+                        .0+                        {+                            continue;+                        }+                    }                      // check_pat on both dst patterns-                    if check_pat(rule.2.ast.as_ref(), &mut runner.egraph, &merged_subst).0 {-                        if check_pat(rule.3.ast.as_ref(), &mut runner.egraph, &merged_subst).0 {+                    let (valid_1, _, _, existing_1) = check_pat(+                        rule.2.ast.as_ref(),+                        &mut runner.egraph,+                        &merged_subst,+                        /*get_exist_nodes=*/ self.filter_after,+                    );+                    let (valid_2, _, _, existing_2) = check_pat(+                        rule.3.ast.as_ref(),+                        &mut runner.egraph,+                        &merged_subst,+                        /*get_exist_nodes=*/ self.filter_after,+                    );+                    if valid_1 && valid_2 {+                        let cycle_check_passed = if self.no_cycle {+                            if self.filter_after {+                                // Do pre-filtering using the pre-collected descendents info+                                self.check_cycle_partial(+                                    &runner.egraph,+                                    &merged_subst,+                                    &map_1.var_map,+                                    &map_2.var_map,+                                    match_1.eclass,+                                    match_2.eclass,+                                )+                            } else {+                                // Check cycle by make a pass in egraph+                                let mut descendents: HashMap<Id, HashSet<Id>> = Default::default();+                                check_cycle(+                                    &runner.egraph,+                                    &merged_subst,+                                    &map_1.var_map,+                                    &map_2.var_map,+                                    match_1.eclass,+                                    match_2.eclass,+                                    &mut descendents,+                                )+                            }+                        } else {+                            true+                        };+                        if cycle_check_passed {                             // apply dst patterns, union                             let id_1 =                                 rule.2                                     .apply_one(&mut runner.egraph, match_1.eclass, &merged_subst)                                     [0];-                            runner.egraph.union(id_1, match_1.eclass);+                             let id_2 =                                 rule.3                                     .apply_one(&mut runner.egraph, match_2.eclass, &merged_subst)                                     [0];++                            // Add the newly added nodes to the ordering list+                            if self.filter_after {+                                let existing_1 = existing_1.unwrap();+                                let existing_2 = existing_2.unwrap();+                                let (nodes_in_1, _) = add_newly_added(+                                    rule.2.ast.as_ref(),+                                    &mut runner.egraph,+                                    &merged_subst,+                                    &existing_1,+                                );+                                let existing_2_updated: HashSet<Mdl> = existing_2+                                    .iter()+                                    .chain(nodes_in_1.iter())+                                    .map(|node| node.clone())+                                    .collect();+                                add_newly_added(+                                    rule.3.ast.as_ref(),+                                    &mut runner.egraph,+                                    &merged_subst,+                                    &existing_2_updated,+                                );+                            }++                            runner.egraph.union(id_1, match_1.eclass);                             runner.egraph.union(id_2, match_2.eclass);+                        } else {                         }                     }                 }             }         }     }++    /// Returns true if there will not be a cycle introduced by applying this rule.+    ///+    /// Checking based on descendents collected at beginning of run_one(). If any input node+    /// contains descendents that is any of the matched output class, then there can be a cycle+    /// created.+    ///+    /// # Parameters+    ///+    /// - `egraph`: egraph of interest+    /// - `input_subst`: substitution containing the input variables+    /// - `var_map_1`: keys of this map contains all the input variables in source pattern 1+    /// - `var_map_2`: keys of this map contains all the input variables in source pattern 2+    /// - `out_class_1`: Id of the matched eclass of the output of source pattern 1+    /// - `out_class_2`: Id of the matched eclass of the output of source pattern 2+    fn check_cycle_partial(+        &self,+        egraph: &EGraph<Mdl, TensorAnalysis>,+        input_subst: &Subst,+        var_map_1: &HashMap<egg::Var, egg::Var>,+        var_map_2: &HashMap<egg::Var, egg::Var>,+        out_class_1: Id,+        out_class_2: Id,+    ) -> bool {+        // Get all input eclass IDs+        let input_ids: HashSet<Id> = var_map_1+            .iter()+            .chain(var_map_2.iter())+            .map(|(var, _)| *input_subst.get(*var).unwrap())+            .collect();+        // Check descendents of the input eclasses+        for id in input_ids.iter() {+            let descendents = self.descendents.as_ref().unwrap();+            let descendents_input = descendents.get(id).unwrap();+            if descendents_input.contains(&out_class_1) || descendents_input.contains(&out_class_2)+            {+                return false;+            }+        }+        true+    }+}++/// Do post-processing to remove cycles in the egraph, by adding nodes to blacklist.+pub fn remove_cycle_by_order(runner: &mut Runner<Mdl, TensorAnalysis, ()>) {+    // Update blacklist (canonicalize Ids with egraph.find())+    update_blacklist(&mut runner.egraph);++    // Update newly_added (canonicalize Ids with egraph.find()) and construct hashmap+    // for newly_added+    let updated: Vec<Mdl> = runner+        .egraph+        .analysis+        .newly_added+        .iter()+        .map(|node| node.clone().map_children(|id| runner.egraph.find(id)))+        .collect();+    let mut added_node_to_order = HashMap::<Mdl, usize>::new();+    for (i, node) in updated.iter().enumerate() {+        added_node_to_order.entry(node.clone()).or_insert(i);+    }+    // Remove cycles by adding nodes to blacklist+    remove_cycles(&mut runner.egraph, &added_node_to_order, runner.roots[0]);+}++/// Add newly added nodes in this pattern to the list of newly added nodes, for use in cycle+/// filtering+///+/// The newly added nodes are stored in the graph level metadata in egraph.analysis+///+/// # Parameters+///+/// - `pat`: the AST representation of the pattern. See egg::Pattern for more info+/// - `egraph`: E-graph of interest+/// - `subst`: mapping variable to eclass ID. See egg::Subst for more info.+/// - `existing_nodes`: the set of nodes within this pattern that already exists before this+///         pattern is applied+///+/// # Returns+///+/// A tuple of (HashSet<Mdl>, Id) where+///+/// - HashSet<Mdl>: The set of all nodes in this pattern+/// - Id: the Id into egraph of the matched root of this pattern+fn add_newly_added(+    pat: &[ENodeOrVar<Mdl>],+    egraph: &mut EGraph<Mdl, TensorAnalysis>,+    subst: &Subst,+    existing_nodes: &HashSet<Mdl>,+) -> (HashSet<Mdl>, Id) {+    match pat.last().unwrap() {+        ENodeOrVar::Var(w) => (HashSet::<Mdl>::new(), subst[*w]),+        ENodeOrVar::ENode(e) => {+            let children = e.children();+            let results: Vec<(HashSet<Mdl>, Id)> = children+                .iter()+                .map(|child| {+                    add_newly_added(+                        &pat[..usize::from(*child) + 1],+                        egraph,+                        subst,+                        existing_nodes,+                    )+                })+                .collect();++            let mut new_e = e.clone();+            let new_e_ch = new_e.children_mut();+            for (i, res) in results.iter().enumerate() {+                new_e_ch[i] = res.1;+            }++            let mut nodes_in_pat = HashSet::<Mdl>::new();+            for res in results.iter() {+                for node in res.0.iter() {+                    nodes_in_pat.insert(node.clone());+                }+            }+            nodes_in_pat.insert(new_e.clone());++            // Add to order list+            if !existing_nodes.contains(&new_e) {+                egraph.analysis.newly_added.push(new_e.clone());+            }++            (nodes_in_pat, egraph.lookup(new_e).unwrap())+        }+    }+}++/// Remove cycles in EGraph by adding nodes to blacklist+///+/// This function works by:+///     - Make a pass over egraph to get a set of cycles+///     - For each cycle, pick the node that got added latest and add it to blacklist+///     - Repeat until no cycles are left+fn remove_cycles(+    egraph: &mut EGraph<Mdl, TensorAnalysis>,+    added_node_to_order: &HashMap<Mdl, usize>,+    root: Id,+) {+    loop {+        let mut paths_from_root = HashMap::<Id, Vec<(Id, Mdl)>>::new();+        let mut cycles = Vec::<Vec<Mdl>>::new();++        get_cycles(egraph, root, &mut paths_from_root, &mut cycles);++        if cycles.len() == 0 {+            break;+        }+        for cycle in cycles.iter() {+            resolve_cycle(egraph, cycle, added_node_to_order);+        }+    }+}++/// Resolve cycle by adding node to blacklist+///+/// # Parameters+///+/// - `egraph`: E-graph of interest+/// - `cycle`: list of nodes within the cycle+/// - `added_node_to_order`: HashMap, map from node to the order that it was added into the egraph+fn resolve_cycle(+    egraph: &mut EGraph<Mdl, TensorAnalysis>,+    cycle: &[Mdl],+    added_node_to_order: &HashMap<Mdl, usize>,+) {+    // Check if any node in cycle is already in blacklist+    let already_solved = cycle+        .iter()+        .any(|node| egraph.analysis.blacklist_nodes.contains(node));+    if !already_solved {+        assert!(cycle.len() > 0);+        let (ord, n) = cycle+            .iter()+            .map(|node| {+                let order = added_node_to_order+                    .get(node)+                    .map_or(-1, |index| *index as i32);+                (order, node.clone())+            })+            .max_by_key(|(o, _)| *o)+            .unwrap();+        assert!(ord >= 0);+        egraph.analysis.blacklist_nodes.insert(n.clone());+    }+}++/// Traverse the EGraph and get a set of cycles (reachable from root)+///+/// # Parameters+///+/// - `egraph`: E-graph of interest+/// - `root`: Id of root eclass+/// - `paths_from_root`: HashMap storing for each eclass, one path from the root eclass. The path+///         is represented by a list of (eclass, enode) that composes the path+/// - `cycles`: list of cycles. Each cycle is a list of nodes. A cycle of 1->2->4->3->1 will be+///         stored as [1,2,4,3]+fn get_cycles(+    egraph: &EGraph<Mdl, TensorAnalysis>,+    root: Id,+    paths_from_root: &mut HashMap<Id, Vec<(Id, Mdl)>>,+    cycles: &mut Vec<Vec<Mdl>>,+) {+    // Get a map from Id to the eclass objects, since egg doesn't provide accessing eclass from Id+    let id_to_class: HashMap<Id, &EClass<Mdl, ValTnsr>> =+        egraph.classes().map(|class| (class.id, class)).collect();++    get_cycles_rec(+        egraph,+        root,+        /*path_to_here=*/ Vec::<(Id, Mdl)>::new(),+        &id_to_class,+        paths_from_root,+        cycles,+    );+}++/// Traverse the EGraph in DFS order, update paths_from_root and cycles on the fly+///+/// # Parameters+///+/// - `egraph`: E-graph of interest+/// - `eclass`: The current eclass that we are visiting+/// - `path_to_here`: A path from root to this eclass+/// - `id_to_class`: Map from eclass ID to the eclass objects+/// - `paths_from_root`: HashMap storing for each eclass, one path from the root eclass. The path+///         is represented by a list of (eclass, enode) that composes the path+/// - `cycles`: list of cycles. Each cycle is a list of nodes. A cycle of 1->2->4->3->1 will be+///         stored as [1,2,4,3]+fn get_cycles_rec(+    egraph: &EGraph<Mdl, TensorAnalysis>,+    eclass: Id,+    path_to_here: Vec<(Id, Mdl)>,+    id_to_class: &HashMap<Id, &EClass<Mdl, ValTnsr>>,+    paths_from_root: &mut HashMap<Id, Vec<(Id, Mdl)>>,+    cycles: &mut Vec<Vec<Mdl>>,+) {+    assert!(!paths_from_root.contains_key(&eclass));+    paths_from_root.insert(eclass, path_to_here.clone());++    let class = id_to_class.get(&eclass).unwrap();+    for node in class.iter() {+        if egraph.analysis.blacklist_nodes.contains(node) {+            continue;+        }+        for child in node.children().iter() {+            if !paths_from_root.contains_key(child) {+                // Haven't visited, so visit+                let mut path_to_child = path_to_here.clone();+                path_to_child.push((eclass, node.clone()));+                get_cycles_rec(+                    egraph,+                    *child,+                    path_to_child,+                    id_to_class,+                    paths_from_root,+                    cycles,+                );+            } else {+                // Visited before. Check if in path_to_here+                match path_to_here+                    .iter()+                    .enumerate()+                    .find(|(i, (cid, n))| *cid == *child)+                {+                    Some((i, _)) => {+                        let mut cycle: Vec<Mdl> =+                            path_to_here[i..].iter().map(|(cid, n)| n.clone()).collect();+                        cycle.push(node.clone());+                        cycles.push(cycle);+                    }+                    None => (),+                }+            }+        }+    }+}++/// Update the blacklist_nodes in egraph.analysis with the new canonical EClass IDs+fn update_blacklist(egraph: &mut EGraph<Mdl, TensorAnalysis>) {+    egraph.analysis.blacklist_nodes = egraph+        .analysis+        .blacklist_nodes+        .iter()+        .map(|node| node.clone().map_children(|id| egraph.find(id)))+        .collect();+}++/// Returns true if there will not be a cycle introduced by applying this rule.+///+/// Checking based on freshly collected descendents info. If any input node+/// contains descendents that is any of the matched output class, then there can be a cycle+/// created.+///+/// # Parameters+///+/// - `egraph`: egraph of interest+/// - `input_subst`: substitution containing the input variables+/// - `var_map_1`: keys of this map contains all the input variables in source pattern 1+/// - `var_map_2`: keys of this map contains all the input variables in source pattern 2+/// - `out_class_1`: Id of the matched eclass of the output of source pattern 1+/// - `out_class_2`: Id of the matched eclass of the output of source pattern 2+/// - `descendents`: Map from each eclass ID to its set of descendents. Constructed here.+fn check_cycle(+    egraph: &EGraph<Mdl, TensorAnalysis>,+    input_subst: &Subst,+    var_map_1: &HashMap<egg::Var, egg::Var>,+    var_map_2: &HashMap<egg::Var, egg::Var>,+    out_class_1: Id,+    out_class_2: Id,+    descendents: &mut HashMap<Id, HashSet<Id>>,+) -> bool {+    // Get all input eclass IDs+    let input_ids: HashSet<Id> = var_map_1+        .iter()+        .chain(var_map_2.iter())+        .map(|(var, _)| *input_subst.get(*var).unwrap())+        .collect();+    // Get a map from eclass IDs to eclass+    let id_to_class: HashMap<Id, &EClass<Mdl, ValTnsr>> =+        egraph.classes().map(|class| (class.id, class)).collect();+    // Check descendents of the input eclasses+    for id in input_ids.iter() {+        get_descendents(+            egraph,+            *id,+            &id_to_class,+            /*check_blacklist=*/ false,+            descendents,+        );+        let descendents_input = descendents.get(id).unwrap();+        if descendents_input.contains(&out_class_1) || descendents_input.contains(&out_class_2) {+            return false;+        }+    }+    true+}

Done

yycdavid

comment created time in a month

PullRequestReviewEvent

Pull request review commentyycdavid/tamago

Dev branch

 impl MultiPatterns {                 if compatible(&subst_1_dec, &subst_2_dec, &map_1.var_map) {                     // If so, merge two substitutions                     let merged_subst = merge_subst(subst_1_dec, subst_2_dec, &map_1.var_map);+                    // Check if any source pattern contains blacklisted nodes+                    if self.filter_after {+                        if contains_blacklist(+                            rule.0.ast.as_ref(),+                            &mut runner.egraph,+                            &merged_subst,+                        )+                        .0 || contains_blacklist(+                            rule.1.ast.as_ref(),+                            &mut runner.egraph,+                            &merged_subst,+                        )+                        .0+                        {+                            continue;+                        }+                    }                      // check_pat on both dst patterns-                    if check_pat(rule.2.ast.as_ref(), &mut runner.egraph, &merged_subst).0 {-                        if check_pat(rule.3.ast.as_ref(), &mut runner.egraph, &merged_subst).0 {+                    let (valid_1, _, _, existing_1) = check_pat(+                        rule.2.ast.as_ref(),+                        &mut runner.egraph,+                        &merged_subst,+                        /*get_exist_nodes=*/ self.filter_after,+                    );+                    let (valid_2, _, _, existing_2) = check_pat(+                        rule.3.ast.as_ref(),+                        &mut runner.egraph,+                        &merged_subst,+                        /*get_exist_nodes=*/ self.filter_after,+                    );+                    if valid_1 && valid_2 {+                        let cycle_check_passed = if self.no_cycle {+                            if self.filter_after {+                                // Do pre-filtering using the pre-collected descendents info+                                self.check_cycle_partial(+                                    &runner.egraph,+                                    &merged_subst,+                                    &map_1.var_map,+                                    &map_2.var_map,+                                    match_1.eclass,+                                    match_2.eclass,+                                )+                            } else {+                                // Check cycle by make a pass in egraph+                                let mut descendents: HashMap<Id, HashSet<Id>> = Default::default();+                                check_cycle(+                                    &runner.egraph,+                                    &merged_subst,+                                    &map_1.var_map,+                                    &map_2.var_map,+                                    match_1.eclass,+                                    match_2.eclass,+                                    &mut descendents,+                                )+                            }+                        } else {+                            true+                        };+                        if cycle_check_passed {                             // apply dst patterns, union                             let id_1 =                                 rule.2                                     .apply_one(&mut runner.egraph, match_1.eclass, &merged_subst)                                     [0];-                            runner.egraph.union(id_1, match_1.eclass);+                             let id_2 =                                 rule.3                                     .apply_one(&mut runner.egraph, match_2.eclass, &merged_subst)                                     [0];++                            // Add the newly added nodes to the ordering list+                            if self.filter_after {+                                let existing_1 = existing_1.unwrap();+                                let existing_2 = existing_2.unwrap();+                                let (nodes_in_1, _) = add_newly_added(+                                    rule.2.ast.as_ref(),+                                    &mut runner.egraph,+                                    &merged_subst,+                                    &existing_1,+                                );+                                let existing_2_updated: HashSet<Mdl> = existing_2+                                    .iter()+                                    .chain(nodes_in_1.iter())+                                    .map(|node| node.clone())+                                    .collect();+                                add_newly_added(+                                    rule.3.ast.as_ref(),+                                    &mut runner.egraph,+                                    &merged_subst,+                                    &existing_2_updated,+                                );+                            }++                            runner.egraph.union(id_1, match_1.eclass);                             runner.egraph.union(id_2, match_2.eclass);+                        } else {                         }                     }                 }             }         }     }++    /// Returns true if there will not be a cycle introduced by applying this rule.+    ///+    /// Checking based on descendents collected at beginning of run_one(). If any input node+    /// contains descendents that is any of the matched output class, then there can be a cycle+    /// created.+    ///+    /// # Parameters+    ///+    /// - `egraph`: egraph of interest+    /// - `input_subst`: substitution containing the input variables+    /// - `var_map_1`: keys of this map contains all the input variables in source pattern 1+    /// - `var_map_2`: keys of this map contains all the input variables in source pattern 2+    /// - `out_class_1`: Id of the matched eclass of the output of source pattern 1+    /// - `out_class_2`: Id of the matched eclass of the output of source pattern 2+    fn check_cycle_partial(+        &self,+        egraph: &EGraph<Mdl, TensorAnalysis>,+        input_subst: &Subst,+        var_map_1: &HashMap<egg::Var, egg::Var>,+        var_map_2: &HashMap<egg::Var, egg::Var>,+        out_class_1: Id,+        out_class_2: Id,+    ) -> bool {+        // Get all input eclass IDs+        let input_ids: HashSet<Id> = var_map_1+            .iter()+            .chain(var_map_2.iter())+            .map(|(var, _)| *input_subst.get(*var).unwrap())+            .collect();+        // Check descendents of the input eclasses+        for id in input_ids.iter() {+            let descendents = self.descendents.as_ref().unwrap();+            let descendents_input = descendents.get(id).unwrap();+            if descendents_input.contains(&out_class_1) || descendents_input.contains(&out_class_2)+            {+                return false;+            }+        }+        true+    }+}++/// Do post-processing to remove cycles in the egraph, by adding nodes to blacklist.+pub fn remove_cycle_by_order(runner: &mut Runner<Mdl, TensorAnalysis, ()>) {+    // Update blacklist (canonicalize Ids with egraph.find())+    update_blacklist(&mut runner.egraph);++    // Update newly_added (canonicalize Ids with egraph.find()) and construct hashmap+    // for newly_added+    let updated: Vec<Mdl> = runner+        .egraph+        .analysis+        .newly_added+        .iter()+        .map(|node| node.clone().map_children(|id| runner.egraph.find(id)))+        .collect();+    let mut added_node_to_order = HashMap::<Mdl, usize>::new();+    for (i, node) in updated.iter().enumerate() {+        added_node_to_order.entry(node.clone()).or_insert(i);+    }+    // Remove cycles by adding nodes to blacklist+    remove_cycles(&mut runner.egraph, &added_node_to_order, runner.roots[0]);+}++/// Add newly added nodes in this pattern to the list of newly added nodes, for use in cycle+/// filtering+///+/// The newly added nodes are stored in the graph level metadata in egraph.analysis+///+/// # Parameters+///+/// - `pat`: the AST representation of the pattern. See egg::Pattern for more info+/// - `egraph`: E-graph of interest+/// - `subst`: mapping variable to eclass ID. See egg::Subst for more info.+/// - `existing_nodes`: the set of nodes within this pattern that already exists before this+///         pattern is applied+///+/// # Returns+///+/// A tuple of (HashSet<Mdl>, Id) where+///+/// - HashSet<Mdl>: The set of all nodes in this pattern+/// - Id: the Id into egraph of the matched root of this pattern+fn add_newly_added(+    pat: &[ENodeOrVar<Mdl>],+    egraph: &mut EGraph<Mdl, TensorAnalysis>,+    subst: &Subst,+    existing_nodes: &HashSet<Mdl>,+) -> (HashSet<Mdl>, Id) {+    match pat.last().unwrap() {+        ENodeOrVar::Var(w) => (HashSet::<Mdl>::new(), subst[*w]),+        ENodeOrVar::ENode(e) => {+            let children = e.children();+            let results: Vec<(HashSet<Mdl>, Id)> = children+                .iter()+                .map(|child| {+                    add_newly_added(+                        &pat[..usize::from(*child) + 1],+                        egraph,+                        subst,+                        existing_nodes,+                    )+                })+                .collect();++            let mut new_e = e.clone();+            let new_e_ch = new_e.children_mut();+            for (i, res) in results.iter().enumerate() {+                new_e_ch[i] = res.1;+            }++            let mut nodes_in_pat = HashSet::<Mdl>::new();+            for res in results.iter() {+                for node in res.0.iter() {+                    nodes_in_pat.insert(node.clone());+                }+            }+            nodes_in_pat.insert(new_e.clone());++            // Add to order list+            if !existing_nodes.contains(&new_e) {+                egraph.analysis.newly_added.push(new_e.clone());+            }++            (nodes_in_pat, egraph.lookup(new_e).unwrap())+        }+    }+}++/// Remove cycles in EGraph by adding nodes to blacklist+///+/// This function works by:+///     - Make a pass over egraph to get a set of cycles+///     - For each cycle, pick the node that got added latest and add it to blacklist+///     - Repeat until no cycles are left+fn remove_cycles(+    egraph: &mut EGraph<Mdl, TensorAnalysis>,+    added_node_to_order: &HashMap<Mdl, usize>,+    root: Id,+) {+    loop {+        let mut paths_from_root = HashMap::<Id, Vec<(Id, Mdl)>>::new();+        let mut cycles = Vec::<Vec<Mdl>>::new();++        get_cycles(egraph, root, &mut paths_from_root, &mut cycles);++        if cycles.len() == 0 {+            break;+        }+        for cycle in cycles.iter() {+            resolve_cycle(egraph, cycle, added_node_to_order);+        }+    }+}++/// Resolve cycle by adding node to blacklist+///+/// # Parameters+///+/// - `egraph`: E-graph of interest+/// - `cycle`: list of nodes within the cycle+/// - `added_node_to_order`: HashMap, map from node to the order that it was added into the egraph+fn resolve_cycle(+    egraph: &mut EGraph<Mdl, TensorAnalysis>,+    cycle: &[Mdl],+    added_node_to_order: &HashMap<Mdl, usize>,+) {+    // Check if any node in cycle is already in blacklist+    let already_solved = cycle+        .iter()+        .any(|node| egraph.analysis.blacklist_nodes.contains(node));+    if !already_solved {+        assert!(cycle.len() > 0);+        let (ord, n) = cycle+            .iter()+            .map(|node| {+                let order = added_node_to_order+                    .get(node)+                    .map_or(-1, |index| *index as i32);+                (order, node.clone())+            })+            .max_by_key(|(o, _)| *o)+            .unwrap();+        assert!(ord >= 0);+        egraph.analysis.blacklist_nodes.insert(n.clone());+    }+}++/// Traverse the EGraph and get a set of cycles (reachable from root)+///+/// # Parameters+///+/// - `egraph`: E-graph of interest+/// - `root`: Id of root eclass+/// - `paths_from_root`: HashMap storing for each eclass, one path from the root eclass. The path+///         is represented by a list of (eclass, enode) that composes the path+/// - `cycles`: list of cycles. Each cycle is a list of nodes. A cycle of 1->2->4->3->1 will be+///         stored as [1,2,4,3]+fn get_cycles(+    egraph: &EGraph<Mdl, TensorAnalysis>,+    root: Id,+    paths_from_root: &mut HashMap<Id, Vec<(Id, Mdl)>>,+    cycles: &mut Vec<Vec<Mdl>>,+) {+    // Get a map from Id to the eclass objects, since egg doesn't provide accessing eclass from Id+    let id_to_class: HashMap<Id, &EClass<Mdl, ValTnsr>> =+        egraph.classes().map(|class| (class.id, class)).collect();++    get_cycles_rec(+        egraph,+        root,+        /*path_to_here=*/ Vec::<(Id, Mdl)>::new(),+        &id_to_class,+        paths_from_root,+        cycles,+    );+}++/// Traverse the EGraph in DFS order, update paths_from_root and cycles on the fly+///+/// # Parameters+///+/// - `egraph`: E-graph of interest+/// - `eclass`: The current eclass that we are visiting+/// - `path_to_here`: A path from root to this eclass+/// - `id_to_class`: Map from eclass ID to the eclass objects+/// - `paths_from_root`: HashMap storing for each eclass, one path from the root eclass. The path+///         is represented by a list of (eclass, enode) that composes the path+/// - `cycles`: list of cycles. Each cycle is a list of nodes. A cycle of 1->2->4->3->1 will be+///         stored as [1,2,4,3]+fn get_cycles_rec(+    egraph: &EGraph<Mdl, TensorAnalysis>,+    eclass: Id,+    path_to_here: Vec<(Id, Mdl)>,+    id_to_class: &HashMap<Id, &EClass<Mdl, ValTnsr>>,+    paths_from_root: &mut HashMap<Id, Vec<(Id, Mdl)>>,+    cycles: &mut Vec<Vec<Mdl>>,+) {+    assert!(!paths_from_root.contains_key(&eclass));+    paths_from_root.insert(eclass, path_to_here.clone());++    let class = id_to_class.get(&eclass).unwrap();+    for node in class.iter() {+        if egraph.analysis.blacklist_nodes.contains(node) {+            continue;+        }+        for child in node.children().iter() {+            if !paths_from_root.contains_key(child) {+                // Haven't visited, so visit+                let mut path_to_child = path_to_here.clone();+                path_to_child.push((eclass, node.clone()));+                get_cycles_rec(+                    egraph,+                    *child,+                    path_to_child,+                    id_to_class,+                    paths_from_root,+                    cycles,+                );+            } else {+                // Visited before. Check if in path_to_here+                match path_to_here

Done

yycdavid

comment created time in a month

Pull request review commentyycdavid/tamago

Dev branch

 impl MultiPatterns {                 if compatible(&subst_1_dec, &subst_2_dec, &map_1.var_map) {                     // If so, merge two substitutions                     let merged_subst = merge_subst(subst_1_dec, subst_2_dec, &map_1.var_map);+                    // Check if any source pattern contains blacklisted nodes+                    if self.filter_after {+                        if contains_blacklist(+                            rule.0.ast.as_ref(),+                            &mut runner.egraph,+                            &merged_subst,+                        )+                        .0 || contains_blacklist(+                            rule.1.ast.as_ref(),+                            &mut runner.egraph,+                            &merged_subst,+                        )+                        .0+                        {+                            continue;+                        }+                    }                      // check_pat on both dst patterns-                    if check_pat(rule.2.ast.as_ref(), &mut runner.egraph, &merged_subst).0 {-                        if check_pat(rule.3.ast.as_ref(), &mut runner.egraph, &merged_subst).0 {+                    let (valid_1, _, _, existing_1) = check_pat(+                        rule.2.ast.as_ref(),+                        &mut runner.egraph,+                        &merged_subst,+                        /*get_exist_nodes=*/ self.filter_after,+                    );+                    let (valid_2, _, _, existing_2) = check_pat(+                        rule.3.ast.as_ref(),+                        &mut runner.egraph,+                        &merged_subst,+                        /*get_exist_nodes=*/ self.filter_after,+                    );+                    if valid_1 && valid_2 {+                        let cycle_check_passed = if self.no_cycle {+                            if self.filter_after {+                                // Do pre-filtering using the pre-collected descendents info+                                self.check_cycle_partial(+                                    &runner.egraph,+                                    &merged_subst,+                                    &map_1.var_map,+                                    &map_2.var_map,+                                    match_1.eclass,+                                    match_2.eclass,+                                )+                            } else {+                                // Check cycle by make a pass in egraph+                                let mut descendents: HashMap<Id, HashSet<Id>> = Default::default();+                                check_cycle(+                                    &runner.egraph,+                                    &merged_subst,+                                    &map_1.var_map,+                                    &map_2.var_map,+                                    match_1.eclass,+                                    match_2.eclass,+                                    &mut descendents,+                                )+                            }+                        } else {+                            true+                        };+                        if cycle_check_passed {                             // apply dst patterns, union                             let id_1 =                                 rule.2                                     .apply_one(&mut runner.egraph, match_1.eclass, &merged_subst)                                     [0];-                            runner.egraph.union(id_1, match_1.eclass);+                             let id_2 =                                 rule.3                                     .apply_one(&mut runner.egraph, match_2.eclass, &merged_subst)                                     [0];++                            // Add the newly added nodes to the ordering list+                            if self.filter_after {+                                let existing_1 = existing_1.unwrap();+                                let existing_2 = existing_2.unwrap();+                                let (nodes_in_1, _) = add_newly_added(+                                    rule.2.ast.as_ref(),+                                    &mut runner.egraph,+                                    &merged_subst,+                                    &existing_1,+                                );+                                let existing_2_updated: HashSet<Mdl> = existing_2+                                    .iter()+                                    .chain(nodes_in_1.iter())+                                    .map(|node| node.clone())+                                    .collect();+                                add_newly_added(+                                    rule.3.ast.as_ref(),+                                    &mut runner.egraph,+                                    &merged_subst,+                                    &existing_2_updated,+                                );+                            }++                            runner.egraph.union(id_1, match_1.eclass);                             runner.egraph.union(id_2, match_2.eclass);+                        } else {                         }                     }                 }             }         }     }++    /// Returns true if there will not be a cycle introduced by applying this rule.+    ///+    /// Checking based on descendents collected at beginning of run_one(). If any input node+    /// contains descendents that is any of the matched output class, then there can be a cycle+    /// created.+    ///+    /// # Parameters+    ///+    /// - `egraph`: egraph of interest+    /// - `input_subst`: substitution containing the input variables+    /// - `var_map_1`: keys of this map contains all the input variables in source pattern 1+    /// - `var_map_2`: keys of this map contains all the input variables in source pattern 2+    /// - `out_class_1`: Id of the matched eclass of the output of source pattern 1+    /// - `out_class_2`: Id of the matched eclass of the output of source pattern 2+    fn check_cycle_partial(+        &self,+        egraph: &EGraph<Mdl, TensorAnalysis>,+        input_subst: &Subst,+        var_map_1: &HashMap<egg::Var, egg::Var>,+        var_map_2: &HashMap<egg::Var, egg::Var>,+        out_class_1: Id,+        out_class_2: Id,+    ) -> bool {+        // Get all input eclass IDs+        let input_ids: HashSet<Id> = var_map_1+            .iter()+            .chain(var_map_2.iter())+            .map(|(var, _)| *input_subst.get(*var).unwrap())+            .collect();+        // Check descendents of the input eclasses+        for id in input_ids.iter() {+            let descendents = self.descendents.as_ref().unwrap();+            let descendents_input = descendents.get(id).unwrap();+            if descendents_input.contains(&out_class_1) || descendents_input.contains(&out_class_2)+            {+                return false;+            }+        }

Done

yycdavid

comment created time in a month

PullRequestReviewEvent

Pull request review commentyycdavid/tamago

Dev branch

 fn check_pat(                     for (i, res) in results.iter().enumerate() {                         new_e_ch[i] = res.1.unwrap();                     }-                    let looked = egraph.lookup(new_e);+                    let looked = egraph.lookup(new_e.clone());                     match looked {

Done

yycdavid

comment created time in a month

PullRequestReviewEvent

Pull request review commentyycdavid/tamago

Dev branch

 impl Applier<Mdl, TensorAnalysis> for CheckApply {     } } +/// Check if the matched graph of the pattern contains any blacklisted nodes+///+/// # Returns+///+/// A tuple of (bool, Option<Id>) where+///+/// - bool: true if the nodes in this pattern contains some node in blacklist+/// - Option<Id>: if the nodes in this pattern do not contain blacklisted+///     nodes, then this is the Id of the matched EClass of the root of this pattern(pat.last())+fn contains_blacklist(+    pat: &[ENodeOrVar<Mdl>],+    egraph: &mut EGraph<Mdl, TensorAnalysis>,+    subst: &Subst,+) -> (bool, Option<Id>) {+    match pat.last().unwrap() {+        ENodeOrVar::Var(w) => (false, Some(subst[*w])),+        ENodeOrVar::ENode(e) => {+            let children = e.children();+            let results: Vec<(bool, Option<Id>)> = children+                .iter()+                .map(|child| contains_blacklist(&pat[..usize::from(*child) + 1], egraph, subst))+                .collect();++            let mut contains = false;

Done

yycdavid

comment created time in a month

PullRequestReviewEvent

Pull request review commentyycdavid/tamago

Dev branch

 fn extract_by_ilp(     if no_order {         arg_vec.push("--no_order");     }+    if initialize {+        arg_vec.push("--initialize")+    }+    match matches.value_of("ilp_time_sec") {+        Some(time_lim) => {+            arg_vec.push("--time_lim_sec");+            arg_vec.push(time_lim);+        }+        None => (),+    }+    match matches.value_of("ilp_num_threads") {+        Some(num_thread) => {+            arg_vec.push("--num_thread");+            arg_vec.push(num_thread);+        }+        None => (),+    }

Done

yycdavid

comment created time in a month

PullRequestReviewEvent

Pull request review commentyycdavid/tamago

Dev branch

 fn optimize(matches: clap::ArgMatches) {         runner_ext.egraph.dot().to_svg("target/ext.svg").unwrap();     } +    let time_start = get_full_graph_runtime(&runner_start, false);+    println!("Start graph runtime: {}", time_start);+     let time_ext = get_full_graph_runtime(&runner_ext, true);     println!("Extracted graph runtime: {}", time_ext); -    let time_start = get_full_graph_runtime(&runner_start, false);-    println!("Start graph runtime: {}", time_start);+    match matches.value_of("out_file") {+        Some(outf) => {+            let mut file = OpenOptions::new()+                .append(true)+                .create(true)+                .open(outf)+                .unwrap();++            // Stats to write: original runtime, optimized runtime, saturation time, extraction time,+            // number of nodes, number of eclasses, number of possible programs+            let data = json!({+                "original": time_start,+                "optimized": time_ext,+                "saturation": sat_duration.as_secs_f32(),+                "extraction": ext_secs,+                "nodes": num_enodes,+                "classes": num_classes,+                "programs": num_programs,+                "iter": num_iter_sat,+            });+            let sol_data_str =+                serde_json::to_string(&data).expect("Fail to convert json to string");++            if let Err(e) = writeln!(file, "{}", sol_data_str) {+                eprintln!("Couldn't write to file: {}", e);+            }+        }+        None => (),+    }

Done

yycdavid

comment created time in a month

PullRequestReviewEvent

Pull request review commentyycdavid/tamago

Dev branch

 impl GraphConverter {     /// Takes in the parameters for the new input, construct the node in RexExpr,     /// return the Id (index) of this input node in the RecExpr. This is the     /// pattern for all these op functions.-    pub fn new_input(&mut self, dims: &[i32]) -> Id {+    pub fn new_input(&mut self, dims: &[i32]) -> TensorInfo {         let name = self.name_gen.new_input_name() + "@" + &dims.iter().join("_");         let node = Mdl::Var(Symbol::from(name));         let name_id = self.rec_expr.add(node);          let new_node = Mdl::Input([name_id]);-        self.rec_expr.add(new_node)+        let (shape, n_dim) = self.shape_from_dim(dims);+        TensorInfo {+            id: self.rec_expr.add(new_node),+            shape: shape,+            n_dim: n_dim,

Done

yycdavid

comment created time in a month

PullRequestReviewEvent

Pull request review commentyycdavid/tamago

Dev branch

 impl GraphConverter {             stride_w_id,             padding_id,             activation_id,-            inpt,-            wght,+            inpt.id,+            wght.id,         ]);-        self.rec_expr.add(new_node)++        // Get shape+        let mut shape = [0; MAX_DIM];+        let input_h = inpt.shape[2];+        let input_w = inpt.shape[3];+        let kernel_h = wght.shape[2];+        let kernel_w = wght.shape[3];++        let (output_h, output_w) = self.get_conv_shape(+            input_h, input_w, stride_h, stride_w, kernel_h, kernel_w, padding,+        );+        shape[0] = inpt.shape[0];+        shape[1] = wght.shape[0];+        shape[2] = output_h;+        shape[3] = output_w;++        TensorInfo {+            id: self.rec_expr.add(new_node),+            shape: shape,+            n_dim: 4,+        }     } -    pub fn relu(&mut self, inpt: Id) -> Id {-        let new_node = Mdl::Relu(inpt);-        self.rec_expr.add(new_node)+    pub fn relu(&mut self, inpt: TensorInfo) -> TensorInfo {+        let new_node = Mdl::Relu(inpt.id);++        TensorInfo {+            id: self.rec_expr.add(new_node),+            shape: inpt.shape,+            n_dim: inpt.n_dim,+        }

Done

yycdavid

comment created time in a month

PullRequestReviewEvent

Pull request review commentyycdavid/tamago

Dev branch

 impl GraphConverter {             stride_w_id,             padding_id,             activation_id,-            inpt,-            wght,+            inpt.id,+            wght.id,         ]);-        self.rec_expr.add(new_node)++        // Get shape+        let mut shape = [0; MAX_DIM];+        let input_h = inpt.shape[2];+        let input_w = inpt.shape[3];+        let kernel_h = wght.shape[2];+        let kernel_w = wght.shape[3];++        let (output_h, output_w) = self.get_conv_shape(+            input_h, input_w, stride_h, stride_w, kernel_h, kernel_w, padding,+        );+        shape[0] = inpt.shape[0];+        shape[1] = wght.shape[0];+        shape[2] = output_h;+        shape[3] = output_w;++        TensorInfo {+            id: self.rec_expr.add(new_node),+            shape: shape,+            n_dim: 4,+        }     } -    pub fn relu(&mut self, inpt: Id) -> Id {-        let new_node = Mdl::Relu(inpt);-        self.rec_expr.add(new_node)+    pub fn relu(&mut self, inpt: TensorInfo) -> TensorInfo {+        let new_node = Mdl::Relu(inpt.id);++        TensorInfo {+            id: self.rec_expr.add(new_node),+            shape: inpt.shape,+            n_dim: inpt.n_dim,+        }     } -    pub fn tanh(&mut self, inpt: Id) -> Id {-        let new_node = Mdl::Tanh(inpt);-        self.rec_expr.add(new_node)+    pub fn tanh(&mut self, inpt: TensorInfo) -> TensorInfo {+        let new_node = Mdl::Tanh(inpt.id);++        TensorInfo {+            id: self.rec_expr.add(new_node),+            shape: inpt.shape,+            n_dim: inpt.n_dim,+        }     } -    pub fn sigmoid(&mut self, inpt: Id) -> Id {-        let new_node = Mdl::Sigmoid(inpt);-        self.rec_expr.add(new_node)+    pub fn sigmoid(&mut self, inpt: TensorInfo) -> TensorInfo {+        let new_node = Mdl::Sigmoid(inpt.id);++        TensorInfo {+            id: self.rec_expr.add(new_node),+            shape: inpt.shape,+            n_dim: inpt.n_dim,+        }     } -    pub fn add(&mut self, inpt_1: Id, inpt_2: Id) -> Id {-        let new_node = Mdl::Ewadd([inpt_1, inpt_2]);-        self.rec_expr.add(new_node)+    pub fn add(&mut self, inpt_1: TensorInfo, inpt_2: TensorInfo) -> TensorInfo {+        let new_node = Mdl::Ewadd([inpt_1.id, inpt_2.id]);++        TensorInfo {+            id: self.rec_expr.add(new_node),+            shape: inpt_1.shape,+            n_dim: inpt_1.n_dim,+        }     } -    pub fn matmul(&mut self, inpt_1: Id, inpt_2: Id) -> Id {+    pub fn matmul(&mut self, inpt_1: TensorInfo, inpt_2: TensorInfo) -> TensorInfo {         let activation = ACTNONE;         let act_id = self.add_or_get_val(activation); -        let new_node = Mdl::Matmul([act_id, inpt_1, inpt_2]);-        self.rec_expr.add(new_node)+        let new_node = Mdl::Matmul([act_id, inpt_1.id, inpt_2.id]);++        let mut shape = inpt_1.shape;+        let n_dim = inpt_1.n_dim;+        shape[n_dim - 1] = inpt_2.shape[n_dim - 1];++        TensorInfo {+            id: self.rec_expr.add(new_node),+            shape: shape,+            n_dim: n_dim,+        }     } -    pub fn mul(&mut self, inpt_1: Id, inpt_2: Id) -> Id {-        let new_node = Mdl::Ewmul([inpt_1, inpt_2]);-        self.rec_expr.add(new_node)+    pub fn mul(&mut self, inpt_1: TensorInfo, inpt_2: TensorInfo) -> TensorInfo {+        let new_node = Mdl::Ewmul([inpt_1.id, inpt_2.id]);++        TensorInfo {+            id: self.rec_expr.add(new_node),+            shape: inpt_1.shape,+            n_dim: inpt_1.n_dim,+        }     } -    pub fn concat(&mut self, axis: i32, ndim: i32, inpt_1: Id, inpt_2: Id) -> Id {+    pub fn concat(+        &mut self,+        axis: i32,+        ndim: i32,+        inpt_1: TensorInfo,+        inpt_2: TensorInfo,+    ) -> TensorInfo {         // Only support concat of 2 inputs for now         // To support more, pass in a slice and create more concat nodes here         let axis_id = self.add_or_get_val(axis);         let ndim_id = self.add_or_get_val(ndim); -        let new_node = Mdl::Concat([axis_id, ndim_id, inpt_1, inpt_2]);-        self.rec_expr.add(new_node)+        let new_node = Mdl::Concat([axis_id, ndim_id, inpt_1.id, inpt_2.id]);++        let mut shape = inpt_1.shape;+        let n_dim = inpt_1.n_dim;+        shape[axis as usize] += inpt_2.shape[axis as usize];++        TensorInfo {+            id: self.rec_expr.add(new_node),+            shape: shape,+            n_dim: n_dim,+        }+    }++    pub fn concat_multi(&mut self, axis: i32, ndim: i32, inputs: &[TensorInfo]) -> TensorInfo {+        let n_inputs = inputs.len();+        // We can add supports for other number of inputs later when needed.+        // We need to add a new Concat op for each number of inputs+        assert!(n_inputs == 5);++        let axis_id = self.add_or_get_val(axis);+        let ndim_id = self.add_or_get_val(ndim);++        let new_node = Mdl::Concat5([+            axis_id,+            ndim_id,+            inputs[0].id,+            inputs[1].id,+            inputs[2].id,+            inputs[3].id,+            inputs[4].id,+        ]);++        let mut shape = inputs[0].shape;+        let n_dim = inputs[0].n_dim;+        for i in 1..n_inputs {

Done

yycdavid

comment created time in a month

PullRequestReviewEvent

Pull request review commentyycdavid/tamago

Dev branch

+from __future__ import print_function+import argparse+import numpy as np+import matplotlib.pyplot as plt+import os+import json+import scipy+import scipy.stats++def get_args():+    parser = argparse.ArgumentParser(description='Analysis script, get the statistics we want')+    parser.add_argument('--mode', type=str, default='runtime',+        help='Mode of analysis')+    parser.add_argument('--file', type=str, default='data.txt',+        help='File for the input data to analyze')++    return parser.parse_args()++# Plot speedup and optimizer time together in a same bar plot+def plot_runtime_and_speed(benchmark_name, result):+    width = 0.8+    x_locs = [0, 1, 2, 3, 6, 7]+    x_locs = [a + width/2 for a in x_locs]+    colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k']++    fig, ax1 = plt.subplots()+    ax1.bar(x_locs[0], result['orig_runtime'], width=width, label='Original', color=colors[0])+    ax1.bar(x_locs[1], result['taso'], width=width, label='TASO', color=colors[1])+    ax1.bar(x_locs[2], result['greedy'], width=width, label='Greedy', color=colors[2])+    ax1.bar(x_locs[3], result['ilp'], width=width, label='ILP', color=colors[3])++    runtimes = [result['orig_runtime'], result['taso'], result['greedy'], result['ilp']]++    ax1.set_ylabel('Graph runtime (milliseconds)')++    ax2 = ax1.twinx()+    ax2.bar(x_locs[4], result['taso_total_time'], width=width, label='TASO total', color=colors[4])+    ax2.bar(x_locs[4], result['taso_best_time'], width=width, label='TASO best', color=colors[5])+    ax2.bar(x_locs[5], result['ilp_time'], width=width, label='Sat.+ILP', color=colors[6])++    plt.xticks(x_locs, ['' for _ in range(len(x_locs))])++    ax2.set_ylabel('Optimizer time (seconds)')+    ax1.set_xlabel(benchmark_name)++    lines, labels = ax1.get_legend_handles_labels()+    lines2, labels2 = ax2.get_legend_handles_labels()+    ax2.legend(lines + lines2, labels + labels2, loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=4, fancybox=True, shadow=True, prop={'size': 12})

Make a TODO for later

yycdavid

comment created time in a month

PullRequestReviewEvent

Pull request review commentyycdavid/tamago

Dev branch

 impl MultiPatterns {                 if compatible(&subst_1_dec, &subst_2_dec, &map_1.var_map) {                     // If so, merge two substitutions                     let merged_subst = merge_subst(subst_1_dec, subst_2_dec, &map_1.var_map);+                    // Check if any source pattern contains blacklisted nodes+                    if self.filter_after {+                        if contains_blacklist(+                            rule.0.ast.as_ref(),+                            &mut runner.egraph,+                            &merged_subst,+                        )+                        .0 || contains_blacklist(+                            rule.1.ast.as_ref(),+                            &mut runner.egraph,+                            &merged_subst,+                        )+                        .0+                        {+                            continue;+                        }+                    }                      // check_pat on both dst patterns-                    if check_pat(rule.2.ast.as_ref(), &mut runner.egraph, &merged_subst).0 {-                        if check_pat(rule.3.ast.as_ref(), &mut runner.egraph, &merged_subst).0 {+                    let (valid_1, _, _, existing_1) = check_pat(+                        rule.2.ast.as_ref(),+                        &mut runner.egraph,+                        &merged_subst,+                        /*get_exist_nodes=*/ self.filter_after,+                    );+                    let (valid_2, _, _, existing_2) = check_pat(+                        rule.3.ast.as_ref(),+                        &mut runner.egraph,+                        &merged_subst,+                        /*get_exist_nodes=*/ self.filter_after,+                    );+                    if valid_1 && valid_2 {+                        let cycle_check_passed = if self.no_cycle {+                            if self.filter_after {+                                // Do pre-filtering using the pre-collected descendents info+                                self.check_cycle_partial(+                                    &runner.egraph,+                                    &merged_subst,+                                    &map_1.var_map,+                                    &map_2.var_map,+                                    match_1.eclass,+                                    match_2.eclass,+                                )+                            } else {+                                // Check cycle by make a pass in egraph+                                let mut descendents: HashMap<Id, HashSet<Id>> = Default::default();+                                check_cycle(+                                    &runner.egraph,+                                    &merged_subst,+                                    &map_1.var_map,+                                    &map_2.var_map,+                                    match_1.eclass,+                                    match_2.eclass,+                                    &mut descendents,+                                )+                            }+                        } else {+                            true+                        };+                        if cycle_check_passed {                             // apply dst patterns, union                             let id_1 =                                 rule.2                                     .apply_one(&mut runner.egraph, match_1.eclass, &merged_subst)                                     [0];-                            runner.egraph.union(id_1, match_1.eclass);+                             let id_2 =                                 rule.3                                     .apply_one(&mut runner.egraph, match_2.eclass, &merged_subst)                                     [0];++                            // Add the newly added nodes to the ordering list+                            if self.filter_after {+                                let existing_1 = existing_1.unwrap();+                                let existing_2 = existing_2.unwrap();+                                let (nodes_in_1, _) = add_newly_added(+                                    rule.2.ast.as_ref(),+                                    &mut runner.egraph,+                                    &merged_subst,+                                    &existing_1,+                                );+                                let existing_2_updated: HashSet<Mdl> = existing_2+                                    .iter()+                                    .chain(nodes_in_1.iter())+                                    .map(|node| node.clone())+                                    .collect();+                                add_newly_added(+                                    rule.3.ast.as_ref(),+                                    &mut runner.egraph,+                                    &merged_subst,+                                    &existing_2_updated,+                                );+                            }++                            runner.egraph.union(id_1, match_1.eclass);                             runner.egraph.union(id_2, match_2.eclass);+                        } else {                         }                     }                 }             }         }     }++    /// Returns true if there will not be a cycle introduced by applying this rule.+    ///+    /// Checking based on descendents collected at beginning of run_one(). If any input node+    /// contains descendents that is any of the matched output class, then there can be a cycle+    /// created.+    ///+    /// # Parameters+    ///+    /// - `egraph`: egraph of interest+    /// - `input_subst`: substitution containing the input variables+    /// - `var_map_1`: keys of this map contains all the input variables in source pattern 1+    /// - `var_map_2`: keys of this map contains all the input variables in source pattern 2+    /// - `out_class_1`: Id of the matched eclass of the output of source pattern 1+    /// - `out_class_2`: Id of the matched eclass of the output of source pattern 2+    fn check_cycle_partial(+        &self,+        egraph: &EGraph<Mdl, TensorAnalysis>,+        input_subst: &Subst,+        var_map_1: &HashMap<egg::Var, egg::Var>,+        var_map_2: &HashMap<egg::Var, egg::Var>,+        out_class_1: Id,+        out_class_2: Id,+    ) -> bool {+        // Get all input eclass IDs+        let input_ids: HashSet<Id> = var_map_1+            .iter()+            .chain(var_map_2.iter())+            .map(|(var, _)| *input_subst.get(*var).unwrap())+            .collect();+        // Check descendents of the input eclasses+        for id in input_ids.iter() {+            let descendents = self.descendents.as_ref().unwrap();+            let descendents_input = descendents.get(id).unwrap();+            if descendents_input.contains(&out_class_1) || descendents_input.contains(&out_class_2)+            {+                return false;+            }+        }+        true+    }+}++/// Do post-processing to remove cycles in the egraph, by adding nodes to blacklist.+pub fn remove_cycle_by_order(runner: &mut Runner<Mdl, TensorAnalysis, ()>) {+    // Update blacklist (canonicalize Ids with egraph.find())+    update_blacklist(&mut runner.egraph);++    // Update newly_added (canonicalize Ids with egraph.find()) and construct hashmap+    // for newly_added+    let updated: Vec<Mdl> = runner+        .egraph+        .analysis+        .newly_added+        .iter()+        .map(|node| node.clone().map_children(|id| runner.egraph.find(id)))+        .collect();+    let mut added_node_to_order = HashMap::<Mdl, usize>::new();+    for (i, node) in updated.iter().enumerate() {+        added_node_to_order.entry(node.clone()).or_insert(i);+    }+    // Remove cycles by adding nodes to blacklist+    remove_cycles(&mut runner.egraph, &added_node_to_order, runner.roots[0]);+}++/// Add newly added nodes in this pattern to the list of newly added nodes, for use in cycle+/// filtering+///+/// The newly added nodes are stored in the graph level metadata in egraph.analysis+///+/// # Parameters+///+/// - `pat`: the AST representation of the pattern. See egg::Pattern for more info+/// - `egraph`: E-graph of interest+/// - `subst`: mapping variable to eclass ID. See egg::Subst for more info.+/// - `existing_nodes`: the set of nodes within this pattern that already exists before this+///         pattern is applied+///+/// # Returns+///+/// A tuple of (HashSet<Mdl>, Id) where+///+/// - HashSet<Mdl>: The set of all nodes in this pattern+/// - Id: the Id into egraph of the matched root of this pattern+fn add_newly_added(+    pat: &[ENodeOrVar<Mdl>],+    egraph: &mut EGraph<Mdl, TensorAnalysis>,+    subst: &Subst,+    existing_nodes: &HashSet<Mdl>,+) -> (HashSet<Mdl>, Id) {+    match pat.last().unwrap() {+        ENodeOrVar::Var(w) => (HashSet::<Mdl>::new(), subst[*w]),+        ENodeOrVar::ENode(e) => {+            let children = e.children();+            let results: Vec<(HashSet<Mdl>, Id)> = children+                .iter()+                .map(|child| {+                    add_newly_added(+                        &pat[..usize::from(*child) + 1],+                        egraph,+                        subst,+                        existing_nodes,+                    )+                })+                .collect();++            let mut new_e = e.clone();+            let new_e_ch = new_e.children_mut();+            for (i, res) in results.iter().enumerate() {+                new_e_ch[i] = res.1;+            }++            let mut nodes_in_pat = HashSet::<Mdl>::new();+            for res in results.iter() {+                for node in res.0.iter() {+                    nodes_in_pat.insert(node.clone());+                }+            }+            nodes_in_pat.insert(new_e.clone());++            // Add to order list+            if !existing_nodes.contains(&new_e) {+                egraph.analysis.newly_added.push(new_e.clone());+            }++            (nodes_in_pat, egraph.lookup(new_e).unwrap())+        }+    }+}++/// Remove cycles in EGraph by adding nodes to blacklist+///+/// This function works by:+///     - Make a pass over egraph to get a set of cycles+///     - For each cycle, pick the node that got added latest and add it to blacklist+///     - Repeat until no cycles are left+fn remove_cycles(+    egraph: &mut EGraph<Mdl, TensorAnalysis>,+    added_node_to_order: &HashMap<Mdl, usize>,+    root: Id,+) {+    loop {+        let mut paths_from_root = HashMap::<Id, Vec<(Id, Mdl)>>::new();+        let mut cycles = Vec::<Vec<Mdl>>::new();++        get_cycles(egraph, root, &mut paths_from_root, &mut cycles);++        if cycles.len() == 0 {+            break;+        }+        for cycle in cycles.iter() {+            resolve_cycle(egraph, cycle, added_node_to_order);+        }+    }+}++/// Resolve cycle by adding node to blacklist+///+/// # Parameters+///+/// - `egraph`: E-graph of interest+/// - `cycle`: list of nodes within the cycle+/// - `added_node_to_order`: HashMap, map from node to the order that it was added into the egraph+fn resolve_cycle(+    egraph: &mut EGraph<Mdl, TensorAnalysis>,+    cycle: &[Mdl],+    added_node_to_order: &HashMap<Mdl, usize>,+) {+    // Check if any node in cycle is already in blacklist+    let already_solved = cycle+        .iter()+        .any(|node| egraph.analysis.blacklist_nodes.contains(node));+    if !already_solved {+        assert!(cycle.len() > 0);+        let (ord, n) = cycle+            .iter()+            .map(|node| {+                let order = added_node_to_order+                    .get(node)+                    .map_or(-1, |index| *index as i32);+                (order, node.clone())+            })+            .max_by_key(|(o, _)| *o)+            .unwrap();+        assert!(ord >= 0);+        egraph.analysis.blacklist_nodes.insert(n.clone());+    }+}++/// Traverse the EGraph and get a set of cycles (reachable from root)+///+/// # Parameters+///+/// - `egraph`: E-graph of interest+/// - `root`: Id of root eclass+/// - `paths_from_root`: HashMap storing for each eclass, one path from the root eclass. The path+///         is represented by a list of (eclass, enode) that composes the path+/// - `cycles`: list of cycles. Each cycle is a list of nodes. A cycle of 1->2->4->3->1 will be+///         stored as [1,2,4,3]+fn get_cycles(+    egraph: &EGraph<Mdl, TensorAnalysis>,+    root: Id,+    paths_from_root: &mut HashMap<Id, Vec<(Id, Mdl)>>,+    cycles: &mut Vec<Vec<Mdl>>,+) {+    // Get a map from Id to the eclass objects, since egg doesn't provide accessing eclass from Id+    let id_to_class: HashMap<Id, &EClass<Mdl, ValTnsr>> =+        egraph.classes().map(|class| (class.id, class)).collect();++    get_cycles_rec(+        egraph,+        root,+        /*path_to_here=*/ Vec::<(Id, Mdl)>::new(),+        &id_to_class,+        paths_from_root,+        cycles,+    );+}++/// Traverse the EGraph in DFS order, update paths_from_root and cycles on the fly+///+/// # Parameters+///+/// - `egraph`: E-graph of interest+/// - `eclass`: The current eclass that we are visiting+/// - `path_to_here`: A path from root to this eclass+/// - `id_to_class`: Map from eclass ID to the eclass objects+/// - `paths_from_root`: HashMap storing for each eclass, one path from the root eclass. The path+///         is represented by a list of (eclass, enode) that composes the path+/// - `cycles`: list of cycles. Each cycle is a list of nodes. A cycle of 1->2->4->3->1 will be+///         stored as [1,2,4,3]+fn get_cycles_rec(+    egraph: &EGraph<Mdl, TensorAnalysis>,+    eclass: Id,+    path_to_here: Vec<(Id, Mdl)>,+    id_to_class: &HashMap<Id, &EClass<Mdl, ValTnsr>>,+    paths_from_root: &mut HashMap<Id, Vec<(Id, Mdl)>>,+    cycles: &mut Vec<Vec<Mdl>>,+) {+    assert!(!paths_from_root.contains_key(&eclass));+    paths_from_root.insert(eclass, path_to_here.clone());++    let class = id_to_class.get(&eclass).unwrap();+    for node in class.iter() {+        if egraph.analysis.blacklist_nodes.contains(node) {+            continue;+        }+        for child in node.children().iter() {+            if !paths_from_root.contains_key(child) {+                // Haven't visited, so visit+                let mut path_to_child = path_to_here.clone();

Put a TODO for later in the interest of time

yycdavid

comment created time in a month

PullRequestReviewEvent

Pull request review commentyycdavid/tamago

Dev branch

+from __future__ import print_function+import argparse+import numpy as np+import matplotlib.pyplot as plt+import os+import json+import scipy+import scipy.stats++def get_args():+    parser = argparse.ArgumentParser(description='Analysis script, get the statistics we want')+    parser.add_argument('--mode', type=str, default='runtime',+        help='Mode of analysis')+    parser.add_argument('--file', type=str, default='data.txt',+        help='File for the input data to analyze')++    return parser.parse_args()++# Plot speedup and optimizer time together in a same bar plot+def plot_runtime_and_speed(benchmark_name, result):+    width = 0.8+    x_locs = [0, 1, 2, 3, 6, 7]+    x_locs = [a + width/2 for a in x_locs]+    colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k']++    fig, ax1 = plt.subplots()+    ax1.bar(x_locs[0], result['orig_runtime'], width=width, label='Original', color=colors[0])+    ax1.bar(x_locs[1], result['taso'], width=width, label='TASO', color=colors[1])+    ax1.bar(x_locs[2], result['greedy'], width=width, label='Greedy', color=colors[2])+    ax1.bar(x_locs[3], result['ilp'], width=width, label='ILP', color=colors[3])++    runtimes = [result['orig_runtime'], result['taso'], result['greedy'], result['ilp']]++    ax1.set_ylabel('Graph runtime (milliseconds)')++    ax2 = ax1.twinx()+    ax2.bar(x_locs[4], result['taso_total_time'], width=width, label='TASO total', color=colors[4])+    ax2.bar(x_locs[4], result['taso_best_time'], width=width, label='TASO best', color=colors[5])+    ax2.bar(x_locs[5], result['ilp_time'], width=width, label='Sat.+ILP', color=colors[6])++    plt.xticks(x_locs, ['' for _ in range(len(x_locs))])++    ax2.set_ylabel('Optimizer time (seconds)')+    ax1.set_xlabel(benchmark_name)++    lines, labels = ax1.get_legend_handles_labels()+    lines2, labels2 = ax2.get_legend_handles_labels()+    ax2.legend(lines + lines2, labels + labels2, loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=4, fancybox=True, shadow=True, prop={'size': 12})++    plt.savefig("{}.png".format(benchmark_name), bbox_inches='tight')+    plt.close()++def plot_runtime_and_speed_2(benchmark_name, result):++    width = 0.8+    x_locs = [0, 1, 2, 3]+    x_locs = [a + width/2 for a in x_locs]+    colors = ['b', 'g', 'r', 'c']++    fig, ax1 = plt.subplots()+    ax1.bar(x_locs[0], result['orig_runtime'], width=width, label='Original', color=colors[0])+    ax1.bar(x_locs[1], result['taso'], width=width, label='TASO', color=colors[1])+    ax1.bar(x_locs[2], result['greedy'], width=width, label='Sat.+Greedy', color=colors[2])+    ax1.bar(x_locs[3], result['ilp'], width=width, label='Sat.+ILP', color=colors[3])++    runtimes = [result['orig_runtime'], result['taso'], result['greedy'], result['ilp']]++    ax1.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=4, fancybox=True, shadow=True, prop={'size': 12})+    plt.xticks(x_locs, ['' for _ in range(len(x_locs))])+    ax1.set_ylabel('Graph runtime (milliseconds)')+    ax1.set_xlabel(benchmark_name)+++    plt.savefig("{}_runtime.png".format(benchmark_name), bbox_inches='tight')+    plt.close()++    x_locs = [0, 1]+    x_locs = [a + width/2 for a in x_locs]+    colors = ['b', 'g', 'r', 'c']++    fig, ax2 = plt.subplots()+    ax2.bar(x_locs[0], result['taso_total_time'], width=width, label='TASO total', color=colors[0])+    ax2.bar(x_locs[0], result['taso_best_time'], width=width, label='TASO best', color=colors[1])+    ax2.bar(x_locs[1], result['ilp_time'], width=width, label='Sat.+ILP', color=colors[2])++    ax2.set_ylabel('Optimizer time (seconds)')+    ax2.set_xlabel(benchmark_name)+    fig = plt.gcf()+    fig.set_size_inches(3, 5)+    plt.xticks(x_locs, ['' for _ in range(len(x_locs))])++    ax2.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=4, fancybox=True, shadow=True, prop={'size': 12})++    plt.savefig("{}_optimizer.png".format(benchmark_name), bbox_inches='tight')+    plt.close()++def runtime_stats(args):+    with open(args.file, 'r') as f:+        content = f.readlines()++    start_times = []+    ext_times = []+    for line in content:+        times = line.split('\t')+        start_times.append(float(times[0]))+        ext_times.append(float(times[1]))++    start_mean = np.mean(start_times)+    start_std = np.std(start_times)+    ext_mean = np.mean(ext_times)+    ext_std = np.std(ext_times)+    print("Start graph runtime: mean {}, std {}".format(start_mean, start_std))+    print("Extracted graph runtime: mean {}, std {}".format(ext_mean, ext_std))++def plot_bars(args):+    # Results for the spotlight talk was manually input, since we don't have the pipeline to store results then+    results = {+        "bert": {+            "orig_runtime": 1.8964,+            "taso": 1.7415,+            "greedy": 1.8903,+            "ilp": 1.7410,+            "taso_total_time": 13.98,+            "taso_best_time": 3.410,+            "ilp_time": 3.022,+        },+        "nasrnn": {+            "orig_runtime": 1.8601,+            "taso": 1.2890,+            "greedy": 1.1446,+            "ilp": 1.1106,+            "taso_total_time": 175.4, +            "taso_best_time": 121.1,+            "ilp_time": 28.47,+        },+        "resnext50": {+            "orig_runtime": 6.0775,+            "taso": 5.8144,+            "greedy": 5.5850,+            "ilp": 5.5704,+            "taso_total_time": 25.00,+            "taso_best_time": 5.909,+            "ilp_time": 1.314,+        }+    }++    plt.rcParams.update({'font.size': 16})++    for (benchmark, result) in results.items():+        plot_runtime_and_speed_2(benchmark, result)+++def speedup_bar(benchmark):+    # Read in results+    tamago_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))+    taso_root = os.path.join(os.path.dirname(tamago_root), "TASO")++    egg_stats_file = os.path.join(tamago_root, "tmp/{}_1_stats.txt".format(benchmark))+    taso_runtime_file = os.path.join(taso_root, "examples/{}_time.txt".format(benchmark))++    with open(egg_stats_file, 'r') as egg_f:+        egg_results = egg_f.readlines()++    egg_results = [json.loads(x) for x in egg_results]+    egg_runtimes = []+    for res in egg_results[-5:]:+        egg_runtimes.append(res['optimized'])++    with open(taso_runtime_file, 'r') as f:+        content = f.readlines()++    orig_runtimes = []+    taso_runtimes = []+    for line in content[-5:]:+        times = line.split('\t')+        orig_runtimes.append(float(times[0]))+        taso_runtimes.append(float(times[1]))++    # Get original runtime mean, TASO mean and ste, egg mean and ste+    orig_mean = np.mean(orig_runtimes)+    taso_speedup = [(orig_mean/x - 1) * 100 for x in taso_runtimes]+    egg_speedup = [(orig_mean/x - 1) * 100 for x in egg_runtimes]+    taso_mean = np.mean(taso_speedup)+    egg_mean = np.mean(egg_speedup)+    taso_ste = scipy.stats.sem(taso_speedup)+    egg_ste = scipy.stats.sem(egg_speedup)++    taso_mean_time = np.mean(taso_runtimes)++    print("{}: orig {} taso {}".format(benchmark, orig_mean, taso_mean_time))++    # Plot bar and save+    width = 0.8+    x_locs = [0, 1]+    x_locs = [a + width/2 for a in x_locs]+    colors = ['b', 'r']++    fig, ax1 = plt.subplots()+    ax1.bar(x_locs[0], taso_mean, width=width, yerr=taso_ste, ecolor='m', capsize=2.0, label='TASO', color=colors[0])+    ax1.bar(x_locs[1], egg_mean, width=width, yerr=egg_ste, ecolor='m', capsize=2.0, label='Sat.+ILP', color=colors[1])++    ax1.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=2, fancybox=True, shadow=True, prop={'size': 14})+    plt.xticks(x_locs, ['' for _ in range(len(x_locs))])+    ax1.set_ylabel('Speed up percentage')+    ax1.set_xlabel(benchmark)++    fig = plt.gcf()+    fig.set_size_inches(2, 5)++    plt.savefig("{}_speedup.png".format(benchmark), bbox_inches='tight')+    plt.close()++def optimizer_time_bar(benchmark):+    # Read in results+    tamago_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))+    taso_root = os.path.join(os.path.dirname(tamago_root), "TASO")++    egg_stats_file = os.path.join(tamago_root, "tmp/{}_2_stats.txt".format(benchmark))+    taso_stats_file = os.path.join(taso_root, "examples/{}_stats.txt".format(benchmark))++    with open(egg_stats_file, 'r') as egg_f:+        egg_results = egg_f.readlines()++    egg_results = [json.loads(x) for x in egg_results]+    egg_times = []+    egg_sat_times = []+    egg_ext_times = []+    for res in egg_results[-5:]:+        egg_times.append(res['extraction'] + res['saturation'])+        egg_sat_times.append(res['saturation'])+        egg_ext_times.append(res['extraction'])++    with open(taso_stats_file, 'r') as f:+        content = f.readlines()++    taso_totals = []+    taso_bests = []+    for line in content[-5:]:+        elements = line.split(' ')+        taso_totals.append(float(elements[3][:-1]))+        taso_bests.append(float(elements[1][:-1]))++    sat_time_mean = np.mean(egg_sat_times)+    ext_time_mean = np.mean(egg_ext_times)++    print("{}, sat time {}, ext time {}".format(benchmark, sat_time_mean, ext_time_mean))++    width = 0.8+    x_locs = [0, 1]+    x_locs = [a + width/2 for a in x_locs]+    colors = ['b', 'g', 'r', 'c']++    egg_time = np.mean(egg_times)+    taso_total = np.mean(taso_totals)+    taso_best = np.mean(taso_bests)++    fig, ax2 = plt.subplots()+    ax2.bar(x_locs[0], taso_total, width=width, label='TASO total', color=colors[0])+    ax2.bar(x_locs[0], taso_best, width=width, label='TASO best', color=colors[1])+    ax2.bar(x_locs[1], egg_time, width=width, label='Sat.+ILP', color=colors[2])++    ax2.set_ylabel('Optimizer time (seconds)')+    ax2.set_xlabel(benchmark)+    fig = plt.gcf()+    fig.set_size_inches(3, 5)+    plt.xticks(x_locs, ['' for _ in range(len(x_locs))])++    #ax2.legend(lines + lines2, labels + labels2, fontsize=10)+    ax2.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=4, fancybox=True, shadow=True, prop={'size': 12})++    plt.savefig("{}_optim_time.png".format(benchmark), bbox_inches='tight')+    plt.close()+    ++def equivalent_graphs(benchmark):+    # Read in results+    tamago_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))+    taso_root = os.path.join(os.path.dirname(tamago_root), "TASO")++    egg_stats_file = os.path.join(tamago_root, "tmp/{}_1_stats.txt".format(benchmark))+    taso_stats_file = os.path.join(taso_root, "examples/{}_stats.txt".format(benchmark))++    with open(egg_stats_file, 'r') as egg_f:+        egg_results = egg_f.readlines()++    egg_results = [json.loads(x) for x in egg_results]+    egg_equiv = []+    for res in egg_results[-5:]:+        egg_equiv.append(res['programs'])++    with open(taso_stats_file, 'r') as f:+        content = f.readlines()++    taso_equiv = []+    for line in content[-5:]:+        elements = line.split(' ')+        taso_equiv.append(int(elements[-1])+100)++    egg_mean = np.mean(egg_equiv)+    taso_mean = np.mean(taso_equiv)++    print("{}: egg (power of 2) {}, taso {}".format(benchmark, egg_mean, taso_mean))+++def multi_trend(benchmark):+    # Read in results+    tamago_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))+    taso_root = os.path.join(os.path.dirname(tamago_root), "TASO")++    taso_runtime_file = os.path.join(taso_root, "examples/{}_time.txt".format(benchmark))++    with open(taso_runtime_file, 'r') as f:+        content = f.readlines()++    orig_runtimes = []+    for line in content[-5:]:+        times = line.split('\t')+        orig_runtimes.append(float(times[0]))+    orig_mean = np.mean(orig_runtimes)++    # iter=1+    egg_stats_file = os.path.join(tamago_root, "tmp/{}_1_stats.txt".format(benchmark))+    with open(egg_stats_file, 'r') as egg_f:+        egg_results = egg_f.readlines()++    egg_results = [json.loads(x) for x in egg_results]+    egg_runtimes = []+    egg_sat_times = []+    egg_ext_times = []+    egg_n_nodes = []+    for res in egg_results[-5:]:+        egg_runtimes.append(res['optimized'])+        egg_sat_times.append(res['saturation'])+        egg_ext_times.append(res['extraction'])+        egg_n_nodes.append(res['nodes'])++    mean_iter_1 = np.mean(egg_runtimes)+    mean_sat_iter_1 = np.mean(egg_sat_times)+    mean_ext_iter_1 = np.mean(egg_ext_times)+    mean_nodes_iter_1 = np.mean(egg_n_nodes)++    # iter=2+    egg_stats_file = os.path.join(tamago_root, "tmp/{}_2_stats.txt".format(benchmark))

Done

yycdavid

comment created time in a month

PullRequestReviewEvent

Pull request review commentyycdavid/tamago

Dev branch

 impl MultiPatterns {                 if compatible(&subst_1_dec, &subst_2_dec, &map_1.var_map) {                     // If so, merge two substitutions                     let merged_subst = merge_subst(subst_1_dec, subst_2_dec, &map_1.var_map);+                    // Check if any source pattern contains blacklisted nodes+                    if self.filter_after {+                        if contains_blacklist(+                            rule.0.ast.as_ref(),+                            &mut runner.egraph,+                            &merged_subst,+                        )+                        .0 || contains_blacklist(+                            rule.1.ast.as_ref(),+                            &mut runner.egraph,+                            &merged_subst,+                        )+                        .0+                        {+                            continue;+                        }+                    }                      // check_pat on both dst patterns-                    if check_pat(rule.2.ast.as_ref(), &mut runner.egraph, &merged_subst).0 {-                        if check_pat(rule.3.ast.as_ref(), &mut runner.egraph, &merged_subst).0 {+                    let (valid_1, _, _, existing_1) = check_pat(+                        rule.2.ast.as_ref(),+                        &mut runner.egraph,+                        &merged_subst,+                        /*get_exist_nodes=*/ self.filter_after,+                    );+                    let (valid_2, _, _, existing_2) = check_pat(+                        rule.3.ast.as_ref(),+                        &mut runner.egraph,+                        &merged_subst,+                        /*get_exist_nodes=*/ self.filter_after,+                    );+                    if valid_1 && valid_2 {+                        let cycle_check_passed = if self.no_cycle {+                            if self.filter_after {+                                // Do pre-filtering using the pre-collected descendents info+                                self.check_cycle_partial(+                                    &runner.egraph,+                                    &merged_subst,+                                    &map_1.var_map,+                                    &map_2.var_map,+                                    match_1.eclass,+                                    match_2.eclass,+                                )+                            } else {+                                // Check cycle by make a pass in egraph+                                let mut descendents: HashMap<Id, HashSet<Id>> = Default::default();+                                check_cycle(+                                    &runner.egraph,+                                    &merged_subst,+                                    &map_1.var_map,+                                    &map_2.var_map,+                                    match_1.eclass,+                                    match_2.eclass,+                                    &mut descendents,+                                )+                            }+                        } else {+                            true+                        };+                        if cycle_check_passed {                             // apply dst patterns, union                             let id_1 =                                 rule.2                                     .apply_one(&mut runner.egraph, match_1.eclass, &merged_subst)                                     [0];-                            runner.egraph.union(id_1, match_1.eclass);+                             let id_2 =                                 rule.3                                     .apply_one(&mut runner.egraph, match_2.eclass, &merged_subst)                                     [0];++                            // Add the newly added nodes to the ordering list+                            if self.filter_after {+                                let existing_1 = existing_1.unwrap();+                                let existing_2 = existing_2.unwrap();+                                let (nodes_in_1, _) = add_newly_added(+                                    rule.2.ast.as_ref(),+                                    &mut runner.egraph,+                                    &merged_subst,+                                    &existing_1,+                                );+                                let existing_2_updated: HashSet<Mdl> = existing_2+                                    .iter()+                                    .chain(nodes_in_1.iter())+                                    .map(|node| node.clone())+                                    .collect();+                                add_newly_added(+                                    rule.3.ast.as_ref(),+                                    &mut runner.egraph,+                                    &merged_subst,+                                    &existing_2_updated,+                                );+                            }++                            runner.egraph.union(id_1, match_1.eclass);                             runner.egraph.union(id_2, match_2.eclass);+                        } else {                         }                     }                 }             }         }     }++    /// Returns true if there will not be a cycle introduced by applying this rule.+    ///+    /// Checking based on descendents collected at beginning of run_one(). If any input node+    /// contains descendents that is any of the matched output class, then there can be a cycle+    /// created.+    ///+    /// # Parameters+    ///+    /// - `egraph`: egraph of interest+    /// - `input_subst`: substitution containing the input variables+    /// - `var_map_1`: keys of this map contains all the input variables in source pattern 1+    /// - `var_map_2`: keys of this map contains all the input variables in source pattern 2+    /// - `out_class_1`: Id of the matched eclass of the output of source pattern 1+    /// - `out_class_2`: Id of the matched eclass of the output of source pattern 2+    fn check_cycle_partial(+        &self,+        egraph: &EGraph<Mdl, TensorAnalysis>,+        input_subst: &Subst,+        var_map_1: &HashMap<egg::Var, egg::Var>,+        var_map_2: &HashMap<egg::Var, egg::Var>,+        out_class_1: Id,+        out_class_2: Id,+    ) -> bool {+        // Get all input eclass IDs+        let input_ids: HashSet<Id> = var_map_1+            .iter()+            .chain(var_map_2.iter())+            .map(|(var, _)| *input_subst.get(*var).unwrap())+            .collect();+        // Check descendents of the input eclasses+        for id in input_ids.iter() {+            let descendents = self.descendents.as_ref().unwrap();+            let descendents_input = descendents.get(id).unwrap();+            if descendents_input.contains(&out_class_1) || descendents_input.contains(&out_class_2)+            {+                return false;+            }+        }+        true+    }+}++/// Do post-processing to remove cycles in the egraph, by adding nodes to blacklist.+pub fn remove_cycle_by_order(runner: &mut Runner<Mdl, TensorAnalysis, ()>) {+    // Update blacklist (canonicalize Ids with egraph.find())+    update_blacklist(&mut runner.egraph);++    // Update newly_added (canonicalize Ids with egraph.find()) and construct hashmap+    // for newly_added+    let updated: Vec<Mdl> = runner+        .egraph+        .analysis+        .newly_added+        .iter()+        .map(|node| node.clone().map_children(|id| runner.egraph.find(id)))+        .collect();+    let mut added_node_to_order = HashMap::<Mdl, usize>::new();+    for (i, node) in updated.iter().enumerate() {+        added_node_to_order.entry(node.clone()).or_insert(i);+    }+    // Remove cycles by adding nodes to blacklist+    remove_cycles(&mut runner.egraph, &added_node_to_order, runner.roots[0]);+}++/// Add newly added nodes in this pattern to the list of newly added nodes, for use in cycle+/// filtering+///+/// The newly added nodes are stored in the graph level metadata in egraph.analysis+///+/// # Parameters+///+/// - `pat`: the AST representation of the pattern. See egg::Pattern for more info+/// - `egraph`: E-graph of interest+/// - `subst`: mapping variable to eclass ID. See egg::Subst for more info.+/// - `existing_nodes`: the set of nodes within this pattern that already exists before this+///         pattern is applied+///+/// # Returns+///+/// A tuple of (HashSet<Mdl>, Id) where+///+/// - HashSet<Mdl>: The set of all nodes in this pattern+/// - Id: the Id into egraph of the matched root of this pattern+fn add_newly_added(+    pat: &[ENodeOrVar<Mdl>],+    egraph: &mut EGraph<Mdl, TensorAnalysis>,+    subst: &Subst,+    existing_nodes: &HashSet<Mdl>,+) -> (HashSet<Mdl>, Id) {+    match pat.last().unwrap() {+        ENodeOrVar::Var(w) => (HashSet::<Mdl>::new(), subst[*w]),+        ENodeOrVar::ENode(e) => {+            let children = e.children();+            let results: Vec<(HashSet<Mdl>, Id)> = children+                .iter()+                .map(|child| {+                    add_newly_added(+                        &pat[..usize::from(*child) + 1],+                        egraph,+                        subst,+                        existing_nodes,+                    )+                })+                .collect();++            let mut new_e = e.clone();+            let new_e_ch = new_e.children_mut();+            for (i, res) in results.iter().enumerate() {+                new_e_ch[i] = res.1;+            }++            let mut nodes_in_pat = HashSet::<Mdl>::new();+            for res in results.iter() {+                for node in res.0.iter() {+                    nodes_in_pat.insert(node.clone());+                }+            }+            nodes_in_pat.insert(new_e.clone());++            // Add to order list+            if !existing_nodes.contains(&new_e) {+                egraph.analysis.newly_added.push(new_e.clone());+            }++            (nodes_in_pat, egraph.lookup(new_e).unwrap())+        }+    }+}++/// Remove cycles in EGraph by adding nodes to blacklist+///+/// This function works by:+///     - Make a pass over egraph to get a set of cycles+///     - For each cycle, pick the node that got added latest and add it to blacklist+///     - Repeat until no cycles are left+fn remove_cycles(+    egraph: &mut EGraph<Mdl, TensorAnalysis>,+    added_node_to_order: &HashMap<Mdl, usize>,+    root: Id,+) {+    loop {+        let mut paths_from_root = HashMap::<Id, Vec<(Id, Mdl)>>::new();+        let mut cycles = Vec::<Vec<Mdl>>::new();++        get_cycles(egraph, root, &mut paths_from_root, &mut cycles);++        if cycles.len() == 0 {+            break;+        }+        for cycle in cycles.iter() {+            resolve_cycle(egraph, cycle, added_node_to_order);+        }+    }+}++/// Resolve cycle by adding node to blacklist+///+/// # Parameters+///+/// - `egraph`: E-graph of interest+/// - `cycle`: list of nodes within the cycle+/// - `added_node_to_order`: HashMap, map from node to the order that it was added into the egraph+fn resolve_cycle(+    egraph: &mut EGraph<Mdl, TensorAnalysis>,+    cycle: &[Mdl],+    added_node_to_order: &HashMap<Mdl, usize>,+) {+    // Check if any node in cycle is already in blacklist+    let already_solved = cycle+        .iter()+        .any(|node| egraph.analysis.blacklist_nodes.contains(node));+    if !already_solved {+        assert!(cycle.len() > 0);+        let (ord, n) = cycle+            .iter()+            .map(|node| {+                let order = added_node_to_order+                    .get(node)+                    .map_or(-1, |index| *index as i32);+                (order, node.clone())+            })+            .max_by_key(|(o, _)| *o)+            .unwrap();+        assert!(ord >= 0);+        egraph.analysis.blacklist_nodes.insert(n.clone());+    }+}++/// Traverse the EGraph and get a set of cycles (reachable from root)+///+/// # Parameters+///+/// - `egraph`: E-graph of interest+/// - `root`: Id of root eclass+/// - `paths_from_root`: HashMap storing for each eclass, one path from the root eclass. The path+///         is represented by a list of (eclass, enode) that composes the path+/// - `cycles`: list of cycles. Each cycle is a list of nodes. A cycle of 1->2->4->3->1 will be+///         stored as [1,2,4,3]+fn get_cycles(+    egraph: &EGraph<Mdl, TensorAnalysis>,+    root: Id,+    paths_from_root: &mut HashMap<Id, Vec<(Id, Mdl)>>,+    cycles: &mut Vec<Vec<Mdl>>,+) {+    // Get a map from Id to the eclass objects, since egg doesn't provide accessing eclass from Id+    let id_to_class: HashMap<Id, &EClass<Mdl, ValTnsr>> =+        egraph.classes().map(|class| (class.id, class)).collect();++    get_cycles_rec(+        egraph,+        root,+        /*path_to_here=*/ Vec::<(Id, Mdl)>::new(),+        &id_to_class,+        paths_from_root,+        cycles,+    );+}++/// Traverse the EGraph in DFS order, update paths_from_root and cycles on the fly+///+/// # Parameters+///+/// - `egraph`: E-graph of interest+/// - `eclass`: The current eclass that we are visiting+/// - `path_to_here`: A path from root to this eclass+/// - `id_to_class`: Map from eclass ID to the eclass objects+/// - `paths_from_root`: HashMap storing for each eclass, one path from the root eclass. The path+///         is represented by a list of (eclass, enode) that composes the path+/// - `cycles`: list of cycles. Each cycle is a list of nodes. A cycle of 1->2->4->3->1 will be+///         stored as [1,2,4,3]+fn get_cycles_rec(+    egraph: &EGraph<Mdl, TensorAnalysis>,+    eclass: Id,+    path_to_here: Vec<(Id, Mdl)>,+    id_to_class: &HashMap<Id, &EClass<Mdl, ValTnsr>>,+    paths_from_root: &mut HashMap<Id, Vec<(Id, Mdl)>>,+    cycles: &mut Vec<Vec<Mdl>>,+) {+    assert!(!paths_from_root.contains_key(&eclass));+    paths_from_root.insert(eclass, path_to_here.clone());++    let class = id_to_class.get(&eclass).unwrap();+    for node in class.iter() {+        if egraph.analysis.blacklist_nodes.contains(node) {+            continue;+        }+        for child in node.children().iter() {+            if !paths_from_root.contains_key(child) {+                // Haven't visited, so visit+                let mut path_to_child = path_to_here.clone();+                path_to_child.push((eclass, node.clone()));

I think so, this node is a reference to the node in the EGraph

yycdavid

comment created time in a month

PullRequestReviewEvent

Pull request review commentyycdavid/tamago

Dev branch

 impl MultiPatterns {                 if compatible(&subst_1_dec, &subst_2_dec, &map_1.var_map) {                     // If so, merge two substitutions                     let merged_subst = merge_subst(subst_1_dec, subst_2_dec, &map_1.var_map);+                    // Check if any source pattern contains blacklisted nodes+                    if self.filter_after {+                        if contains_blacklist(+                            rule.0.ast.as_ref(),+                            &mut runner.egraph,+                            &merged_subst,+                        )+                        .0 || contains_blacklist(+                            rule.1.ast.as_ref(),+                            &mut runner.egraph,+                            &merged_subst,+                        )+                        .0+                        {+                            continue;+                        }+                    }                      // check_pat on both dst patterns-                    if check_pat(rule.2.ast.as_ref(), &mut runner.egraph, &merged_subst).0 {-                        if check_pat(rule.3.ast.as_ref(), &mut runner.egraph, &merged_subst).0 {+                    let (valid_1, _, _, existing_1) = check_pat(+                        rule.2.ast.as_ref(),+                        &mut runner.egraph,+                        &merged_subst,+                        /*get_exist_nodes=*/ self.filter_after,+                    );+                    let (valid_2, _, _, existing_2) = check_pat(+                        rule.3.ast.as_ref(),+                        &mut runner.egraph,+                        &merged_subst,+                        /*get_exist_nodes=*/ self.filter_after,+                    );+                    if valid_1 && valid_2 {+                        let cycle_check_passed = if self.no_cycle {+                            if self.filter_after {+                                // Do pre-filtering using the pre-collected descendents info+                                self.check_cycle_partial(+                                    &runner.egraph,+                                    &merged_subst,+                                    &map_1.var_map,+                                    &map_2.var_map,+                                    match_1.eclass,+                                    match_2.eclass,+                                )+                            } else {+                                // Check cycle by make a pass in egraph+                                let mut descendents: HashMap<Id, HashSet<Id>> = Default::default();+                                check_cycle(+                                    &runner.egraph,+                                    &merged_subst,+                                    &map_1.var_map,+                                    &map_2.var_map,+                                    match_1.eclass,+                                    match_2.eclass,+                                    &mut descendents,+                                )+                            }+                        } else {+                            true+                        };+                        if cycle_check_passed {                             // apply dst patterns, union                             let id_1 =                                 rule.2                                     .apply_one(&mut runner.egraph, match_1.eclass, &merged_subst)                                     [0];-                            runner.egraph.union(id_1, match_1.eclass);+                             let id_2 =                                 rule.3                                     .apply_one(&mut runner.egraph, match_2.eclass, &merged_subst)                                     [0];++                            // Add the newly added nodes to the ordering list+                            if self.filter_after {+                                let existing_1 = existing_1.unwrap();+                                let existing_2 = existing_2.unwrap();+                                let (nodes_in_1, _) = add_newly_added(+                                    rule.2.ast.as_ref(),+                                    &mut runner.egraph,+                                    &merged_subst,+                                    &existing_1,+                                );+                                let existing_2_updated: HashSet<Mdl> = existing_2+                                    .iter()+                                    .chain(nodes_in_1.iter())+                                    .map(|node| node.clone())+                                    .collect();+                                add_newly_added(+                                    rule.3.ast.as_ref(),+                                    &mut runner.egraph,+                                    &merged_subst,+                                    &existing_2_updated,+                                );+                            }++                            runner.egraph.union(id_1, match_1.eclass);                             runner.egraph.union(id_2, match_2.eclass);+                        } else {                         }                     }                 }             }         }     }++    /// Returns true if there will not be a cycle introduced by applying this rule.+    ///+    /// Checking based on descendents collected at beginning of run_one(). If any input node+    /// contains descendents that is any of the matched output class, then there can be a cycle+    /// created.+    ///+    /// # Parameters+    ///+    /// - `egraph`: egraph of interest+    /// - `input_subst`: substitution containing the input variables+    /// - `var_map_1`: keys of this map contains all the input variables in source pattern 1+    /// - `var_map_2`: keys of this map contains all the input variables in source pattern 2+    /// - `out_class_1`: Id of the matched eclass of the output of source pattern 1+    /// - `out_class_2`: Id of the matched eclass of the output of source pattern 2+    fn check_cycle_partial(+        &self,+        egraph: &EGraph<Mdl, TensorAnalysis>,+        input_subst: &Subst,+        var_map_1: &HashMap<egg::Var, egg::Var>,+        var_map_2: &HashMap<egg::Var, egg::Var>,+        out_class_1: Id,+        out_class_2: Id,+    ) -> bool {+        // Get all input eclass IDs+        let input_ids: HashSet<Id> = var_map_1+            .iter()+            .chain(var_map_2.iter())+            .map(|(var, _)| *input_subst.get(*var).unwrap())+            .collect();+        // Check descendents of the input eclasses+        for id in input_ids.iter() {+            let descendents = self.descendents.as_ref().unwrap();+            let descendents_input = descendents.get(id).unwrap();+            if descendents_input.contains(&out_class_1) || descendents_input.contains(&out_class_2)+            {+                return false;+            }+        }+        true+    }+}++/// Do post-processing to remove cycles in the egraph, by adding nodes to blacklist.+pub fn remove_cycle_by_order(runner: &mut Runner<Mdl, TensorAnalysis, ()>) {+    // Update blacklist (canonicalize Ids with egraph.find())+    update_blacklist(&mut runner.egraph);++    // Update newly_added (canonicalize Ids with egraph.find()) and construct hashmap+    // for newly_added+    let updated: Vec<Mdl> = runner+        .egraph+        .analysis+        .newly_added+        .iter()+        .map(|node| node.clone().map_children(|id| runner.egraph.find(id)))+        .collect();+    let mut added_node_to_order = HashMap::<Mdl, usize>::new();+    for (i, node) in updated.iter().enumerate() {+        added_node_to_order.entry(node.clone()).or_insert(i);+    }+    // Remove cycles by adding nodes to blacklist+    remove_cycles(&mut runner.egraph, &added_node_to_order, runner.roots[0]);+}++/// Add newly added nodes in this pattern to the list of newly added nodes, for use in cycle+/// filtering+///+/// The newly added nodes are stored in the graph level metadata in egraph.analysis+///+/// # Parameters+///+/// - `pat`: the AST representation of the pattern. See egg::Pattern for more info+/// - `egraph`: E-graph of interest+/// - `subst`: mapping variable to eclass ID. See egg::Subst for more info.+/// - `existing_nodes`: the set of nodes within this pattern that already exists before this+///         pattern is applied+///+/// # Returns+///+/// A tuple of (HashSet<Mdl>, Id) where+///+/// - HashSet<Mdl>: The set of all nodes in this pattern+/// - Id: the Id into egraph of the matched root of this pattern+fn add_newly_added(+    pat: &[ENodeOrVar<Mdl>],+    egraph: &mut EGraph<Mdl, TensorAnalysis>,+    subst: &Subst,+    existing_nodes: &HashSet<Mdl>,+) -> (HashSet<Mdl>, Id) {+    match pat.last().unwrap() {+        ENodeOrVar::Var(w) => (HashSet::<Mdl>::new(), subst[*w]),+        ENodeOrVar::ENode(e) => {+            let children = e.children();+            let results: Vec<(HashSet<Mdl>, Id)> = children+                .iter()+                .map(|child| {+                    add_newly_added(+                        &pat[..usize::from(*child) + 1],+                        egraph,+                        subst,+                        existing_nodes,+                    )+                })+                .collect();++            let mut new_e = e.clone();+            let new_e_ch = new_e.children_mut();+            for (i, res) in results.iter().enumerate() {+                new_e_ch[i] = res.1;+            }++            let mut nodes_in_pat = HashSet::<Mdl>::new();+            for res in results.iter() {+                for node in res.0.iter() {+                    nodes_in_pat.insert(node.clone());+                }+            }+            nodes_in_pat.insert(new_e.clone());++            // Add to order list+            if !existing_nodes.contains(&new_e) {+                egraph.analysis.newly_added.push(new_e.clone());+            }++            (nodes_in_pat, egraph.lookup(new_e).unwrap())+        }+    }+}++/// Remove cycles in EGraph by adding nodes to blacklist+///+/// This function works by:+///     - Make a pass over egraph to get a set of cycles+///     - For each cycle, pick the node that got added latest and add it to blacklist+///     - Repeat until no cycles are left+fn remove_cycles(+    egraph: &mut EGraph<Mdl, TensorAnalysis>,+    added_node_to_order: &HashMap<Mdl, usize>,+    root: Id,+) {+    loop {+        let mut paths_from_root = HashMap::<Id, Vec<(Id, Mdl)>>::new();+        let mut cycles = Vec::<Vec<Mdl>>::new();++        get_cycles(egraph, root, &mut paths_from_root, &mut cycles);++        if cycles.len() == 0 {+            break;+        }+        for cycle in cycles.iter() {+            resolve_cycle(egraph, cycle, added_node_to_order);+        }+    }+}++/// Resolve cycle by adding node to blacklist+///+/// # Parameters+///+/// - `egraph`: E-graph of interest+/// - `cycle`: list of nodes within the cycle+/// - `added_node_to_order`: HashMap, map from node to the order that it was added into the egraph+fn resolve_cycle(+    egraph: &mut EGraph<Mdl, TensorAnalysis>,+    cycle: &[Mdl],+    added_node_to_order: &HashMap<Mdl, usize>,+) {+    // Check if any node in cycle is already in blacklist+    let already_solved = cycle+        .iter()+        .any(|node| egraph.analysis.blacklist_nodes.contains(node));+    if !already_solved {+        assert!(cycle.len() > 0);+        let (ord, n) = cycle+            .iter()+            .map(|node| {+                let order = added_node_to_order+                    .get(node)+                    .map_or(-1, |index| *index as i32);+                (order, node.clone())+            })+            .max_by_key(|(o, _)| *o)+            .unwrap();+        assert!(ord >= 0);+        egraph.analysis.blacklist_nodes.insert(n.clone());+    }+}++/// Traverse the EGraph and get a set of cycles (reachable from root)+///+/// # Parameters+///+/// - `egraph`: E-graph of interest+/// - `root`: Id of root eclass+/// - `paths_from_root`: HashMap storing for each eclass, one path from the root eclass. The path+///         is represented by a list of (eclass, enode) that composes the path+/// - `cycles`: list of cycles. Each cycle is a list of nodes. A cycle of 1->2->4->3->1 will be+///         stored as [1,2,4,3]+fn get_cycles(+    egraph: &EGraph<Mdl, TensorAnalysis>,+    root: Id,+    paths_from_root: &mut HashMap<Id, Vec<(Id, Mdl)>>,+    cycles: &mut Vec<Vec<Mdl>>,+) {+    // Get a map from Id to the eclass objects, since egg doesn't provide accessing eclass from Id+    let id_to_class: HashMap<Id, &EClass<Mdl, ValTnsr>> =+        egraph.classes().map(|class| (class.id, class)).collect();++    get_cycles_rec(+        egraph,+        root,+        /*path_to_here=*/ Vec::<(Id, Mdl)>::new(),+        &id_to_class,+        paths_from_root,+        cycles,+    );+}++/// Traverse the EGraph in DFS order, update paths_from_root and cycles on the fly+///+/// # Parameters+///+/// - `egraph`: E-graph of interest+/// - `eclass`: The current eclass that we are visiting+/// - `path_to_here`: A path from root to this eclass+/// - `id_to_class`: Map from eclass ID to the eclass objects+/// - `paths_from_root`: HashMap storing for each eclass, one path from the root eclass. The path+///         is represented by a list of (eclass, enode) that composes the path+/// - `cycles`: list of cycles. Each cycle is a list of nodes. A cycle of 1->2->4->3->1 will be+///         stored as [1,2,4,3]+fn get_cycles_rec(+    egraph: &EGraph<Mdl, TensorAnalysis>,+    eclass: Id,+    path_to_here: Vec<(Id, Mdl)>,+    id_to_class: &HashMap<Id, &EClass<Mdl, ValTnsr>>,+    paths_from_root: &mut HashMap<Id, Vec<(Id, Mdl)>>,

Changed to HashSet

yycdavid

comment created time in a month

PullRequestReviewEvent

Pull request review commentyycdavid/tamago

Dev branch

 impl MultiPatterns {                 if compatible(&subst_1_dec, &subst_2_dec, &map_1.var_map) {                     // If so, merge two substitutions                     let merged_subst = merge_subst(subst_1_dec, subst_2_dec, &map_1.var_map);+                    // Check if any source pattern contains blacklisted nodes+                    if self.filter_after {+                        if contains_blacklist(+                            rule.0.ast.as_ref(),+                            &mut runner.egraph,+                            &merged_subst,+                        )+                        .0 || contains_blacklist(+                            rule.1.ast.as_ref(),+                            &mut runner.egraph,+                            &merged_subst,+                        )+                        .0+                        {+                            continue;+                        }+                    }                      // check_pat on both dst patterns-                    if check_pat(rule.2.ast.as_ref(), &mut runner.egraph, &merged_subst).0 {-                        if check_pat(rule.3.ast.as_ref(), &mut runner.egraph, &merged_subst).0 {+                    let (valid_1, _, _, existing_1) = check_pat(+                        rule.2.ast.as_ref(),+                        &mut runner.egraph,+                        &merged_subst,+                        /*get_exist_nodes=*/ self.filter_after,+                    );+                    let (valid_2, _, _, existing_2) = check_pat(+                        rule.3.ast.as_ref(),+                        &mut runner.egraph,+                        &merged_subst,+                        /*get_exist_nodes=*/ self.filter_after,+                    );+                    if valid_1 && valid_2 {+                        let cycle_check_passed = if self.no_cycle {+                            if self.filter_after {+                                // Do pre-filtering using the pre-collected descendents info+                                self.check_cycle_partial(+                                    &runner.egraph,+                                    &merged_subst,+                                    &map_1.var_map,+                                    &map_2.var_map,+                                    match_1.eclass,+                                    match_2.eclass,+                                )+                            } else {+                                // Check cycle by make a pass in egraph+                                let mut descendents: HashMap<Id, HashSet<Id>> = Default::default();+                                check_cycle(+                                    &runner.egraph,+                                    &merged_subst,+                                    &map_1.var_map,+                                    &map_2.var_map,+                                    match_1.eclass,+                                    match_2.eclass,+                                    &mut descendents,+                                )+                            }+                        } else {+                            true+                        };+                        if cycle_check_passed {                             // apply dst patterns, union                             let id_1 =                                 rule.2                                     .apply_one(&mut runner.egraph, match_1.eclass, &merged_subst)                                     [0];-                            runner.egraph.union(id_1, match_1.eclass);+                             let id_2 =                                 rule.3                                     .apply_one(&mut runner.egraph, match_2.eclass, &merged_subst)                                     [0];++                            // Add the newly added nodes to the ordering list+                            if self.filter_after {+                                let existing_1 = existing_1.unwrap();+                                let existing_2 = existing_2.unwrap();+                                let (nodes_in_1, _) = add_newly_added(+                                    rule.2.ast.as_ref(),+                                    &mut runner.egraph,+                                    &merged_subst,+                                    &existing_1,+                                );+                                let existing_2_updated: HashSet<Mdl> = existing_2+                                    .iter()+                                    .chain(nodes_in_1.iter())+                                    .map(|node| node.clone())+                                    .collect();+                                add_newly_added(+                                    rule.3.ast.as_ref(),+                                    &mut runner.egraph,+                                    &merged_subst,+                                    &existing_2_updated,+                                );+                            }++                            runner.egraph.union(id_1, match_1.eclass);                             runner.egraph.union(id_2, match_2.eclass);+                        } else {                         }                     }                 }             }         }     }++    /// Returns true if there will not be a cycle introduced by applying this rule.+    ///+    /// Checking based on descendents collected at beginning of run_one(). If any input node+    /// contains descendents that is any of the matched output class, then there can be a cycle+    /// created.+    ///+    /// # Parameters+    ///+    /// - `egraph`: egraph of interest+    /// - `input_subst`: substitution containing the input variables+    /// - `var_map_1`: keys of this map contains all the input variables in source pattern 1+    /// - `var_map_2`: keys of this map contains all the input variables in source pattern 2+    /// - `out_class_1`: Id of the matched eclass of the output of source pattern 1+    /// - `out_class_2`: Id of the matched eclass of the output of source pattern 2+    fn check_cycle_partial(+        &self,+        egraph: &EGraph<Mdl, TensorAnalysis>,+        input_subst: &Subst,+        var_map_1: &HashMap<egg::Var, egg::Var>,+        var_map_2: &HashMap<egg::Var, egg::Var>,+        out_class_1: Id,+        out_class_2: Id,+    ) -> bool {+        // Get all input eclass IDs+        let input_ids: HashSet<Id> = var_map_1+            .iter()+            .chain(var_map_2.iter())+            .map(|(var, _)| *input_subst.get(*var).unwrap())+            .collect();+        // Check descendents of the input eclasses+        for id in input_ids.iter() {+            let descendents = self.descendents.as_ref().unwrap();+            let descendents_input = descendents.get(id).unwrap();+            if descendents_input.contains(&out_class_1) || descendents_input.contains(&out_class_2)+            {+                return false;+            }+        }+        true+    }+}++/// Do post-processing to remove cycles in the egraph, by adding nodes to blacklist.+pub fn remove_cycle_by_order(runner: &mut Runner<Mdl, TensorAnalysis, ()>) {+    // Update blacklist (canonicalize Ids with egraph.find())+    update_blacklist(&mut runner.egraph);++    // Update newly_added (canonicalize Ids with egraph.find()) and construct hashmap+    // for newly_added+    let updated: Vec<Mdl> = runner+        .egraph+        .analysis+        .newly_added+        .iter()+        .map(|node| node.clone().map_children(|id| runner.egraph.find(id)))+        .collect();+    let mut added_node_to_order = HashMap::<Mdl, usize>::new();+    for (i, node) in updated.iter().enumerate() {+        added_node_to_order.entry(node.clone()).or_insert(i);+    }+    // Remove cycles by adding nodes to blacklist+    remove_cycles(&mut runner.egraph, &added_node_to_order, runner.roots[0]);+}++/// Add newly added nodes in this pattern to the list of newly added nodes, for use in cycle+/// filtering+///+/// The newly added nodes are stored in the graph level metadata in egraph.analysis+///+/// # Parameters+///+/// - `pat`: the AST representation of the pattern. See egg::Pattern for more info+/// - `egraph`: E-graph of interest+/// - `subst`: mapping variable to eclass ID. See egg::Subst for more info.+/// - `existing_nodes`: the set of nodes within this pattern that already exists before this+///         pattern is applied+///+/// # Returns+///+/// A tuple of (HashSet<Mdl>, Id) where+///+/// - HashSet<Mdl>: The set of all nodes in this pattern+/// - Id: the Id into egraph of the matched root of this pattern+fn add_newly_added(+    pat: &[ENodeOrVar<Mdl>],+    egraph: &mut EGraph<Mdl, TensorAnalysis>,+    subst: &Subst,+    existing_nodes: &HashSet<Mdl>,+) -> (HashSet<Mdl>, Id) {+    match pat.last().unwrap() {+        ENodeOrVar::Var(w) => (HashSet::<Mdl>::new(), subst[*w]),+        ENodeOrVar::ENode(e) => {+            let children = e.children();+            let results: Vec<(HashSet<Mdl>, Id)> = children+                .iter()+                .map(|child| {+                    add_newly_added(+                        &pat[..usize::from(*child) + 1],+                        egraph,+                        subst,+                        existing_nodes,+                    )+                })+                .collect();++            let mut new_e = e.clone();+            let new_e_ch = new_e.children_mut();+            for (i, res) in results.iter().enumerate() {+                new_e_ch[i] = res.1;+            }++            let mut nodes_in_pat = HashSet::<Mdl>::new();+            for res in results.iter() {+                for node in res.0.iter() {+                    nodes_in_pat.insert(node.clone());+                }+            }+            nodes_in_pat.insert(new_e.clone());++            // Add to order list+            if !existing_nodes.contains(&new_e) {+                egraph.analysis.newly_added.push(new_e.clone());+            }++            (nodes_in_pat, egraph.lookup(new_e).unwrap())+        }+    }+}++/// Remove cycles in EGraph by adding nodes to blacklist+///+/// This function works by:+///     - Make a pass over egraph to get a set of cycles+///     - For each cycle, pick the node that got added latest and add it to blacklist+///     - Repeat until no cycles are left+fn remove_cycles(

Renamed

yycdavid

comment created time in a month

PullRequestReviewEvent

Pull request review commentyycdavid/tamago

Dev branch

 impl MultiPatterns {                 if compatible(&subst_1_dec, &subst_2_dec, &map_1.var_map) {                     // If so, merge two substitutions                     let merged_subst = merge_subst(subst_1_dec, subst_2_dec, &map_1.var_map);+                    // Check if any source pattern contains blacklisted nodes+                    if self.filter_after {+                        if contains_blacklist(+                            rule.0.ast.as_ref(),+                            &mut runner.egraph,+                            &merged_subst,+                        )+                        .0 || contains_blacklist(+                            rule.1.ast.as_ref(),+                            &mut runner.egraph,+                            &merged_subst,+                        )+                        .0+                        {+                            continue;+                        }+                    }                      // check_pat on both dst patterns-                    if check_pat(rule.2.ast.as_ref(), &mut runner.egraph, &merged_subst).0 {-                        if check_pat(rule.3.ast.as_ref(), &mut runner.egraph, &merged_subst).0 {+                    let (valid_1, _, _, existing_1) = check_pat(+                        rule.2.ast.as_ref(),+                        &mut runner.egraph,+                        &merged_subst,+                        /*get_exist_nodes=*/ self.filter_after,+                    );+                    let (valid_2, _, _, existing_2) = check_pat(+                        rule.3.ast.as_ref(),+                        &mut runner.egraph,+                        &merged_subst,+                        /*get_exist_nodes=*/ self.filter_after,+                    );+                    if valid_1 && valid_2 {+                        let cycle_check_passed = if self.no_cycle {+                            if self.filter_after {+                                // Do pre-filtering using the pre-collected descendents info+                                self.check_cycle_partial(+                                    &runner.egraph,+                                    &merged_subst,+                                    &map_1.var_map,+                                    &map_2.var_map,+                                    match_1.eclass,+                                    match_2.eclass,+                                )+                            } else {+                                // Check cycle by make a pass in egraph+                                let mut descendents: HashMap<Id, HashSet<Id>> = Default::default();+                                check_cycle(+                                    &runner.egraph,+                                    &merged_subst,+                                    &map_1.var_map,+                                    &map_2.var_map,+                                    match_1.eclass,+                                    match_2.eclass,+                                    &mut descendents,+                                )+                            }+                        } else {+                            true+                        };+                        if cycle_check_passed {                             // apply dst patterns, union                             let id_1 =                                 rule.2                                     .apply_one(&mut runner.egraph, match_1.eclass, &merged_subst)                                     [0];-                            runner.egraph.union(id_1, match_1.eclass);+                             let id_2 =                                 rule.3                                     .apply_one(&mut runner.egraph, match_2.eclass, &merged_subst)                                     [0];++                            // Add the newly added nodes to the ordering list+                            if self.filter_after {+                                let existing_1 = existing_1.unwrap();+                                let existing_2 = existing_2.unwrap();+                                let (nodes_in_1, _) = add_newly_added(+                                    rule.2.ast.as_ref(),+                                    &mut runner.egraph,+                                    &merged_subst,+                                    &existing_1,+                                );+                                let existing_2_updated: HashSet<Mdl> = existing_2+                                    .iter()+                                    .chain(nodes_in_1.iter())+                                    .map(|node| node.clone())+                                    .collect();+                                add_newly_added(+                                    rule.3.ast.as_ref(),+                                    &mut runner.egraph,+                                    &merged_subst,+                                    &existing_2_updated,+                                );+                            }++                            runner.egraph.union(id_1, match_1.eclass);                             runner.egraph.union(id_2, match_2.eclass);+                        } else {

Done

yycdavid

comment created time in a month

PullRequestReviewEvent

Pull request review commentyycdavid/tamago

Dev branch

 impl MultiPatterns {                 if compatible(&subst_1_dec, &subst_2_dec, &map_1.var_map) {                     // If so, merge two substitutions                     let merged_subst = merge_subst(subst_1_dec, subst_2_dec, &map_1.var_map);+                    // Check if any source pattern contains blacklisted nodes+                    if self.filter_after {+                        if contains_blacklist(+                            rule.0.ast.as_ref(),+                            &mut runner.egraph,+                            &merged_subst,+                        )+                        .0 || contains_blacklist(+                            rule.1.ast.as_ref(),+                            &mut runner.egraph,+                            &merged_subst,+                        )+                        .0+                        {+                            continue;+                        }+                    }                      // check_pat on both dst patterns-                    if check_pat(rule.2.ast.as_ref(), &mut runner.egraph, &merged_subst).0 {-                        if check_pat(rule.3.ast.as_ref(), &mut runner.egraph, &merged_subst).0 {+                    let (valid_1, _, _, existing_1) = check_pat(+                        rule.2.ast.as_ref(),+                        &mut runner.egraph,+                        &merged_subst,+                        /*get_exist_nodes=*/ self.filter_after,+                    );+                    let (valid_2, _, _, existing_2) = check_pat(+                        rule.3.ast.as_ref(),+                        &mut runner.egraph,+                        &merged_subst,+                        /*get_exist_nodes=*/ self.filter_after,+                    );+                    if valid_1 && valid_2 {+                        let cycle_check_passed = if self.no_cycle {+                            if self.filter_after {

Yes I do pre-check before applying a rule. This filter_after refers to we have a post-processing step to add nodes to filtered list

yycdavid

comment created time in a month

PullRequestReviewEvent

Pull request review commentyycdavid/tamago

Dev branch

 impl MultiPatterns {                 if compatible(&subst_1_dec, &subst_2_dec, &map_1.var_map) {                     // If so, merge two substitutions                     let merged_subst = merge_subst(subst_1_dec, subst_2_dec, &map_1.var_map);+                    // Check if any source pattern contains blacklisted nodes+                    if self.filter_after {+                        if contains_blacklist(

Done

yycdavid

comment created time in a month

PullRequestReviewEvent

Pull request review commentyycdavid/tamago

Dev branch

 impl Applier<Mdl, TensorAnalysis> for CheckApply {     } } +/// Check if the matched graph of the pattern contains any blacklisted nodes+///+/// # Returns+///+/// A tuple of (bool, Option<Id>) where+///+/// - bool: true if the nodes in this pattern contains some node in blacklist+/// - Option<Id>: if the nodes in this pattern do not contain blacklisted+///     nodes, then this is the Id of the matched EClass of the root of this pattern(pat.last())+fn contains_blacklist(+    pat: &[ENodeOrVar<Mdl>],+    egraph: &mut EGraph<Mdl, TensorAnalysis>,+    subst: &Subst,+) -> (bool, Option<Id>) {+    match pat.last().unwrap() {+        ENodeOrVar::Var(w) => (false, Some(subst[*w])),+        ENodeOrVar::ENode(e) => {+            let children = e.children();+            let results: Vec<(bool, Option<Id>)> = children+                .iter()+                .map(|child| contains_blacklist(&pat[..usize::from(*child) + 1], egraph, subst))+                .collect();++            let mut contains = false;+            for res in &results {+                if res.0 {+                    contains = true;+                }+            }++            if contains {+                (true, None)+            } else {+                let mut new_e = e.clone();+                let new_e_ch = new_e.children_mut();+                for (i, res) in results.iter().enumerate() {+                    if let Some(id) = res.1 {+                        new_e_ch[i] = id;+                    } else {+                        // This placed shouldn't be reached in any case. The pat and subst passed

Done

yycdavid

comment created time in a month

PullRequestReviewEvent

Pull request review commentyycdavid/tamago

Dev branch

 use egg::*; use root::taso::*; use serde::{Deserialize, Serialize}; use std::collections::HashMap;+use std::collections::HashSet; use std::convert::TryInto; use std::time::{Duration, Instant}; -/// Custom struct implementing our cost function+/// Wrapper class for egg's cost function pub struct TensorCost<'a> {     pub egraph: &'a EGraph<Mdl, TensorAnalysis>,+    pub cost_model: &'a CostModel, }  impl CostFunction<Mdl> for TensorCost<'_> {     type Cost = f32;     /// Getting total cost for the subtree rooted at enode. See egg::CostFunction     /// trait for more information on interface.     fn cost<C: FnMut(Id) -> Self::Cost>(&mut self, enode: &Mdl, mut costs: C) -> Self::Cost {-        let self_cost = get_self_cost(self.egraph, enode);+        let self_cost = self.cost_model.get_self_cost(self.egraph, enode);         enode.fold(self_cost, |sum, id| sum + costs(id))     } } -/// Gets cost for the enode itself.-///-/// This function gets the cost by calling TASO's get_or_create_{some_op}()-/// functions with the tensor information stored in metadata. TASO side stores-/// hashmaps for OpBase objects. So here TASO side will simply lookup previously-/// created ops (with previously measured runtime).-///-/// # Parameters-///-/// - `egraph`: E-graph of interest-/// - `enode`: enode to get cost for-///-/// # Returns-///-/// Cost for this enode.-fn get_self_cost(egraph: &EGraph<Mdl, TensorAnalysis>, enode: &Mdl) -> f32 {-    let x = |i: &Id| &egraph[*i].data;-    let mut g = egraph.analysis.graph.borrow_mut();-    match enode {-        Mdl::Num(_)-        | Mdl::Var(_)-        | Mdl::Input(_)-        | Mdl::Weight(_)-        | Mdl::Merge(_)-        | Mdl::Split0(_)-        | Mdl::Split1(_)-        | Mdl::Reshape(_)-        | Mdl::Transpose(_) => 0.0,--        Mdl::Relu(_a) => {-            // Check types-            let a_t_data = x(_a);-            assert!(a_t_data.dtype == DataKind::Tnsr);--            unsafe {-                // Get op-                let op = (*g.model).get_or_create_activation(*a_t_data.meta, OpType_OP_RELU, true);-                assert!(op != Op_INVALID_OP);-                (*op.ptr).runtime.clone()-            }+/// Class for our cost model+pub struct CostModel {+    /// To have zero cost for all weight op only+    ignore_all_weight_only: bool,+    /// Discount factor for all weight ops+    all_weight_discount: f32,+}++impl CostModel {+    pub fn with_setting(ignore_all_weight_only: bool) -> Self {+        CostModel {+            ignore_all_weight_only: ignore_all_weight_only,+            all_weight_discount: 1.0,         }+    } -        Mdl::Tanh(_a) => {-            // Check types-            let a_t_data = x(_a);-            assert!(a_t_data.dtype == DataKind::Tnsr);+    /// Gets cost for the enode itself.+    ///+    /// This function gets the cost by calling TASO's get_or_create_{some_op}()+    /// functions with the tensor information stored in metadata. TASO side stores+    /// hashmaps for OpBase objects. So here TASO side will simply lookup previously+    /// created ops (with previously measured runtime).+    ///+    /// # Parameters+    ///+    /// - `egraph`: E-graph of interest+    /// - `enode`: enode to get cost for+    ///+    /// # Returns+    ///+    /// Cost for this enode.+    pub fn get_self_cost(&self, egraph: &EGraph<Mdl, TensorAnalysis>, enode: &Mdl) -> f32 {+        let x = |i: &Id| &egraph[*i].data;+        let mut g = egraph.analysis.graph.borrow_mut();+        match enode {+            Mdl::Num(_)+            | Mdl::Var(_)+            | Mdl::Input(_)+            | Mdl::Weight(_)+            | Mdl::Merge(_)+            | Mdl::Split0(_)+            | Mdl::Split1(_)+            | Mdl::Reshape(_)+            | Mdl::Transpose(_)+            | Mdl::Noop(_) => 0.0,++            Mdl::Relu(_a) => {+                // Check types+                let a_t_data = x(_a);+                assert!(a_t_data.dtype == DataKind::Tnsr);++                let runtime = unsafe {+                    // Get op+                    let op =+                        (*g.model).get_or_create_activation(*a_t_data.meta, OpType_OP_RELU, true);+                    assert!(op != Op_INVALID_OP);+                    (*op.ptr).runtime.clone()+                };++                if self.ignore_all_weight_only && x(_a).all_weights {+                    self.all_weight_discount * runtime

Currently it's just 1.0. We can add a command line flag for it if we decide to experiment with this

yycdavid

comment created time in a month

PullRequestReviewEvent

Pull request review commentyycdavid/tamago

Dev branch

+from __future__ import print_function+import argparse+import numpy as np+import matplotlib.pyplot as plt+import os+import json+import scipy+import scipy.stats++def get_args():+    parser = argparse.ArgumentParser(description='Analysis script, get the statistics we want')+    parser.add_argument('--mode', type=str, default='runtime',+        help='Mode of analysis')+    parser.add_argument('--file', type=str, default='data.txt',+        help='File for the input data to analyze')++    return parser.parse_args()++# Plot speedup and optimizer time together in a same bar plot+def plot_runtime_and_speed(benchmark_name, result):+    width = 0.8+    x_locs = [0, 1, 2, 3, 6, 7]+    x_locs = [a + width/2 for a in x_locs]+    colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k']++    fig, ax1 = plt.subplots()+    ax1.bar(x_locs[0], result['orig_runtime'], width=width, label='Original', color=colors[0])+    ax1.bar(x_locs[1], result['taso'], width=width, label='TASO', color=colors[1])+    ax1.bar(x_locs[2], result['greedy'], width=width, label='Greedy', color=colors[2])+    ax1.bar(x_locs[3], result['ilp'], width=width, label='ILP', color=colors[3])++    runtimes = [result['orig_runtime'], result['taso'], result['greedy'], result['ilp']]++    ax1.set_ylabel('Graph runtime (milliseconds)')++    ax2 = ax1.twinx()+    ax2.bar(x_locs[4], result['taso_total_time'], width=width, label='TASO total', color=colors[4])+    ax2.bar(x_locs[4], result['taso_best_time'], width=width, label='TASO best', color=colors[5])+    ax2.bar(x_locs[5], result['ilp_time'], width=width, label='Sat.+ILP', color=colors[6])++    plt.xticks(x_locs, ['' for _ in range(len(x_locs))])++    ax2.set_ylabel('Optimizer time (seconds)')+    ax1.set_xlabel(benchmark_name)++    lines, labels = ax1.get_legend_handles_labels()+    lines2, labels2 = ax2.get_legend_handles_labels()+    ax2.legend(lines + lines2, labels + labels2, loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=4, fancybox=True, shadow=True, prop={'size': 12})++    plt.savefig("{}.png".format(benchmark_name), bbox_inches='tight')+    plt.close()++def plot_runtime_and_speed_2(benchmark_name, result):++    width = 0.8+    x_locs = [0, 1, 2, 3]+    x_locs = [a + width/2 for a in x_locs]+    colors = ['b', 'g', 'r', 'c']++    fig, ax1 = plt.subplots()+    ax1.bar(x_locs[0], result['orig_runtime'], width=width, label='Original', color=colors[0])+    ax1.bar(x_locs[1], result['taso'], width=width, label='TASO', color=colors[1])+    ax1.bar(x_locs[2], result['greedy'], width=width, label='Sat.+Greedy', color=colors[2])+    ax1.bar(x_locs[3], result['ilp'], width=width, label='Sat.+ILP', color=colors[3])++    runtimes = [result['orig_runtime'], result['taso'], result['greedy'], result['ilp']]++    ax1.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=4, fancybox=True, shadow=True, prop={'size': 12})+    plt.xticks(x_locs, ['' for _ in range(len(x_locs))])+    ax1.set_ylabel('Graph runtime (milliseconds)')+    ax1.set_xlabel(benchmark_name)+++    plt.savefig("{}_runtime.png".format(benchmark_name), bbox_inches='tight')+    plt.close()++    x_locs = [0, 1]+    x_locs = [a + width/2 for a in x_locs]+    colors = ['b', 'g', 'r', 'c']++    fig, ax2 = plt.subplots()+    ax2.bar(x_locs[0], result['taso_total_time'], width=width, label='TASO total', color=colors[0])+    ax2.bar(x_locs[0], result['taso_best_time'], width=width, label='TASO best', color=colors[1])+    ax2.bar(x_locs[1], result['ilp_time'], width=width, label='Sat.+ILP', color=colors[2])++    ax2.set_ylabel('Optimizer time (seconds)')+    ax2.set_xlabel(benchmark_name)+    fig = plt.gcf()+    fig.set_size_inches(3, 5)+    plt.xticks(x_locs, ['' for _ in range(len(x_locs))])++    ax2.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=4, fancybox=True, shadow=True, prop={'size': 12})++    plt.savefig("{}_optimizer.png".format(benchmark_name), bbox_inches='tight')+    plt.close()++def runtime_stats(args):+    with open(args.file, 'r') as f:+        content = f.readlines()++    start_times = []+    ext_times = []+    for line in content:+        times = line.split('\t')+        start_times.append(float(times[0]))+        ext_times.append(float(times[1]))++    start_mean = np.mean(start_times)+    start_std = np.std(start_times)+    ext_mean = np.mean(ext_times)+    ext_std = np.std(ext_times)+    print("Start graph runtime: mean {}, std {}".format(start_mean, start_std))+    print("Extracted graph runtime: mean {}, std {}".format(ext_mean, ext_std))++def plot_bars(args):+    # Results for the spotlight talk was manually input, since we don't have the pipeline to store results then+    results = {+        "bert": {+            "orig_runtime": 1.8964,+            "taso": 1.7415,+            "greedy": 1.8903,+            "ilp": 1.7410,+            "taso_total_time": 13.98,+            "taso_best_time": 3.410,+            "ilp_time": 3.022,+        },+        "nasrnn": {+            "orig_runtime": 1.8601,+            "taso": 1.2890,+            "greedy": 1.1446,+            "ilp": 1.1106,+            "taso_total_time": 175.4, +            "taso_best_time": 121.1,+            "ilp_time": 28.47,+        },+        "resnext50": {+            "orig_runtime": 6.0775,+            "taso": 5.8144,+            "greedy": 5.5850,+            "ilp": 5.5704,+            "taso_total_time": 25.00,+            "taso_best_time": 5.909,+            "ilp_time": 1.314,+        }+    }++    plt.rcParams.update({'font.size': 16})++    for (benchmark, result) in results.items():+        plot_runtime_and_speed_2(benchmark, result)+++def speedup_bar(benchmark):+    # Read in results+    tamago_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))+    taso_root = os.path.join(os.path.dirname(tamago_root), "TASO")++    egg_stats_file = os.path.join(tamago_root, "tmp/{}_1_stats.txt".format(benchmark))+    taso_runtime_file = os.path.join(taso_root, "examples/{}_time.txt".format(benchmark))++    with open(egg_stats_file, 'r') as egg_f:+        egg_results = egg_f.readlines()++    egg_results = [json.loads(x) for x in egg_results]+    egg_runtimes = []+    for res in egg_results[-5:]:+        egg_runtimes.append(res['optimized'])++    with open(taso_runtime_file, 'r') as f:+        content = f.readlines()++    orig_runtimes = []+    taso_runtimes = []+    for line in content[-5:]:+        times = line.split('\t')+        orig_runtimes.append(float(times[0]))+        taso_runtimes.append(float(times[1]))++    # Get original runtime mean, TASO mean and ste, egg mean and ste+    orig_mean = np.mean(orig_runtimes)+    taso_speedup = [(orig_mean/x - 1) * 100 for x in taso_runtimes]+    egg_speedup = [(orig_mean/x - 1) * 100 for x in egg_runtimes]+    taso_mean = np.mean(taso_speedup)+    egg_mean = np.mean(egg_speedup)+    taso_ste = scipy.stats.sem(taso_speedup)+    egg_ste = scipy.stats.sem(egg_speedup)++    taso_mean_time = np.mean(taso_runtimes)++    print("{}: orig {} taso {}".format(benchmark, orig_mean, taso_mean_time))++    # Plot bar and save+    width = 0.8+    x_locs = [0, 1]+    x_locs = [a + width/2 for a in x_locs]+    colors = ['b', 'r']++    fig, ax1 = plt.subplots()+    ax1.bar(x_locs[0], taso_mean, width=width, yerr=taso_ste, ecolor='m', capsize=2.0, label='TASO', color=colors[0])+    ax1.bar(x_locs[1], egg_mean, width=width, yerr=egg_ste, ecolor='m', capsize=2.0, label='Sat.+ILP', color=colors[1])++    ax1.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=2, fancybox=True, shadow=True, prop={'size': 14})+    plt.xticks(x_locs, ['' for _ in range(len(x_locs))])+    ax1.set_ylabel('Speed up percentage')+    ax1.set_xlabel(benchmark)++    fig = plt.gcf()+    fig.set_size_inches(2, 5)++    plt.savefig("{}_speedup.png".format(benchmark), bbox_inches='tight')+    plt.close()++def optimizer_time_bar(benchmark):+    # Read in results+    tamago_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))+    taso_root = os.path.join(os.path.dirname(tamago_root), "TASO")++    egg_stats_file = os.path.join(tamago_root, "tmp/{}_2_stats.txt".format(benchmark))+    taso_stats_file = os.path.join(taso_root, "examples/{}_stats.txt".format(benchmark))++    with open(egg_stats_file, 'r') as egg_f:+        egg_results = egg_f.readlines()++    egg_results = [json.loads(x) for x in egg_results]+    egg_times = []+    egg_sat_times = []+    egg_ext_times = []+    for res in egg_results[-5:]:+        egg_times.append(res['extraction'] + res['saturation'])+        egg_sat_times.append(res['saturation'])+        egg_ext_times.append(res['extraction'])++    with open(taso_stats_file, 'r') as f:+        content = f.readlines()++    taso_totals = []+    taso_bests = []+    for line in content[-5:]:+        elements = line.split(' ')+        taso_totals.append(float(elements[3][:-1]))+        taso_bests.append(float(elements[1][:-1]))++    sat_time_mean = np.mean(egg_sat_times)+    ext_time_mean = np.mean(egg_ext_times)++    print("{}, sat time {}, ext time {}".format(benchmark, sat_time_mean, ext_time_mean))++    width = 0.8+    x_locs = [0, 1]+    x_locs = [a + width/2 for a in x_locs]+    colors = ['b', 'g', 'r', 'c']++    egg_time = np.mean(egg_times)+    taso_total = np.mean(taso_totals)+    taso_best = np.mean(taso_bests)++    fig, ax2 = plt.subplots()+    ax2.bar(x_locs[0], taso_total, width=width, label='TASO total', color=colors[0])+    ax2.bar(x_locs[0], taso_best, width=width, label='TASO best', color=colors[1])+    ax2.bar(x_locs[1], egg_time, width=width, label='Sat.+ILP', color=colors[2])++    ax2.set_ylabel('Optimizer time (seconds)')+    ax2.set_xlabel(benchmark)+    fig = plt.gcf()+    fig.set_size_inches(3, 5)+    plt.xticks(x_locs, ['' for _ in range(len(x_locs))])++    #ax2.legend(lines + lines2, labels + labels2, fontsize=10)+    ax2.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=4, fancybox=True, shadow=True, prop={'size': 12})++    plt.savefig("{}_optim_time.png".format(benchmark), bbox_inches='tight')+    plt.close()+    ++def equivalent_graphs(benchmark):+    # Read in results+    tamago_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))+    taso_root = os.path.join(os.path.dirname(tamago_root), "TASO")++    egg_stats_file = os.path.join(tamago_root, "tmp/{}_1_stats.txt".format(benchmark))+    taso_stats_file = os.path.join(taso_root, "examples/{}_stats.txt".format(benchmark))++    with open(egg_stats_file, 'r') as egg_f:+        egg_results = egg_f.readlines()++    egg_results = [json.loads(x) for x in egg_results]+    egg_equiv = []+    for res in egg_results[-5:]:+        egg_equiv.append(res['programs'])++    with open(taso_stats_file, 'r') as f:+        content = f.readlines()++    taso_equiv = []+    for line in content[-5:]:+        elements = line.split(' ')+        taso_equiv.append(int(elements[-1])+100)++    egg_mean = np.mean(egg_equiv)+    taso_mean = np.mean(taso_equiv)++    print("{}: egg (power of 2) {}, taso {}".format(benchmark, egg_mean, taso_mean))+++def multi_trend(benchmark):+    # Read in results+    tamago_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))+    taso_root = os.path.join(os.path.dirname(tamago_root), "TASO")++    taso_runtime_file = os.path.join(taso_root, "examples/{}_time.txt".format(benchmark))++    with open(taso_runtime_file, 'r') as f:+        content = f.readlines()++    orig_runtimes = []+    for line in content[-5:]:+        times = line.split('\t')+        orig_runtimes.append(float(times[0]))+    orig_mean = np.mean(orig_runtimes)++    # iter=1+    egg_stats_file = os.path.join(tamago_root, "tmp/{}_1_stats.txt".format(benchmark))+    with open(egg_stats_file, 'r') as egg_f:+        egg_results = egg_f.readlines()++    egg_results = [json.loads(x) for x in egg_results]+    egg_runtimes = []+    egg_sat_times = []+    egg_ext_times = []+    egg_n_nodes = []+    for res in egg_results[-5:]:+        egg_runtimes.append(res['optimized'])+        egg_sat_times.append(res['saturation'])+        egg_ext_times.append(res['extraction'])+        egg_n_nodes.append(res['nodes'])++    mean_iter_1 = np.mean(egg_runtimes)+    mean_sat_iter_1 = np.mean(egg_sat_times)+    mean_ext_iter_1 = np.mean(egg_ext_times)+    mean_nodes_iter_1 = np.mean(egg_n_nodes)++    # iter=2+    egg_stats_file = os.path.join(tamago_root, "tmp/{}_2_stats.txt".format(benchmark))+    with open(egg_stats_file, 'r') as egg_f:+        egg_results = egg_f.readlines()++    egg_results = [json.loads(x) for x in egg_results]+    egg_runtimes = []+    egg_sat_times = []+    egg_ext_times = []+    egg_n_nodes = []+    for res in egg_results[-5:]:+        egg_runtimes.append(res['optimized'])+        egg_sat_times.append(res['saturation'])+        egg_ext_times.append(res['extraction'])+        egg_n_nodes.append(res['nodes'])++    mean_iter_2 = np.mean(egg_runtimes)+    mean_sat_iter_2 = np.mean(egg_sat_times)+    mean_ext_iter_2 = np.mean(egg_ext_times)+    mean_nodes_iter_2 = np.mean(egg_n_nodes)++    # iter=3+    if benchmark == 'resnext50':+        egg_stats_file = os.path.join(tamago_root, "tmp/{}_3_stats.txt".format(benchmark))+        with open(egg_stats_file, 'r') as egg_f:+            egg_results = egg_f.readlines()++        egg_results = [json.loads(x) for x in egg_results]+        egg_runtimes = []+        egg_sat_times = []+        egg_ext_times = []+        egg_n_nodes = []+        for res in egg_results[-5:]:+            egg_runtimes.append(res['optimized'])+            egg_sat_times.append(res['saturation'])+            egg_ext_times.append(res['extraction'])+            egg_n_nodes.append(res['nodes'])++        mean_iter_3 = np.mean(egg_runtimes)+        mean_sat_iter_3 = np.mean(egg_sat_times)+        mean_ext_iter_3 = np.mean(egg_ext_times)+        mean_nodes_iter_3 = np.mean(egg_n_nodes)++    # The number of nodes for these three in iter 3 is manually recorded, since the ILP solver +    # times out, and the results are not saved in files+    elif benchmark == 'bert':+        mean_iter_3 = -1+        mean_nodes_iter_3 = 842044++    elif benchmark == 'nasrnn':+        mean_iter_3 = -1+        mean_nodes_iter_3 = 10177140++    elif benchmark == 'nasneta':+        mean_iter_3 = -1+        mean_nodes_iter_3 = 11114360++    # Plot runtime & optimizer time v.s. iter+    n_iter = [1,2,3]+    speedup = [orig_mean/mean_iter_1, orig_mean/mean_iter_2]+    optimizer_time = [mean_sat_iter_1+mean_ext_iter_1, mean_sat_iter_2+mean_ext_iter_2]+    if mean_iter_3 > 0:+        speedup.append(orig_mean/mean_iter_3)+        optimizer_time.append(mean_sat_iter_3+mean_ext_iter_3)++    speedup = [(i-1)*100 for i in speedup]++    fig = plt.figure()+    ax1 = fig.add_subplot(111)+    color = 'tab:red'+    ax1.set_xlabel('#iter of multi pattern rewrites')+    ax1.set_ylabel('Speedup percentage', color=color)+    lns1 = ax1.plot(n_iter[:len(speedup)], speedup, marker='s', color=color, label='Speedup')++    plt.xticks(n_iter, ['{}'.format(i) for i in n_iter])++    ax2 = ax1.twinx()++    color = 'tab:blue'+    ax2.set_ylabel('Optimizer time (seconds)', color=color)+    lns2 = ax2.plot(n_iter[:len(speedup)], optimizer_time, marker='^', color=color, label='Optimizer time')++    if len(speedup) < 3:+        ax2.scatter(n_iter[-1], 3600, marker='x', color='b')++    lns = lns1+lns2+    labs = [l.get_label() for l in lns]+    ax1.legend(lns, labs, loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=2, fancybox=True, shadow=True)++    plt.savefig("{}_trend_time.png".format(benchmark), bbox_inches='tight')+    plt.close()++    # Plot nodes v.s. iter+    n_iter = [1,2,3]+    nodes = [mean_nodes_iter_1, mean_nodes_iter_2, mean_nodes_iter_3]++    fig = plt.figure()+    ax1 = fig.add_subplot(111)+    color = 'tab:green'+    ax1.set_xlabel('#iter of multi pattern rewrites')+    ax1.set_ylabel('#enodes', color=color)+    lns1 = ax1.plot(n_iter, nodes, marker='s', color=color)++    plt.xticks(n_iter, ['{}'.format(i) for i in n_iter])+    plt.savefig("{}_trend_nodes.png".format(benchmark), bbox_inches='tight')+    plt.close()+++def plot_speedup(args):+    plt.rcParams.update({'font.size': 18})+    for benchmark in ['nasrnn', 'bert', 'resnext50', 'nasneta']:

Done

yycdavid

comment created time in a month

PullRequestReviewEvent
PublicEvent

Pull request review commentyycdavid/tamago

Dev branch

+from __future__ import print_function+import argparse+import numpy as np+import matplotlib.pyplot as plt+import os+import json+import scipy+import scipy.stats++def get_args():+    parser = argparse.ArgumentParser(description='Analysis script, get the statistics we want')+    parser.add_argument('--mode', type=str, default='runtime',+        help='Mode of analysis')+    parser.add_argument('--file', type=str, default='data.txt',+        help='File for the input data to analyze')++    return parser.parse_args()++# Plot speedup and optimizer time together in a same bar plot+def plot_runtime_and_speed(benchmark_name, result):+    width = 0.8+    x_locs = [0, 1, 2, 3, 6, 7]+    x_locs = [a + width/2 for a in x_locs]+    colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k']++    fig, ax1 = plt.subplots()+    ax1.bar(x_locs[0], result['orig_runtime'], width=width, label='Original', color=colors[0])+    ax1.bar(x_locs[1], result['taso'], width=width, label='TASO', color=colors[1])+    ax1.bar(x_locs[2], result['greedy'], width=width, label='Greedy', color=colors[2])+    ax1.bar(x_locs[3], result['ilp'], width=width, label='ILP', color=colors[3])++    runtimes = [result['orig_runtime'], result['taso'], result['greedy'], result['ilp']]++    ax1.set_ylabel('Graph runtime (milliseconds)')++    ax2 = ax1.twinx()+    ax2.bar(x_locs[4], result['taso_total_time'], width=width, label='TASO total', color=colors[4])+    ax2.bar(x_locs[4], result['taso_best_time'], width=width, label='TASO best', color=colors[5])+    ax2.bar(x_locs[5], result['ilp_time'], width=width, label='Sat.+ILP', color=colors[6])++    plt.xticks(x_locs, ['' for _ in range(len(x_locs))])++    ax2.set_ylabel('Optimizer time (seconds)')+    ax1.set_xlabel(benchmark_name)++    lines, labels = ax1.get_legend_handles_labels()+    lines2, labels2 = ax2.get_legend_handles_labels()+    ax2.legend(lines + lines2, labels + labels2, loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=4, fancybox=True, shadow=True, prop={'size': 12})++    plt.savefig("{}.png".format(benchmark_name), bbox_inches='tight')+    plt.close()++def plot_runtime_and_speed_2(benchmark_name, result):++    width = 0.8+    x_locs = [0, 1, 2, 3]+    x_locs = [a + width/2 for a in x_locs]+    colors = ['b', 'g', 'r', 'c']++    fig, ax1 = plt.subplots()+    ax1.bar(x_locs[0], result['orig_runtime'], width=width, label='Original', color=colors[0])+    ax1.bar(x_locs[1], result['taso'], width=width, label='TASO', color=colors[1])+    ax1.bar(x_locs[2], result['greedy'], width=width, label='Sat.+Greedy', color=colors[2])+    ax1.bar(x_locs[3], result['ilp'], width=width, label='Sat.+ILP', color=colors[3])++    runtimes = [result['orig_runtime'], result['taso'], result['greedy'], result['ilp']]++    ax1.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=4, fancybox=True, shadow=True, prop={'size': 12})+    plt.xticks(x_locs, ['' for _ in range(len(x_locs))])+    ax1.set_ylabel('Graph runtime (milliseconds)')+    ax1.set_xlabel(benchmark_name)+++    plt.savefig("{}_runtime.png".format(benchmark_name), bbox_inches='tight')+    plt.close()++    x_locs = [0, 1]+    x_locs = [a + width/2 for a in x_locs]+    colors = ['b', 'g', 'r', 'c']++    fig, ax2 = plt.subplots()+    ax2.bar(x_locs[0], result['taso_total_time'], width=width, label='TASO total', color=colors[0])+    ax2.bar(x_locs[0], result['taso_best_time'], width=width, label='TASO best', color=colors[1])+    ax2.bar(x_locs[1], result['ilp_time'], width=width, label='Sat.+ILP', color=colors[2])++    ax2.set_ylabel('Optimizer time (seconds)')+    ax2.set_xlabel(benchmark_name)+    fig = plt.gcf()+    fig.set_size_inches(3, 5)+    plt.xticks(x_locs, ['' for _ in range(len(x_locs))])++    ax2.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=4, fancybox=True, shadow=True, prop={'size': 12})++    plt.savefig("{}_optimizer.png".format(benchmark_name), bbox_inches='tight')+    plt.close()++def runtime_stats(args):+    with open(args.file, 'r') as f:+        content = f.readlines()++    start_times = []+    ext_times = []+    for line in content:+        times = line.split('\t')+        start_times.append(float(times[0]))+        ext_times.append(float(times[1]))++    start_mean = np.mean(start_times)+    start_std = np.std(start_times)+    ext_mean = np.mean(ext_times)+    ext_std = np.std(ext_times)+    print("Start graph runtime: mean {}, std {}".format(start_mean, start_std))+    print("Extracted graph runtime: mean {}, std {}".format(ext_mean, ext_std))++def plot_bars(args):+    # Results for the spotlight talk was manually input, since we don't have the pipeline to store results then+    results = {+        "bert": {+            "orig_runtime": 1.8964,+            "taso": 1.7415,+            "greedy": 1.8903,+            "ilp": 1.7410,+            "taso_total_time": 13.98,+            "taso_best_time": 3.410,+            "ilp_time": 3.022,+        },+        "nasrnn": {+            "orig_runtime": 1.8601,+            "taso": 1.2890,+            "greedy": 1.1446,+            "ilp": 1.1106,+            "taso_total_time": 175.4, +            "taso_best_time": 121.1,+            "ilp_time": 28.47,+        },+        "resnext50": {+            "orig_runtime": 6.0775,+            "taso": 5.8144,+            "greedy": 5.5850,+            "ilp": 5.5704,+            "taso_total_time": 25.00,+            "taso_best_time": 5.909,+            "ilp_time": 1.314,+        }+    }++    plt.rcParams.update({'font.size': 16})++    for (benchmark, result) in results.items():+        plot_runtime_and_speed_2(benchmark, result)+++def speedup_bar(benchmark):+    # Read in results+    tamago_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))+    taso_root = os.path.join(os.path.dirname(tamago_root), "TASO")++    egg_stats_file = os.path.join(tamago_root, "tmp/{}_1_stats.txt".format(benchmark))+    taso_runtime_file = os.path.join(taso_root, "examples/{}_time.txt".format(benchmark))++    with open(egg_stats_file, 'r') as egg_f:+        egg_results = egg_f.readlines()++    egg_results = [json.loads(x) for x in egg_results]+    egg_runtimes = []+    for res in egg_results[-5:]:+        egg_runtimes.append(res['optimized'])++    with open(taso_runtime_file, 'r') as f:+        content = f.readlines()++    orig_runtimes = []+    taso_runtimes = []+    for line in content[-5:]:+        times = line.split('\t')+        orig_runtimes.append(float(times[0]))+        taso_runtimes.append(float(times[1]))++    # Get original runtime mean, TASO mean and ste, egg mean and ste+    orig_mean = np.mean(orig_runtimes)+    taso_speedup = [(orig_mean/x - 1) * 100 for x in taso_runtimes]+    egg_speedup = [(orig_mean/x - 1) * 100 for x in egg_runtimes]+    taso_mean = np.mean(taso_speedup)+    egg_mean = np.mean(egg_speedup)+    taso_ste = scipy.stats.sem(taso_speedup)+    egg_ste = scipy.stats.sem(egg_speedup)++    taso_mean_time = np.mean(taso_runtimes)++    print("{}: orig {} taso {}".format(benchmark, orig_mean, taso_mean_time))++    # Plot bar and save+    width = 0.8+    x_locs = [0, 1]+    x_locs = [a + width/2 for a in x_locs]+    colors = ['b', 'r']++    fig, ax1 = plt.subplots()+    ax1.bar(x_locs[0], taso_mean, width=width, yerr=taso_ste, ecolor='m', capsize=2.0, label='TASO', color=colors[0])+    ax1.bar(x_locs[1], egg_mean, width=width, yerr=egg_ste, ecolor='m', capsize=2.0, label='Sat.+ILP', color=colors[1])++    ax1.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=2, fancybox=True, shadow=True, prop={'size': 14})+    plt.xticks(x_locs, ['' for _ in range(len(x_locs))])+    ax1.set_ylabel('Speed up percentage')+    ax1.set_xlabel(benchmark)++    fig = plt.gcf()+    fig.set_size_inches(2, 5)++    plt.savefig("{}_speedup.png".format(benchmark), bbox_inches='tight')+    plt.close()++def optimizer_time_bar(benchmark):+    # Read in results+    tamago_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))+    taso_root = os.path.join(os.path.dirname(tamago_root), "TASO")++    egg_stats_file = os.path.join(tamago_root, "tmp/{}_2_stats.txt".format(benchmark))+    taso_stats_file = os.path.join(taso_root, "examples/{}_stats.txt".format(benchmark))++    with open(egg_stats_file, 'r') as egg_f:+        egg_results = egg_f.readlines()++    egg_results = [json.loads(x) for x in egg_results]+    egg_times = []+    egg_sat_times = []+    egg_ext_times = []+    for res in egg_results[-5:]:+        egg_times.append(res['extraction'] + res['saturation'])+        egg_sat_times.append(res['saturation'])+        egg_ext_times.append(res['extraction'])++    with open(taso_stats_file, 'r') as f:+        content = f.readlines()++    taso_totals = []+    taso_bests = []+    for line in content[-5:]:+        elements = line.split(' ')+        taso_totals.append(float(elements[3][:-1]))+        taso_bests.append(float(elements[1][:-1]))++    sat_time_mean = np.mean(egg_sat_times)+    ext_time_mean = np.mean(egg_ext_times)++    print("{}, sat time {}, ext time {}".format(benchmark, sat_time_mean, ext_time_mean))++    width = 0.8+    x_locs = [0, 1]+    x_locs = [a + width/2 for a in x_locs]+    colors = ['b', 'g', 'r', 'c']++    egg_time = np.mean(egg_times)+    taso_total = np.mean(taso_totals)+    taso_best = np.mean(taso_bests)++    fig, ax2 = plt.subplots()+    ax2.bar(x_locs[0], taso_total, width=width, label='TASO total', color=colors[0])+    ax2.bar(x_locs[0], taso_best, width=width, label='TASO best', color=colors[1])+    ax2.bar(x_locs[1], egg_time, width=width, label='Sat.+ILP', color=colors[2])++    ax2.set_ylabel('Optimizer time (seconds)')+    ax2.set_xlabel(benchmark)+    fig = plt.gcf()+    fig.set_size_inches(3, 5)+    plt.xticks(x_locs, ['' for _ in range(len(x_locs))])++    #ax2.legend(lines + lines2, labels + labels2, fontsize=10)+    ax2.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=4, fancybox=True, shadow=True, prop={'size': 12})++    plt.savefig("{}_optim_time.png".format(benchmark), bbox_inches='tight')+    plt.close()+    ++def equivalent_graphs(benchmark):+    # Read in results+    tamago_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))+    taso_root = os.path.join(os.path.dirname(tamago_root), "TASO")++    egg_stats_file = os.path.join(tamago_root, "tmp/{}_1_stats.txt".format(benchmark))+    taso_stats_file = os.path.join(taso_root, "examples/{}_stats.txt".format(benchmark))++    with open(egg_stats_file, 'r') as egg_f:+        egg_results = egg_f.readlines()++    egg_results = [json.loads(x) for x in egg_results]+    egg_equiv = []+    for res in egg_results[-5:]:+        egg_equiv.append(res['programs'])++    with open(taso_stats_file, 'r') as f:+        content = f.readlines()++    taso_equiv = []+    for line in content[-5:]:+        elements = line.split(' ')+        taso_equiv.append(int(elements[-1])+100)++    egg_mean = np.mean(egg_equiv)+    taso_mean = np.mean(taso_equiv)++    print("{}: egg (power of 2) {}, taso {}".format(benchmark, egg_mean, taso_mean))+++def multi_trend(benchmark):

Done

yycdavid

comment created time in a month

PullRequestReviewEvent

Pull request review commentyycdavid/tamago

Dev branch

+from __future__ import print_function+import argparse+import numpy as np+import matplotlib.pyplot as plt+import os+import json+import scipy+import scipy.stats++def get_args():+    parser = argparse.ArgumentParser(description='Analysis script, get the statistics we want')+    parser.add_argument('--mode', type=str, default='runtime',+        help='Mode of analysis')+    parser.add_argument('--file', type=str, default='data.txt',+        help='File for the input data to analyze')++    return parser.parse_args()++# Plot speedup and optimizer time together in a same bar plot+def plot_runtime_and_speed(benchmark_name, result):+    width = 0.8+    x_locs = [0, 1, 2, 3, 6, 7]+    x_locs = [a + width/2 for a in x_locs]+    colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k']++    fig, ax1 = plt.subplots()+    ax1.bar(x_locs[0], result['orig_runtime'], width=width, label='Original', color=colors[0])+    ax1.bar(x_locs[1], result['taso'], width=width, label='TASO', color=colors[1])+    ax1.bar(x_locs[2], result['greedy'], width=width, label='Greedy', color=colors[2])+    ax1.bar(x_locs[3], result['ilp'], width=width, label='ILP', color=colors[3])++    runtimes = [result['orig_runtime'], result['taso'], result['greedy'], result['ilp']]++    ax1.set_ylabel('Graph runtime (milliseconds)')++    ax2 = ax1.twinx()+    ax2.bar(x_locs[4], result['taso_total_time'], width=width, label='TASO total', color=colors[4])+    ax2.bar(x_locs[4], result['taso_best_time'], width=width, label='TASO best', color=colors[5])+    ax2.bar(x_locs[5], result['ilp_time'], width=width, label='Sat.+ILP', color=colors[6])++    plt.xticks(x_locs, ['' for _ in range(len(x_locs))])++    ax2.set_ylabel('Optimizer time (seconds)')+    ax1.set_xlabel(benchmark_name)++    lines, labels = ax1.get_legend_handles_labels()+    lines2, labels2 = ax2.get_legend_handles_labels()+    ax2.legend(lines + lines2, labels + labels2, loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=4, fancybox=True, shadow=True, prop={'size': 12})++    plt.savefig("{}.png".format(benchmark_name), bbox_inches='tight')+    plt.close()++def plot_runtime_and_speed_2(benchmark_name, result):++    width = 0.8+    x_locs = [0, 1, 2, 3]+    x_locs = [a + width/2 for a in x_locs]+    colors = ['b', 'g', 'r', 'c']++    fig, ax1 = plt.subplots()+    ax1.bar(x_locs[0], result['orig_runtime'], width=width, label='Original', color=colors[0])+    ax1.bar(x_locs[1], result['taso'], width=width, label='TASO', color=colors[1])+    ax1.bar(x_locs[2], result['greedy'], width=width, label='Sat.+Greedy', color=colors[2])+    ax1.bar(x_locs[3], result['ilp'], width=width, label='Sat.+ILP', color=colors[3])++    runtimes = [result['orig_runtime'], result['taso'], result['greedy'], result['ilp']]++    ax1.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=4, fancybox=True, shadow=True, prop={'size': 12})+    plt.xticks(x_locs, ['' for _ in range(len(x_locs))])+    ax1.set_ylabel('Graph runtime (milliseconds)')+    ax1.set_xlabel(benchmark_name)+++    plt.savefig("{}_runtime.png".format(benchmark_name), bbox_inches='tight')+    plt.close()++    x_locs = [0, 1]+    x_locs = [a + width/2 for a in x_locs]+    colors = ['b', 'g', 'r', 'c']++    fig, ax2 = plt.subplots()+    ax2.bar(x_locs[0], result['taso_total_time'], width=width, label='TASO total', color=colors[0])+    ax2.bar(x_locs[0], result['taso_best_time'], width=width, label='TASO best', color=colors[1])+    ax2.bar(x_locs[1], result['ilp_time'], width=width, label='Sat.+ILP', color=colors[2])++    ax2.set_ylabel('Optimizer time (seconds)')+    ax2.set_xlabel(benchmark_name)+    fig = plt.gcf()+    fig.set_size_inches(3, 5)+    plt.xticks(x_locs, ['' for _ in range(len(x_locs))])++    ax2.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=4, fancybox=True, shadow=True, prop={'size': 12})++    plt.savefig("{}_optimizer.png".format(benchmark_name), bbox_inches='tight')+    plt.close()++def runtime_stats(args):+    with open(args.file, 'r') as f:+        content = f.readlines()++    start_times = []+    ext_times = []+    for line in content:+        times = line.split('\t')+        start_times.append(float(times[0]))+        ext_times.append(float(times[1]))++    start_mean = np.mean(start_times)+    start_std = np.std(start_times)+    ext_mean = np.mean(ext_times)+    ext_std = np.std(ext_times)+    print("Start graph runtime: mean {}, std {}".format(start_mean, start_std))+    print("Extracted graph runtime: mean {}, std {}".format(ext_mean, ext_std))++def plot_bars(args):+    # Results for the spotlight talk was manually input, since we don't have the pipeline to store results then+    results = {+        "bert": {+            "orig_runtime": 1.8964,+            "taso": 1.7415,+            "greedy": 1.8903,+            "ilp": 1.7410,+            "taso_total_time": 13.98,+            "taso_best_time": 3.410,+            "ilp_time": 3.022,+        },+        "nasrnn": {+            "orig_runtime": 1.8601,+            "taso": 1.2890,+            "greedy": 1.1446,+            "ilp": 1.1106,+            "taso_total_time": 175.4, +            "taso_best_time": 121.1,+            "ilp_time": 28.47,+        },+        "resnext50": {+            "orig_runtime": 6.0775,+            "taso": 5.8144,+            "greedy": 5.5850,+            "ilp": 5.5704,+            "taso_total_time": 25.00,+            "taso_best_time": 5.909,+            "ilp_time": 1.314,+        }+    }++    plt.rcParams.update({'font.size': 16})++    for (benchmark, result) in results.items():+        plot_runtime_and_speed_2(benchmark, result)+++def speedup_bar(benchmark):+    # Read in results+    tamago_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))+    taso_root = os.path.join(os.path.dirname(tamago_root), "TASO")++    egg_stats_file = os.path.join(tamago_root, "tmp/{}_1_stats.txt".format(benchmark))+    taso_runtime_file = os.path.join(taso_root, "examples/{}_time.txt".format(benchmark))++    with open(egg_stats_file, 'r') as egg_f:+        egg_results = egg_f.readlines()++    egg_results = [json.loads(x) for x in egg_results]+    egg_runtimes = []+    for res in egg_results[-5:]:+        egg_runtimes.append(res['optimized'])++    with open(taso_runtime_file, 'r') as f:+        content = f.readlines()++    orig_runtimes = []+    taso_runtimes = []+    for line in content[-5:]:+        times = line.split('\t')+        orig_runtimes.append(float(times[0]))+        taso_runtimes.append(float(times[1]))++    # Get original runtime mean, TASO mean and ste, egg mean and ste+    orig_mean = np.mean(orig_runtimes)+    taso_speedup = [(orig_mean/x - 1) * 100 for x in taso_runtimes]

This is the percentage of speedup. So if the original graph runs 2.0 sec, and optimized runs 1.0 sec, it is 100% speedup

yycdavid

comment created time in a month

PullRequestReviewEvent

Pull request review commentyycdavid/tamago

Dev branch

+from __future__ import print_function+import argparse+import numpy as np+import matplotlib.pyplot as plt+import os+import json+import scipy+import scipy.stats++def get_args():+    parser = argparse.ArgumentParser(description='Analysis script, get the statistics we want')+    parser.add_argument('--mode', type=str, default='runtime',+        help='Mode of analysis')+    parser.add_argument('--file', type=str, default='data.txt',+        help='File for the input data to analyze')++    return parser.parse_args()++# Plot speedup and optimizer time together in a same bar plot+def plot_runtime_and_speed(benchmark_name, result):+    width = 0.8+    x_locs = [0, 1, 2, 3, 6, 7]+    x_locs = [a + width/2 for a in x_locs]+    colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k']++    fig, ax1 = plt.subplots()+    ax1.bar(x_locs[0], result['orig_runtime'], width=width, label='Original', color=colors[0])+    ax1.bar(x_locs[1], result['taso'], width=width, label='TASO', color=colors[1])+    ax1.bar(x_locs[2], result['greedy'], width=width, label='Greedy', color=colors[2])+    ax1.bar(x_locs[3], result['ilp'], width=width, label='ILP', color=colors[3])++    runtimes = [result['orig_runtime'], result['taso'], result['greedy'], result['ilp']]++    ax1.set_ylabel('Graph runtime (milliseconds)')++    ax2 = ax1.twinx()+    ax2.bar(x_locs[4], result['taso_total_time'], width=width, label='TASO total', color=colors[4])+    ax2.bar(x_locs[4], result['taso_best_time'], width=width, label='TASO best', color=colors[5])+    ax2.bar(x_locs[5], result['ilp_time'], width=width, label='Sat.+ILP', color=colors[6])++    plt.xticks(x_locs, ['' for _ in range(len(x_locs))])++    ax2.set_ylabel('Optimizer time (seconds)')+    ax1.set_xlabel(benchmark_name)++    lines, labels = ax1.get_legend_handles_labels()+    lines2, labels2 = ax2.get_legend_handles_labels()+    ax2.legend(lines + lines2, labels + labels2, loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=4, fancybox=True, shadow=True, prop={'size': 12})++    plt.savefig("{}.png".format(benchmark_name), bbox_inches='tight')+    plt.close()++def plot_runtime_and_speed_2(benchmark_name, result):++    width = 0.8+    x_locs = [0, 1, 2, 3]+    x_locs = [a + width/2 for a in x_locs]+    colors = ['b', 'g', 'r', 'c']++    fig, ax1 = plt.subplots()+    ax1.bar(x_locs[0], result['orig_runtime'], width=width, label='Original', color=colors[0])+    ax1.bar(x_locs[1], result['taso'], width=width, label='TASO', color=colors[1])+    ax1.bar(x_locs[2], result['greedy'], width=width, label='Sat.+Greedy', color=colors[2])+    ax1.bar(x_locs[3], result['ilp'], width=width, label='Sat.+ILP', color=colors[3])++    runtimes = [result['orig_runtime'], result['taso'], result['greedy'], result['ilp']]++    ax1.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=4, fancybox=True, shadow=True, prop={'size': 12})+    plt.xticks(x_locs, ['' for _ in range(len(x_locs))])+    ax1.set_ylabel('Graph runtime (milliseconds)')+    ax1.set_xlabel(benchmark_name)+++    plt.savefig("{}_runtime.png".format(benchmark_name), bbox_inches='tight')+    plt.close()++    x_locs = [0, 1]+    x_locs = [a + width/2 for a in x_locs]+    colors = ['b', 'g', 'r', 'c']++    fig, ax2 = plt.subplots()+    ax2.bar(x_locs[0], result['taso_total_time'], width=width, label='TASO total', color=colors[0])+    ax2.bar(x_locs[0], result['taso_best_time'], width=width, label='TASO best', color=colors[1])+    ax2.bar(x_locs[1], result['ilp_time'], width=width, label='Sat.+ILP', color=colors[2])++    ax2.set_ylabel('Optimizer time (seconds)')+    ax2.set_xlabel(benchmark_name)+    fig = plt.gcf()+    fig.set_size_inches(3, 5)+    plt.xticks(x_locs, ['' for _ in range(len(x_locs))])++    ax2.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=4, fancybox=True, shadow=True, prop={'size': 12})++    plt.savefig("{}_optimizer.png".format(benchmark_name), bbox_inches='tight')+    plt.close()++def runtime_stats(args):+    with open(args.file, 'r') as f:+        content = f.readlines()++    start_times = []+    ext_times = []+    for line in content:+        times = line.split('\t')+        start_times.append(float(times[0]))+        ext_times.append(float(times[1]))++    start_mean = np.mean(start_times)+    start_std = np.std(start_times)+    ext_mean = np.mean(ext_times)+    ext_std = np.std(ext_times)+    print("Start graph runtime: mean {}, std {}".format(start_mean, start_std))+    print("Extracted graph runtime: mean {}, std {}".format(ext_mean, ext_std))++def plot_bars(args):+    # Results for the spotlight talk was manually input, since we don't have the pipeline to store results then+    results = {+        "bert": {+            "orig_runtime": 1.8964,+            "taso": 1.7415,+            "greedy": 1.8903,+            "ilp": 1.7410,+            "taso_total_time": 13.98,+            "taso_best_time": 3.410,+            "ilp_time": 3.022,+        },+        "nasrnn": {+            "orig_runtime": 1.8601,+            "taso": 1.2890,+            "greedy": 1.1446,+            "ilp": 1.1106,+            "taso_total_time": 175.4, +            "taso_best_time": 121.1,+            "ilp_time": 28.47,+        },+        "resnext50": {+            "orig_runtime": 6.0775,+            "taso": 5.8144,+            "greedy": 5.5850,+            "ilp": 5.5704,+            "taso_total_time": 25.00,+            "taso_best_time": 5.909,+            "ilp_time": 1.314,+        }+    }++    plt.rcParams.update({'font.size': 16})++    for (benchmark, result) in results.items():+        plot_runtime_and_speed_2(benchmark, result)+++def speedup_bar(benchmark):+    # Read in results+    tamago_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))+    taso_root = os.path.join(os.path.dirname(tamago_root), "TASO")++    egg_stats_file = os.path.join(tamago_root, "tmp/{}_1_stats.txt".format(benchmark))+    taso_runtime_file = os.path.join(taso_root, "examples/{}_time.txt".format(benchmark))++    with open(egg_stats_file, 'r') as egg_f:+        egg_results = egg_f.readlines()++    egg_results = [json.loads(x) for x in egg_results]+    egg_runtimes = []+    for res in egg_results[-5:]:

Just to get stats with the same number of runs for each benchmark. I might have run more times for some benchmarks

yycdavid

comment created time in a month

PullRequestReviewEvent

Pull request review commentyycdavid/tamago

Dev branch

+from __future__ import print_function+import argparse+import numpy as np+import matplotlib.pyplot as plt+import os+import json+import scipy+import scipy.stats++def get_args():+    parser = argparse.ArgumentParser(description='Analysis script, get the statistics we want')+    parser.add_argument('--mode', type=str, default='runtime',+        help='Mode of analysis')+    parser.add_argument('--file', type=str, default='data.txt',+        help='File for the input data to analyze')++    return parser.parse_args()++# Plot speedup and optimizer time together in a same bar plot+def plot_runtime_and_speed(benchmark_name, result):+    width = 0.8+    x_locs = [0, 1, 2, 3, 6, 7]+    x_locs = [a + width/2 for a in x_locs]+    colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k']++    fig, ax1 = plt.subplots()+    ax1.bar(x_locs[0], result['orig_runtime'], width=width, label='Original', color=colors[0])+    ax1.bar(x_locs[1], result['taso'], width=width, label='TASO', color=colors[1])+    ax1.bar(x_locs[2], result['greedy'], width=width, label='Greedy', color=colors[2])+    ax1.bar(x_locs[3], result['ilp'], width=width, label='ILP', color=colors[3])++    runtimes = [result['orig_runtime'], result['taso'], result['greedy'], result['ilp']]++    ax1.set_ylabel('Graph runtime (milliseconds)')++    ax2 = ax1.twinx()+    ax2.bar(x_locs[4], result['taso_total_time'], width=width, label='TASO total', color=colors[4])+    ax2.bar(x_locs[4], result['taso_best_time'], width=width, label='TASO best', color=colors[5])+    ax2.bar(x_locs[5], result['ilp_time'], width=width, label='Sat.+ILP', color=colors[6])++    plt.xticks(x_locs, ['' for _ in range(len(x_locs))])++    ax2.set_ylabel('Optimizer time (seconds)')+    ax1.set_xlabel(benchmark_name)++    lines, labels = ax1.get_legend_handles_labels()+    lines2, labels2 = ax2.get_legend_handles_labels()+    ax2.legend(lines + lines2, labels + labels2, loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=4, fancybox=True, shadow=True, prop={'size': 12})++    plt.savefig("{}.png".format(benchmark_name), bbox_inches='tight')+    plt.close()++def plot_runtime_and_speed_2(benchmark_name, result):

Removed

yycdavid

comment created time in a month

PullRequestReviewEvent

Pull request review commentyycdavid/tamago

Dev branch

+from __future__ import print_function+import argparse+import numpy as np+import matplotlib.pyplot as plt+import os+import json+import scipy+import scipy.stats++def get_args():+    parser = argparse.ArgumentParser(description='Analysis script, get the statistics we want')+    parser.add_argument('--mode', type=str, default='runtime',+        help='Mode of analysis')+    parser.add_argument('--file', type=str, default='data.txt',+        help='File for the input data to analyze')++    return parser.parse_args()++# Plot speedup and optimizer time together in a same bar plot+def plot_runtime_and_speed(benchmark_name, result):+    width = 0.8+    x_locs = [0, 1, 2, 3, 6, 7]+    x_locs = [a + width/2 for a in x_locs]+    colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k']++    fig, ax1 = plt.subplots()+    ax1.bar(x_locs[0], result['orig_runtime'], width=width, label='Original', color=colors[0])+    ax1.bar(x_locs[1], result['taso'], width=width, label='TASO', color=colors[1])+    ax1.bar(x_locs[2], result['greedy'], width=width, label='Greedy', color=colors[2])+    ax1.bar(x_locs[3], result['ilp'], width=width, label='ILP', color=colors[3])++    runtimes = [result['orig_runtime'], result['taso'], result['greedy'], result['ilp']]++    ax1.set_ylabel('Graph runtime (milliseconds)')++    ax2 = ax1.twinx()+    ax2.bar(x_locs[4], result['taso_total_time'], width=width, label='TASO total', color=colors[4])+    ax2.bar(x_locs[4], result['taso_best_time'], width=width, label='TASO best', color=colors[5])+    ax2.bar(x_locs[5], result['ilp_time'], width=width, label='Sat.+ILP', color=colors[6])++    plt.xticks(x_locs, ['' for _ in range(len(x_locs))])++    ax2.set_ylabel('Optimizer time (seconds)')+    ax1.set_xlabel(benchmark_name)++    lines, labels = ax1.get_legend_handles_labels()+    lines2, labels2 = ax2.get_legend_handles_labels()+    ax2.legend(lines + lines2, labels + labels2, loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=4, fancybox=True, shadow=True, prop={'size': 12})++    plt.savefig("{}.png".format(benchmark_name), bbox_inches='tight')+    plt.close()++def plot_runtime_and_speed_2(benchmark_name, result):++    width = 0.8+    x_locs = [0, 1, 2, 3]+    x_locs = [a + width/2 for a in x_locs]+    colors = ['b', 'g', 'r', 'c']++    fig, ax1 = plt.subplots()+    ax1.bar(x_locs[0], result['orig_runtime'], width=width, label='Original', color=colors[0])+    ax1.bar(x_locs[1], result['taso'], width=width, label='TASO', color=colors[1])+    ax1.bar(x_locs[2], result['greedy'], width=width, label='Sat.+Greedy', color=colors[2])+    ax1.bar(x_locs[3], result['ilp'], width=width, label='Sat.+ILP', color=colors[3])++    runtimes = [result['orig_runtime'], result['taso'], result['greedy'], result['ilp']]++    ax1.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=4, fancybox=True, shadow=True, prop={'size': 12})+    plt.xticks(x_locs, ['' for _ in range(len(x_locs))])+    ax1.set_ylabel('Graph runtime (milliseconds)')+    ax1.set_xlabel(benchmark_name)+++    plt.savefig("{}_runtime.png".format(benchmark_name), bbox_inches='tight')+    plt.close()++    x_locs = [0, 1]+    x_locs = [a + width/2 for a in x_locs]+    colors = ['b', 'g', 'r', 'c']++    fig, ax2 = plt.subplots()+    ax2.bar(x_locs[0], result['taso_total_time'], width=width, label='TASO total', color=colors[0])+    ax2.bar(x_locs[0], result['taso_best_time'], width=width, label='TASO best', color=colors[1])+    ax2.bar(x_locs[1], result['ilp_time'], width=width, label='Sat.+ILP', color=colors[2])++    ax2.set_ylabel('Optimizer time (seconds)')+    ax2.set_xlabel(benchmark_name)+    fig = plt.gcf()+    fig.set_size_inches(3, 5)+    plt.xticks(x_locs, ['' for _ in range(len(x_locs))])++    ax2.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=4, fancybox=True, shadow=True, prop={'size': 12})++    plt.savefig("{}_optimizer.png".format(benchmark_name), bbox_inches='tight')+    plt.close()++def runtime_stats(args):+    with open(args.file, 'r') as f:+        content = f.readlines()++    start_times = []+    ext_times = []+    for line in content:+        times = line.split('\t')+        start_times.append(float(times[0]))+        ext_times.append(float(times[1]))++    start_mean = np.mean(start_times)+    start_std = np.std(start_times)+    ext_mean = np.mean(ext_times)+    ext_std = np.std(ext_times)+    print("Start graph runtime: mean {}, std {}".format(start_mean, start_std))+    print("Extracted graph runtime: mean {}, std {}".format(ext_mean, ext_std))++def plot_bars(args):+    # Results for the spotlight talk was manually input, since we don't have the pipeline to store results then

Removed

yycdavid

comment created time in a month

PullRequestReviewEvent

Pull request review commentyycdavid/tamago

Dev branch

+from __future__ import print_function

Done

yycdavid

comment created time in a month

PullRequestReviewEvent

Pull request review commentyycdavid/tamago

Dev branch

+num_passes=5+node_limit=50000+iter_limit=15+time_limit=3000+iter_multi=1+ilp_time_sec=3600++for pass in $(seq 0 $(expr $num_passes - 1))

Changed

yycdavid

comment created time in a month

PullRequestReviewEvent

Pull request review commentyycdavid/tamago

Dev branch

 def main():                 else:                     solver.Add(t[g[i]] - t[m] - epsilon + A * (1 - x[i]) >= 0) +    # Blacklist constraints+    for j in blacklist_i:+        solver.Add(x[j] == 0)+     # Define objective     obj_expr = [costs[j] * x[j] for j in range(num_nodes)]     solver.Minimize(sum(obj_expr)) +    # Set initial solutions+    if args.initialize:+        print("Initialize with greedy")+        with open('./tmp/init_sol.json') as f:+            sol_data = json.load(f)++        i_list = sol_data['i_list']

'i' refers to the index in x_i

yycdavid

comment created time in a month

PullRequestReviewEvent

push eventyycdavid/tamago

yycdavid

commit sha 653e4269cdc2061c231218ea1e3ee45f33cd07b4

Update Readme and docker

view details

push time in 2 months

push eventyycdavid/egg

yycdavid

commit sha 12cc1ee7731d37fe91901c81f59678fa1d08a2bb

Expose costs in Extractor

view details

push time in 2 months

PR opened yycdavid/tamago

Reviewers
Dev branch

Change list:

  • Initialize ILP with greedy solution
  • Noop to combine outputs
  • Cycle filtering (naive, and efficient)
  • Correct cost model (zero for all weight cost)
  • Get stats and plot results
  • Add nasneta, and new ops involved
  • Shape inference in input interface
+2886 -504

0 comment

19 changed files

pr created time in 2 months

push eventyycdavid/tamago

yycdavid

commit sha e0686709bae1ce09f23fe7893c448b9c272c6673

Comments for stats script

view details

yycdavid

commit sha 7620eb82d2a50855ddad0c166bf396658b94dfff

Experiment scripts

view details

yycdavid

commit sha 88981bb3f5692259ca8a686addec337877f0ebc1

Add issues and records

view details

yycdavid

commit sha 21d689ac0194ced162e9d76f17ba7e0f6d805372

Documentation for rewrite

view details

yycdavid

commit sha 7e776fedd060980f01c4249d1624fcc0c9c6595d

Documentation for optimize

view details

yycdavid

commit sha bf8dd37ca2d4780d781fdea2ed7d90a1f7a68df5

Docs of main, formating

view details

yycdavid

commit sha 3e3c7e427b42fe87f1a43697494ce7e98797cab4

Update record

view details

push time in 2 months

push eventyycdavid/tamago

yycdavid

commit sha 03a92e336ea6b62f358495357e1670a5ce250e14

Update main

view details

yycdavid

commit sha 21ac2d0ff3b033c3091d59fd04f70fb545dfd498

Add shape inference to input interface

view details

yycdavid

commit sha 8bcf97e232f202ac057b81e1d1385ba0e1f070fe

Add nasneta, avgpool2d

view details

yycdavid

commit sha 0656d131f16e29ca8863c5c31e1a0ff0a08a4a78

Concat multi

view details

yycdavid

commit sha 7a55519ba57d3868dfbfe9f7434bb440a58bae76

Add flag for node limit

view details

yycdavid

commit sha 8862e9a493064b8204b4b429d72533244f8a0a9b

Output stats to file

view details

yycdavid

commit sha e5a81fade3544ac35f96a8181e7f7dbb9ac0d8a9

Number of equivalent programs

view details

yycdavid

commit sha 8fae494d6a1ab4e5e8dc00591a3ab088b05a2a47

Running script

view details

yycdavid

commit sha 0098d96f4c2e7491a0b0f954f47de5ae08b67093

Ignore unfound nodes in contain blacklist check

view details

yycdavid

commit sha 3d7ec3c06f0394961675d7fc6af32ce5531257aa

Scripts to get stats

view details

yycdavid

commit sha e67ee7b42ae548de22a8281cba098c50080bab02

Script to plot trend

view details

push time in 2 months

push eventyycdavid/tamago

yycdavid

commit sha 6a85bf4613f05a6a7ba561f57615b1790ffbba40

Record order in single rules

view details

yycdavid

commit sha 1e7fe9fa178e823691df490315c9ee8839fae9a4

Filter cycle for single rules, and have a global added order

view details

yycdavid

commit sha 83e8de48f6edaac26c3a94ea128c961efbbcdbc0

Pass blacklist nodes to ILP

view details

push time in 2 months

push eventyycdavid/tamago

yycdavid

commit sha 58ff0c296b29460624540526c4940d8eab134f75

Bar plots on stats

view details

yycdavid

commit sha a569039e82d9cda46f542e11ae7f55bd53a6d4de

Update stats plot

view details

yycdavid

commit sha d7a24749e89e522a4ed0df9fc379497a365f4bb2

Pre-filter, add blacklist

view details

yycdavid

commit sha f240bd4937af49835437f71c837380ff9446c13e

Check blacklist before apply

view details

yycdavid

commit sha 2563c0717cd3481bedf43c6dfe44cf5c019994f3

Check blacklist in multi

view details

yycdavid

commit sha 904020ea442aeec9903faee342b2eb71b6e77049

Check blacklist in get descendents

view details

yycdavid

commit sha 97d57ab80de60a0d7ca44e71b33c6abfaf72bb0f

Get existing nodes before apply

view details

yycdavid

commit sha 556e4610f49fa673495b1b779e1f7020ea000eb6

Record order of nodes added

view details

yycdavid

commit sha 7260ce34f31518582c7fd2ff6a34c79a78632d71

Update added nodes a the end of run one

view details

yycdavid

commit sha 1282be627a9cc8989a3e0d09680b0225813fa09e

Get cycles

view details

yycdavid

commit sha 349c84f04c11fe668b5c8d194935602ce7926aba

Resolve cycles by adding to blacklist

view details

push time in 2 months

push eventyycdavid/tamago

yycdavid

commit sha bda00b776bfcb1cb41cc0c24e0409e5a140e70bf

Add flag for filter cycle and multi-pattern iterations

view details

yycdavid

commit sha 08e3d8a3cc5c82da0d347f6bf46c74493c7b341b

Filter cycle in multi rules

view details

yycdavid

commit sha 63ece8d1c2c861cb25cf801546f8b4e757aa818e

Apply half for symmetric multi pattern rules

view details

yycdavid

commit sha bb86f457f40b50ee575ccdf619759fa023ff2648

Only zero cost for all weight concat

view details

yycdavid

commit sha 9c767e0f10bfa71d67bf7df5980e51777d57183a

Refactor cost model code structure

view details

yycdavid

commit sha 05ecf37aeb9b9fda2f42fefe9c056abaf0c13137

Discount cost for all weight ops

view details

push time in 2 months

MemberEvent

push eventyycdavid/tamago

Yichen Yang

commit sha 22a09bb986a62e3d78693bb00d39b2617fde94b5

Merge pull request #6 from yycdavid/dev ILP extraction

view details

yycdavid

commit sha 625f4a3581f85b61cf1f7f5bccc9bc4f3e934273

Compute stats for runtime

view details

yycdavid

commit sha 329cd6a4d2165216c1e7ec47d7478b5346474949

Save greedy solution for ILP

view details

yycdavid

commit sha 5bcfb8731e876c09871d402ab62e8aee49516978

Initialize ILP with greedy solution

view details

yycdavid

commit sha 07cccb3dad99e713808b8dbcf74153f8bdfce2ea

Add flag to set number of threads and time limit for ILP

view details

yycdavid

commit sha f2bfeb0271005ed21b18d2c59abf84dbe9930611

Add noop to combine outputs

view details

yycdavid

commit sha f806a4e35b404c4cc46449011a5d00378304e0c6

Fix bert combine output

view details

yycdavid

commit sha fcf21ce4995f6fa3e0d1993681ac3a096fcf0f3a

Update gitignore

view details

push time in 2 months

push eventyycdavid/tamago

yycdavid

commit sha 86a2617d54b19949bb854f3eff244282f5bb1ac2

Support transpose and reshape

view details

yycdavid

commit sha aa08529ff0462f4ba5d6c80a2271dfa57ba4127d

Add bert

view details

yycdavid

commit sha 5f9d35adab6764edf41d1982700e7a73868ac2e7

Limit iteration to apply of multi-patterns

view details

yycdavid

commit sha 3bd36ab742ffe26c648f425a09a8f8bb97895e82

Get statistics of the egraph after saturation phase

view details

yycdavid

commit sha b7444c56aaa0437922fa92483ad25ef6326124ed

New docker env for or-tools and SCIP

view details

yycdavid

commit sha 7ab940654a28015f0630eb13dc64d4b3d21b5ca0

Save ILP info to json

view details

yycdavid

commit sha 78a6f9bc58c54a50fb806d88f52a8d98d4ad965a

Create ILP and solve

view details

yycdavid

commit sha 0c77cd2bbd8add527954e193f5cb9e9ecaa200ef

Read solved results and construct optimized graph

view details

yycdavid

commit sha 07332c43129cca94a36e871376f62541164eb020

Add option to use int variable for topological order

view details

yycdavid

commit sha e61a9dab15198bf9f741a7b9623f27055d35a6db

Add option to add sum to 1 constraint for each class

view details

yycdavid

commit sha d8b80b21dbbd3e71ca4f1348c0baf6f80a4d1910

Add option to have no ordering constraints in ILP

view details

yycdavid

commit sha 5067685920a3e27914d5dbe55c0cd0bce399b3c2

Add documentations, format

view details

yycdavid

commit sha b93cffe5327d2108f711be1e0c2625ccaba2ba18

Update

view details

Yichen Yang

commit sha 22a09bb986a62e3d78693bb00d39b2617fde94b5

Merge pull request #6 from yycdavid/dev ILP extraction

view details

push time in 2 months

PR merged yycdavid/tamago

Reviewers
ILP extraction

Hi all, please review this pull request. Thanks! Change list:

  • Support transpose and reshape
  • Add bert
  • ILP extraction related
+769 -54

0 comment

13 changed files

yycdavid

pr closed time in 2 months

push eventyycdavid/tamago

yycdavid

commit sha b93cffe5327d2108f711be1e0c2625ccaba2ba18

Update

view details

push time in 2 months

Pull request review commentyycdavid/tamago

ILP extraction

 impl MultiPatterns {     /// it checks and applies the dst patterns. It won't apply if src_1 and src_2 matches with     /// the same eclass. It always returns Ok()     pub fn run_one(&self, runner: &mut Runner<Mdl, TensorAnalysis, ()>) -> Result<(), String> {-        // Construct Vec to store matches for each canonicalized pattern-        let matches: Vec<Vec<SearchMatches>> = self-            .canonical_src_pat-            .iter()-            .map(|x| x.search(&runner.egraph))-            .collect();--        // For each multi rule-        for (i, rule) in self.rules.iter().enumerate() {-            let map_1 = &self.src_pat_maps[i].0;-            let map_2 = &self.src_pat_maps[i].1;-            let matches_1 = &matches[map_1.index];-            let matches_2 = &matches[map_2.index];-            for match_1 in matches_1 {-                for match_2 in matches_2 {-                    if match_1.eclass == match_2.eclass {-                        // We don't want to apply multi-pattern rules on the same eclass-                        continue;+        if runner.iterations.len() < 1 {

This is the place I limit when to apply the multi-pattern rules. Current version means we apply multi-pattern rule at iteration 1. I can change to a argument flag for this later

yycdavid

comment created time in 2 months

Pull request review commentyycdavid/tamago

ILP extraction

 fn get_self_cost(egraph: &EGraph<Mdl, TensorAnalysis>, enode: &Mdl) -> f32 {         }     } }++/// Prepare the data for formulation ILP+///+/// # Returns+///+/// - `m_id_map`: list of EClass Id's each index m refers to+/// - `e_m`: each entry is the list of nodes i within eclass m+/// - `h_i`: each entry is the list of children EClass indices for node i+/// - `cost_i`: self cost for each node i+/// - `g_i`: which EClass index does node i belong to+/// - `root_m`: EClass index of the root eclass+/// - `i_to_nodes: Vector of enodes, ordered by index i+pub fn prep_ilp_data(+    egraph: &EGraph<Mdl, TensorAnalysis>,+    root: Id,+) -> (+    Vec<Id>,+    Vec<Vec<usize>>,+    Vec<Vec<usize>>,+    Vec<f32>,+    Vec<usize>,+    usize,+    Vec<Mdl>,+) {+    let m_id_map: Vec<Id> = egraph.classes().map(|c| egraph.find(c.id)).collect();+    assert!(m_id_map.len() == egraph.number_of_classes());+    let id_m_map: HashMap<Id, usize> = m_id_map+        .iter()+        .enumerate()+        .map(|(i, id)| (*id, i))+        .collect();++    let num_classes = egraph.number_of_classes();+    let num_nodes = egraph.total_size();+    let mut i_to_nodes: Vec<Mdl> = Vec::with_capacity(num_nodes);+    let mut e_m: Vec<Vec<usize>> = vec![Vec::new(); num_classes];+    let mut h_i: Vec<Vec<usize>> = Vec::with_capacity(num_nodes);+    let mut cost_i: Vec<f32> = Vec::with_capacity(num_nodes);+    let mut g_i: Vec<usize> = Vec::with_capacity(num_nodes);++    let mut i = 0;+    for class in egraph.classes() {+        let m = *id_m_map.get(&egraph.find(class.id)).unwrap();+        for node in class.iter() {+            i_to_nodes.push(node.clone());+            e_m[m].push(i);+            h_i.push(+                node.children()+                    .iter()+                    .map(|id| *id_m_map.get(&egraph.find(*id)).unwrap())+                    .collect(),+            );+            cost_i.push(get_self_cost(egraph, node));+            g_i.push(m);+            i += 1;+        }+    }++    let root_m = *id_m_map.get(&egraph.find(root)).unwrap();++    (m_id_map, e_m, h_i, cost_i, g_i, root_m, i_to_nodes)+}++/// Struct for storing the solved results from ILP+#[derive(Debug, Serialize, Deserialize)]+pub struct SolvedResults {+    /// The solved values for the variables associated with each node+    pub solved_x: Vec<i32>,+    /// The minimum total cost found+    pub cost: f32,+}++/// Construct the RecExpr of the optimized graph extracted+///+/// This function does the construction recursively with memoization. Call it with eclass=root+/// will construct the whole extracted graph+///+/// # Parameters+///+/// - `node_picked`: hashmap storing which node is picked for each EClass ID+/// - `expr`: the RecExpr storing the optimized graph, it is constructed within this function+/// - `eclass`: The EClass ID that we aim to construct as root+/// - `added_memo`: Map from EClass ID to RecExpr ID. Storing the eclasses that were already added+/// - `egraph`: E-graph of interest+///+/// # Returns+///+/// - The ID in the RecExpr for the eclass passed in as argument

The output RecExpr (expr) contains the whole extracted graph. This function is returning the ID into 'expr' (essentially the index, as RecExpr internally is stored as an array) corresponds to the 'eclass' passed in as argument. So calling it with eclass=root will return the ID of root in 'expr'. I updated the comments

yycdavid

comment created time in 2 months

Pull request review commentyycdavid/tamago

ILP extraction

 fn get_self_cost(egraph: &EGraph<Mdl, TensorAnalysis>, enode: &Mdl) -> f32 {         }     } }++/// Prepare the data for formulation ILP+///+/// # Returns+///+/// - `m_id_map`: list of EClass Id's each index m refers to+/// - `e_m`: each entry is the list of nodes i within eclass m+/// - `h_i`: each entry is the list of children EClass indices for node i+/// - `cost_i`: self cost for each node i+/// - `g_i`: which EClass index does node i belong to+/// - `root_m`: EClass index of the root eclass+/// - `i_to_nodes: Vector of enodes, ordered by index i+pub fn prep_ilp_data(+    egraph: &EGraph<Mdl, TensorAnalysis>,+    root: Id,+) -> (+    Vec<Id>,+    Vec<Vec<usize>>,+    Vec<Vec<usize>>,+    Vec<f32>,+    Vec<usize>,+    usize,+    Vec<Mdl>,+) {+    let m_id_map: Vec<Id> = egraph.classes().map(|c| egraph.find(c.id)).collect();+    assert!(m_id_map.len() == egraph.number_of_classes());+    let id_m_map: HashMap<Id, usize> = m_id_map+        .iter()+        .enumerate()+        .map(|(i, id)| (*id, i))+        .collect();++    let num_classes = egraph.number_of_classes();+    let num_nodes = egraph.total_size();+    let mut i_to_nodes: Vec<Mdl> = Vec::with_capacity(num_nodes);+    let mut e_m: Vec<Vec<usize>> = vec![Vec::new(); num_classes];+    let mut h_i: Vec<Vec<usize>> = Vec::with_capacity(num_nodes);+    let mut cost_i: Vec<f32> = Vec::with_capacity(num_nodes);+    let mut g_i: Vec<usize> = Vec::with_capacity(num_nodes);++    let mut i = 0;+    for class in egraph.classes() {+        let m = *id_m_map.get(&egraph.find(class.id)).unwrap();+        for node in class.iter() {+            i_to_nodes.push(node.clone());+            e_m[m].push(i);+            h_i.push(+                node.children()+                    .iter()+                    .map(|id| *id_m_map.get(&egraph.find(*id)).unwrap())+                    .collect(),+            );+            cost_i.push(get_self_cost(egraph, node));+            g_i.push(m);+            i += 1;+        }+    }++    let root_m = *id_m_map.get(&egraph.find(root)).unwrap();++    (m_id_map, e_m, h_i, cost_i, g_i, root_m, i_to_nodes)+}++/// Struct for storing the solved results from ILP+#[derive(Debug, Serialize, Deserialize)]+pub struct SolvedResults {+    /// The solved values for the variables associated with each node+    pub solved_x: Vec<i32>,+    /// The minimum total cost found+    pub cost: f32,+}++/// Construct the RecExpr of the optimized graph extracted+///+/// This function does the construction recursively with memoization. Call it with eclass=root+/// will construct the whole extracted graph+///+/// # Parameters+///+/// - `node_picked`: hashmap storing which node is picked for each EClass ID+/// - `expr`: the RecExpr storing the optimized graph, it is constructed within this function

Great to know!

yycdavid

comment created time in 2 months

Pull request review commentyycdavid/tamago

ILP extraction

 fn get_self_cost(egraph: &EGraph<Mdl, TensorAnalysis>, enode: &Mdl) -> f32 {         }     } }++/// Prepare the data for formulation ILP+///+/// # Returns+///+/// - `m_id_map`: list of EClass Id's each index m refers to+/// - `e_m`: each entry is the list of nodes i within eclass m+/// - `h_i`: each entry is the list of children EClass indices for node i+/// - `cost_i`: self cost for each node i+/// - `g_i`: which EClass index does node i belong to+/// - `root_m`: EClass index of the root eclass+/// - `i_to_nodes: Vector of enodes, ordered by index i+pub fn prep_ilp_data(+    egraph: &EGraph<Mdl, TensorAnalysis>,+    root: Id,+) -> (+    Vec<Id>,+    Vec<Vec<usize>>,+    Vec<Vec<usize>>,+    Vec<f32>,+    Vec<usize>,+    usize,+    Vec<Mdl>,+) {+    let m_id_map: Vec<Id> = egraph.classes().map(|c| egraph.find(c.id)).collect();+    assert!(m_id_map.len() == egraph.number_of_classes());+    let id_m_map: HashMap<Id, usize> = m_id_map+        .iter()+        .enumerate()+        .map(|(i, id)| (*id, i))+        .collect();++    let num_classes = egraph.number_of_classes();+    let num_nodes = egraph.total_size();+    let mut i_to_nodes: Vec<Mdl> = Vec::with_capacity(num_nodes);+    let mut e_m: Vec<Vec<usize>> = vec![Vec::new(); num_classes];+    let mut h_i: Vec<Vec<usize>> = Vec::with_capacity(num_nodes);+    let mut cost_i: Vec<f32> = Vec::with_capacity(num_nodes);+    let mut g_i: Vec<usize> = Vec::with_capacity(num_nodes);++    let mut i = 0;+    for class in egraph.classes() {+        let m = *id_m_map.get(&egraph.find(class.id)).unwrap();+        for node in class.iter() {+            i_to_nodes.push(node.clone());+            e_m[m].push(i);+            h_i.push(+                node.children()+                    .iter()+                    .map(|id| *id_m_map.get(&egraph.find(*id)).unwrap())+                    .collect(),+            );+            cost_i.push(get_self_cost(egraph, node));+            g_i.push(m);+            i += 1;+        }+    }++    let root_m = *id_m_map.get(&egraph.find(root)).unwrap();++    (m_id_map, e_m, h_i, cost_i, g_i, root_m, i_to_nodes)+}++/// Struct for storing the solved results from ILP+#[derive(Debug, Serialize, Deserialize)]+pub struct SolvedResults {+    /// The solved values for the variables associated with each node+    pub solved_x: Vec<i32>,+    /// The minimum total cost found+    pub cost: f32,+}++/// Construct the RecExpr of the optimized graph extracted+///+/// This function does the construction recursively with memoization. Call it with eclass=root+/// will construct the whole extracted graph+///+/// # Parameters+///+/// - `node_picked`: hashmap storing which node is picked for each EClass ID+/// - `expr`: the RecExpr storing the optimized graph, it is constructed within this function

Done

yycdavid

comment created time in 2 months

Pull request review commentyycdavid/tamago

ILP extraction

 fn test(matches: clap::ArgMatches) {         }     }; -    let runner_start = Runner::<Mdl, TensorAnalysis, ()>::default().with_expr(&start);-    runner_start-        .egraph-        .dot()-        .to_svg("target/start.svg")-        .unwrap();-    let time_start = get_full_graph_runtime(&runner_start);-    println!("Start graph runtime: {}", time_start);+    // Get multi-pattern rules. learned_rules are the learned rules from TASO,+    // pre_defined_multi are the hand-specified rules from TASO+    let multi_patterns = if let Some(rule_file) = matches.value_of("multi_rules") {+        let learned_rules =+            read_to_string(rule_file).expect("Something went wrong reading the rule file");+        let pre_defined_multi = PRE_DEFINED_MULTI.iter().map(|&x| x);+        let multi_rules: Vec<&str> = learned_rules.split("\n").chain(pre_defined_multi).collect();+        MultiPatterns::with_rules(multi_rules)+    } else {+        let multi_rules: Vec<&str> = PRE_DEFINED_MULTI.iter().map(|&x| x).collect();+        MultiPatterns::with_rules(multi_rules)+    };++    // Run saturation+    let n_sec = matches+        .value_of("n_sec")+        .map_or(10, |s| s.parse::<u64>().unwrap());+    let time_limit_sec = Duration::new(n_sec, 0);+    let iter_limit = matches+        .value_of("n_iter")+        .map_or(1, |s| s.parse::<usize>().unwrap());++    let runner = if use_multi {+        // This hook function (which applies the multi-pattern rules) will be called at the+        // beginning of each iteration in equality saturation+        Runner::<Mdl, TensorAnalysis, ()>::default()+            .with_node_limit(100000)

Yes, right now I am not relying on this limit, so was just setting to a large number

yycdavid

comment created time in 2 months

more