mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-29 18:20:35 +08:00
[v1] add v1 launcher (#9236)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -12,22 +12,55 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
|
||||
def run_train():
|
||||
raise NotImplementedError("Please use `llamafactory-cli sft` or `llamafactory-cli rm`.")
|
||||
from ..extras.env import VERSION, print_env
|
||||
|
||||
|
||||
def run_chat():
|
||||
from llamafactory.v1.core.chat_sampler import Sampler
|
||||
|
||||
Sampler().cli()
|
||||
USAGE = (
|
||||
"-" * 70
|
||||
+ "\n"
|
||||
+ "| Usage: |\n"
|
||||
+ "| llamafactory-cli sft -h: train models |\n"
|
||||
+ "| llamafactory-cli version: show version info |\n"
|
||||
+ "| Hint: You can use `lmf` as a shortcut for `llamafactory-cli`. |\n"
|
||||
+ "-" * 70
|
||||
)
|
||||
|
||||
|
||||
def run_sft():
|
||||
from llamafactory.v1.train.sft import SFTTrainer
|
||||
WELCOME = (
|
||||
"-" * 58
|
||||
+ "\n"
|
||||
+ f"| Welcome to LLaMA Factory, version {VERSION}"
|
||||
+ " " * (21 - len(VERSION))
|
||||
+ "|\n|"
|
||||
+ " " * 56
|
||||
+ "|\n"
|
||||
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
|
||||
+ "-" * 58
|
||||
)
|
||||
|
||||
SFTTrainer().run()
|
||||
|
||||
def launch():
|
||||
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
|
||||
|
||||
if command == "sft":
|
||||
from .trainers.sft_trainer import run_sft
|
||||
|
||||
run_sft()
|
||||
|
||||
elif command == "env":
|
||||
print_env()
|
||||
|
||||
elif command == "version":
|
||||
print(WELCOME)
|
||||
|
||||
elif command == "help":
|
||||
print(USAGE)
|
||||
|
||||
else:
|
||||
print(f"Unknown command: {command}.\n{USAGE}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_train()
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user