1 /**
2  * Authors: Tomoya Tanjo
3  * Copyright: © 2021 Tomoya Tanjo
4  * License: Apache-2.0
5  */
6 module salad.util;
7 
8 import dyaml : Node;
9 
10 import salad.meta : idMap;
11 
12 /// dig for node
13 auto dig(T)(in Node node, string key, T default_)
14 {
15     return dig(node, [key], default_);
16 }
17 
18 /// ditto
19 auto dig(T)(in Node node, string[] keys, T default_)
20 {
21     Node ret = node;
22     foreach(k; keys)
23     {
24         if (auto n = k in ret)
25         {
26             ret = *n;
27         }
28         else
29         {
30             static if (is(T : void[]))
31             {
32                 return Node((Node[]).init);
33             }
34             else
35             {
36                 return Node(default_);
37             }
38         }
39     }
40     return ret;
41 }
42 
43 // dig for CWL object
44 auto dig(alias K, U, T, idMap idMap_ = idMap.init)(T t, U default_ = U.init)
45 if (!is(T: Node))
46 {
47     static assert(is(typeof(K) == string) || is(typeof(K) == string[]));
48     static if (is(typeof(K) == string))
49     {
50         return dig!([K])(t, default_);
51     }
52     else
53     {
54         import std.traits : getUDAs, hasMember, hasUDA, isArray;
55         import salad.type : isSumType, match;
56 
57         static if (K.length == 0)
58         {
59             static if (isSumType!T)
60             {
61                 return t.match!(
62                     (U u) => u,
63                     _ => default_,
64                 );
65             }
66             else
67             {
68                 return t;
69             }
70         }
71         else static if (hasMember!(T, K[0]~"_"))
72         {
73             auto field = mixin("t."~K[0]~"_");
74             static if (hasUDA!(mixin("t."~K[0]~"_"), idMap))
75             {
76                 enum nextIDMap = getUDAs!(mixin("t."~K[0]~"_"), idMap)[0];
77             }
78             else
79             {
80                 enum nextIDMap = idMap.init;
81             }
82             return dig!(K[1..$], U, typeof(field), nextIDMap)(field, default_);
83         }
84         else static if (isSumType!T)
85         {
86             static if (hasUDA!(mixin("t."~K[0]~"_"), idMap))
87             {
88                 enum nextIDMap = getUDAs!(mixin("t."~K[0]~"_"), idMap)[0];
89             }
90             else
91             {
92                 enum nextIDMap = idMap.init;
93             }
94             alias TS = Filter!(ApplyRight!(hasMember, K[0]~"_"), T.Types);
95             alias ddig = ApplyRight!(ApplyLeft!(ApplyLeft!(dig, K[1..$]), U), nextIDMap);
96             return t.match!(
97                 staticMap!(ddig, TS),
98                 _ => default_,
99             );
100         }
101         else static if (isArray!T)
102         {
103             import std.algorithm : map;
104             import std.array : assocArray;
105 
106             static assert(idMap_ != idMap.init, "dig does not support index access");
107             auto aa = t.map!((e) {
108                 import std.typecons : tuple;
109                 auto f = mixin("e."~idMap_.subject~"_");
110                 static if (isSumType!(typeof(f)))
111                 {
112                     auto k = f.tryMatch!((string s) => s);
113                 }
114                 else
115                 {
116                     auto k = f;
117                 }
118                 return tuple(k, e);
119             }).assocArray;
120 
121             if (auto v = K[0] in aa)
122             {
123                 return dig!(K[1..$])(*v, default_);
124             }
125             else
126             {
127                 return default_;
128             }
129         }
130         else
131         {
132             return default_;
133         }
134     }
135 }
136 
137 /// enforceDig
138 auto edig(Ex = Exception)(in Node node, string key, string msg = "")
139 {
140     return edig!Ex(node, [key], msg);
141 }
142 
143 /// ditto
144 auto edig(Ex = Exception)(in Node node, string[] keys, string msg = "")
145 {
146     Node ret = node;
147     foreach(k; keys)
148     {
149         if (auto n = k in ret)
150         {
151             ret = *n;
152         }
153         else
154         {
155             import std.format : format;
156             import std.range : empty;
157             msg = msg.empty ? format!"No such field: %s"(k) : msg;
158             throw new Ex(msg);
159         }
160     }
161     return ret;
162 }
163 
164 auto diff(Node lhs, Node rhs)
165 {
166     import dyaml : NodeType;
167 
168     import std.format;
169     
170     alias Entry = Diff.Entry;
171 
172     if (lhs.type != rhs.type)
173     {
174         return Diff([Entry(lhs, rhs, format!"Different node type: (%s, %s)"(lhs.type, rhs.type))]);
175     }
176 
177     Entry[] result;
178     if (lhs.tag != rhs.tag)
179     {
180         result ~= Entry(lhs, rhs, format!"Different node tag: (%s, %s)"(lhs.tag, rhs.tag));
181     }
182 
183     switch(lhs.type)
184     {
185     case NodeType.mapping:
186         import std.algorithm : map, schwartzSort, setDifference;
187         import std.array : array;
188 
189         auto lmap = lhs.mapping.array.schwartzSort!"a.key";
190         auto rmap = rhs.mapping.array.schwartzSort!"a.key";
191         if (lmap.length != rmap.length)
192         {
193             result ~= Entry(lhs, rhs, format!"Different #node mapping entries: (%s, %s)"(lmap.length, rmap.length));
194         }
195         auto l_r = setDifference!"a.key < b.key"(lmap, rmap);
196         if (!l_r.empty)
197         {
198             result ~= Entry(lhs, rhs, format!"lhs has extra mapping entries: (%s)"(l_r.map!"a.key".array));
199         }
200 
201         auto r_l = setDifference!"a.key < b.key"(rmap, lmap);
202         if (!r_l.empty)
203         {
204             result ~= Entry(lhs, rhs, format!"rhs has extra mapping entries: (%s)"(r_l.map!"a.key".array));
205         }
206 
207         if(l_r.empty && r_l.empty)
208         {
209             import std.algorithm : joiner, map;
210             import std.range : zip;
211             result ~= zip(lmap, rmap).map!(a => diff(a[0].value, a[1].value).entries).joiner.array;
212         }
213         break;
214     case NodeType.sequence:
215         import std.array : array;
216 
217         auto lhsArr = lhs.sequence.array;
218         auto rhsArr = rhs.sequence.array;
219         if (lhsArr.length != rhsArr.length)
220         {
221             result ~= Entry(lhs, rhs, format!"Different node length: (%s, %s)"(lhs.length, rhs.length));
222         }
223         else
224         {
225             import std.algorithm : joiner, map;
226             import std.range : zip;
227             result ~= zip(lhsArr, rhsArr).map!(a => diff(a[0], a[1]).entries).joiner.array;
228         }
229         break;
230     case NodeType.boolean:
231         auto lhsbool = lhs.as!bool;
232         auto rhsbool = rhs.as!bool;
233         if (lhsbool != rhsbool)
234         {
235             result ~= Entry(lhs, rhs, format!"Different boolean value: (%s, %s)"(lhsbool, rhsbool));
236         }
237         break;
238     case NodeType.integer:
239         auto lhsint = lhs.as!int;
240         auto rhsint = rhs.as!int;
241         if (lhsint != rhsint)
242         {
243             result ~= Entry(lhs, rhs, format!"Different integer value: (%s, %s)"(lhsint, rhsint));
244         }
245         break;
246     case NodeType.decimal:
247         import std.math : isClose;
248         auto lhsreal = lhs.as!real;
249         auto rhsreal = rhs.as!real;
250         if (lhsreal.isClose(rhsreal))
251         {
252             result ~= Entry(lhs, rhs, format!"Different decimal value: (%s, %s)"(lhsreal, rhsreal));
253         }
254         break;
255     case NodeType..string:
256         auto lhsstr = lhs.as!string;
257         auto rhsstr = rhs.as!string;
258         if (lhsstr != rhsstr)
259         {
260             result ~= Entry(lhs, rhs, format!"Different string value: (\"%s\", \"%s\")"(lhsstr, rhsstr));
261         }
262         break;
263     default:
264         // nop
265         break;
266     }
267     return Diff(result);
268 }
269 
270 struct Diff
271 {
272     struct Entry
273     {
274         Node lhs, rhs;
275         string message;
276 
277         string toString() const pure @safe
278         {
279             import std.format : format;
280             import std.range : empty;
281             auto mark = lhs.startMark.name.empty ? rhs.startMark : lhs.startMark;
282             return format!"%s:%s:%s: %s"(mark.name, mark.line+1, mark.column,
283                                          message);
284         }
285     }
286 
287     Entry[] entries;
288 
289     string toString() const pure @safe
290     {
291         import std.algorithm : joiner, map;
292         import std.array : array;
293         import std.conv : to;
294 
295         return entries.map!(to!string).joiner("\n").array.to!string;
296     }
297 }