1 | // See LICENSE file for copyright and license details |
2 | |
3 | // Package env implements the "shrt env" command |
4 | package env |
5 | |
6 | import ( |
7 | "bufio" |
8 | "bytes" |
9 | "context" |
10 | "errors" |
11 | "fmt" |
12 | "io" |
13 | "io/fs" |
14 | "log" |
15 | "os" |
16 | "path/filepath" |
17 | "strings" |
18 | |
19 | "djmo.ch/go-shrt" |
20 | "djmo.ch/go-shrt/cmd/shrt/internal/base" |
21 | ) |
22 | |
23 | const ( |
24 | srvNameDefault = "example.com" |
25 | scmTypeDefault = "git" |
26 | suffixDefault = ".git" |
27 | rdrNameDefault = "github.com/user" |
28 | bareRdrDefault = "example.org" |
29 | goSourceDirDefault = "" |
30 | goSourceFileDefault = "" |
31 | ) |
32 | |
33 | var Cmd = &base.Command{ |
34 | Name: "env", |
35 | Usage: "shrt env [-u] [-w] [var ...]", |
36 | ShortHelp: "print Shrt environment information", |
37 | LongHelp: `Env prints Shrt environment information. |
38 | |
39 | By default env prints information as a shell script. If one or more |
40 | variable names is given as arguments, env prints the value of each |
41 | named variable on its own line. |
42 | |
43 | The -u flag requires one or more arguments and unsets |
44 | the default setting for the named environment variables, |
45 | if one has been set with 'shrt env -w'. |
46 | |
47 | The -w flag requires one or more arguments of the |
48 | form NAME=VALUE and changes the default settings |
49 | of the named environment variables to the given values. If the same |
50 | NAME is provided multiple times, the last one takes effect. |
51 | |
52 | For more about environment variables, see 'shrt help environment'. |
53 | `, |
54 | } |
55 | |
56 | var ( |
57 | envU = Cmd.Flags.Bool("u", false, "") |
58 | envW = Cmd.Flags.Bool("w", false, "") |
59 | ) |
60 | |
61 | func init() { |
62 | // break init cycle |
63 | Cmd.Run = runEnv |
64 | } |
65 | |
66 | func runEnv(ctx context.Context) { |
67 | var ( |
68 | w = ctx.Value("w").(io.Writer) |
69 | args = ctx.Value("args").([]string) |
70 | ) |
71 | if *envU && *envW { |
72 | log.Fatal("cannot use -w with -u") |
73 | } |
74 | |
75 | if *envU { |
76 | runEnvU(args) |
77 | return |
78 | } |
79 | |
80 | if *envW { |
81 | runEnvW(args) |
82 | return |
83 | } |
84 | |
85 | // Environment is already merged |
86 | if len(args) > 0 { |
87 | for _, arg := range args { |
88 | fmt.Fprintln(w, os.Getenv(arg)) |
89 | } |
90 | return |
91 | } |
92 | for _, key := range strings.Fields(base.KnownEnv) { |
93 | value := os.Getenv(key) |
94 | fmt.Fprintf(w, "%s=\"%s\"\n", key, value) |
95 | } |
96 | } |
97 | |
98 | func runEnvU(args []string) { |
99 | envPath := envOrDefault(base.SHRTENV, envDefault) |
100 | curEnv := readEnvFile(envPath) |
101 | |
102 | for _, arg := range args { |
103 | delete(curEnv, arg) |
104 | } |
105 | |
106 | writeEnvFile(envPath, curEnv) |
107 | } |
108 | |
109 | func runEnvW(args []string) { |
110 | envToWrite := make(map[string]string) |
111 | for _, arg := range args { |
112 | kv := strings.SplitN(arg, "=", 2) |
113 | if len(kv) == 1 { |
114 | log.Fatal("malformed argument: ", arg) |
115 | } |
116 | if !strings.Contains(base.KnownEnv, kv[0]) { |
117 | log.Fatal("unknown env variable: ", kv[0]) |
118 | } |
119 | envToWrite[kv[0]] = kv[1] |
120 | } |
121 | |
122 | envPath := envOrDefault(base.SHRTENV, envDefault) |
123 | curEnv := readEnvFile(envPath) |
124 | |
125 | for k, v := range envToWrite { |
126 | if k == base.SHRTENV { |
127 | log.Println(base.SHRTENV, "can only be set using the OS environment") |
128 | continue |
129 | } |
130 | if k == base.SHRT_DBPATH && !filepath.IsAbs(v) { |
131 | log.Println(base.SHRT_DBPATH, "must be an absolute path ... ignoring") |
132 | continue |
133 | } |
134 | curEnv[k] = v |
135 | } |
136 | |
137 | writeEnvFile(envPath, curEnv) |
138 | } |
139 | |
140 | func readEnvFile(path string) map[string]string { |
141 | envMap := make(map[string]string) |
142 | envFile, err := os.ReadFile(path) |
143 | if err != nil { |
144 | if !errors.Is(err, fs.ErrNotExist) { |
145 | log.Fatalf("error reading %s: %s", path, err) |
146 | } |
147 | return envMap |
148 | } |
149 | |
150 | s := bufio.NewScanner(bytes.NewReader(envFile)) |
151 | for s.Scan() { |
152 | kv := strings.SplitN(s.Text(), "=", 2) |
153 | if len(kv) == 1 { |
154 | log.Fatalf("malformed line in %s: %s", path, s.Text()) |
155 | } |
156 | |
157 | if !strings.Contains(base.KnownEnv, kv[0]) { |
158 | continue |
159 | } |
160 | envMap[kv[0]] = kv[1] |
161 | } |
162 | |
163 | return envMap |
164 | } |
165 | |
166 | func writeEnvFile(path string, envMap map[string]string) { |
167 | if err := os.MkdirAll(filepath.Dir(path), 0777); err != nil { |
168 | log.Fatalf("failed to create directory %s: %s", filepath.Dir(path), err) |
169 | } |
170 | |
171 | envFile, err := os.Create(path) |
172 | if err != nil { |
173 | log.Fatalf("failed to open %s for writing: %s", path, err) |
174 | } |
175 | defer envFile.Close() |
176 | |
177 | for k, v := range envMap { |
178 | fmt.Fprintf(envFile, "%s=%s\n", k, v) |
179 | } |
180 | } |
181 | |
182 | // ConfigFromEnv returns a Config object matching the current |
183 | // environment. |
184 | func ConfigFromEnv() shrt.Config { |
185 | return shrt.Config{ |
186 | SrvName: envOrDefault(base.SHRT_SRVNAME, srvNameDefault), |
187 | ScmType: envOrDefault(base.SHRT_SCMTYPE, scmTypeDefault), |
188 | Suffix: envOrDefault(base.SHRT_SUFFIX, suffixDefault), |
189 | RdrName: envOrDefault(base.SHRT_RDRNAME, rdrNameDefault), |
190 | BareRdr: envOrDefault(base.SHRT_BARERDR, bareRdrDefault), |
191 | // Trim the leading / to satisfy fs.FS |
192 | DbPath: strings.TrimPrefix(envOrDefault(base.SHRT_DBPATH, dbPathDefault), "/"), |
193 | GoSourceDir: envOrDefault(base.SHRT_GOSOURCEDIR, goSourceDirDefault), |
194 | GoSourceFile: envOrDefault(base.SHRT_GOSOURCEFILE, goSourceFileDefault), |
195 | } |
196 | } |
197 | |
198 | // MergeEnv merges the program's environment with that specified in |
199 | // SHRTENV. Values already specified in the environment take |
200 | // precedence. |
201 | func MergeEnv() { |
202 | envPath := envOrDefault(base.SHRTENV, envDefault) |
203 | envFile, err := os.ReadFile(envPath) |
204 | if err != nil { |
205 | if !errors.Is(err, fs.ErrNotExist) { |
206 | log.Fatalf("error reading %s: %s", envPath, err) |
207 | } |
208 | envFile = []byte{} |
209 | } |
210 | |
211 | // Read envfile into environment |
212 | s := bufio.NewScanner(bytes.NewReader(envFile)) |
213 | for s.Scan() { |
214 | kv := strings.SplitN(s.Text(), "=", 2) |
215 | if len(kv) == 1 { |
216 | log.Fatal("malformed line in SHRTENV: ", s.Text()) |
217 | } |
218 | |
219 | key := kv[0] |
220 | if !strings.Contains(base.KnownEnv, key) { |
221 | log.Fatal("unknown env var: ", key) |
222 | } |
223 | value := kv[1] |
224 | |
225 | if _, ok := os.LookupEnv(key); !ok { |
226 | os.Setenv(key, value) |
227 | } |
228 | } |
229 | |
230 | defaults := map[string]string{ |
231 | base.SHRTENV: envDefault, |
232 | base.SHRT_SRVNAME: srvNameDefault, |
233 | base.SHRT_SCMTYPE: scmTypeDefault, |
234 | base.SHRT_SUFFIX: suffixDefault, |
235 | base.SHRT_RDRNAME: rdrNameDefault, |
236 | base.SHRT_BARERDR: bareRdrDefault, |
237 | base.SHRT_DBPATH: dbPathDefault, |
238 | base.SHRT_GOSOURCEDIR: goSourceDirDefault, |
239 | base.SHRT_GOSOURCEFILE: goSourceFileDefault, |
240 | } |
241 | |
242 | // Populate missing environment variables with defaults |
243 | for _, key := range strings.Fields(base.KnownEnv) { |
244 | if _, ok := os.LookupEnv(key); !ok { |
245 | os.Setenv(key, envOrDefault(key, defaults[key])) |
246 | } |
247 | } |
248 | } |
249 | |
250 | func envOrDefault(key, d string) string { |
251 | env, ok := os.LookupEnv(key) |
252 | if !ok { |
253 | return d |
254 | } |
255 | return env |
256 | } |