mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 09:52:14 +08:00 
			
		
		
		
	[v1] support switch v1 backend (#9226)
This commit is contained in:
		
							parent
							
								
									1d96c62df2
								
							
						
					
					
						commit
						7d60b840ef
					
				@ -36,52 +36,16 @@ USAGE = (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _run_api():
 | 
			
		||||
    from .api.app import run_api
 | 
			
		||||
 | 
			
		||||
    return run_api()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _run_chat():
 | 
			
		||||
    from .chat.chat_model import run_chat
 | 
			
		||||
 | 
			
		||||
    return run_chat()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _run_eval():
 | 
			
		||||
    raise NotImplementedError("Evaluation will be deprecated in the future.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _export_model():
 | 
			
		||||
    from .train.tuner import export_model
 | 
			
		||||
 | 
			
		||||
    return export_model()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _run_exp():
 | 
			
		||||
    from .train.tuner import run_exp
 | 
			
		||||
 | 
			
		||||
    return run_exp()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _run_web_demo():
 | 
			
		||||
    from .webui.interface import run_web_demo
 | 
			
		||||
 | 
			
		||||
    return run_web_demo()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _run_web_ui():
 | 
			
		||||
    from .webui.interface import run_web_ui
 | 
			
		||||
 | 
			
		||||
    return run_web_ui()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    from . import launcher
 | 
			
		||||
    from .extras import logging
 | 
			
		||||
    from .extras.env import VERSION, print_env
 | 
			
		||||
    from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
 | 
			
		||||
 | 
			
		||||
    if is_env_enabled("USE_V1"):
 | 
			
		||||
        from .v1 import launcher
 | 
			
		||||
    else:
 | 
			
		||||
        from . import launcher
 | 
			
		||||
 | 
			
		||||
    logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
    WELCOME = (
 | 
			
		||||
@ -97,14 +61,14 @@ def main():
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    COMMAND_MAP = {
 | 
			
		||||
        "api": _run_api,
 | 
			
		||||
        "chat": _run_chat,
 | 
			
		||||
        "api": launcher.run_api,
 | 
			
		||||
        "chat": launcher.run_chat,
 | 
			
		||||
        "env": print_env,
 | 
			
		||||
        "eval": _run_eval,
 | 
			
		||||
        "export": _export_model,
 | 
			
		||||
        "train": _run_exp,
 | 
			
		||||
        "webchat": _run_web_demo,
 | 
			
		||||
        "webui": _run_web_ui,
 | 
			
		||||
        "eval": launcher.run_eval,
 | 
			
		||||
        "export": launcher.export_model,
 | 
			
		||||
        "train": launcher.run_exp,
 | 
			
		||||
        "webchat": launcher.run_web_demo,
 | 
			
		||||
        "webui": launcher.run_web_ui,
 | 
			
		||||
        "version": partial(print, WELCOME),
 | 
			
		||||
        "help": partial(print, USAGE),
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -13,11 +13,45 @@
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def launch():
 | 
			
		||||
    from llamafactory.train.tuner import run_exp  # use absolute import
 | 
			
		||||
def run_api():
 | 
			
		||||
    from llamafactory.api.app import run_api as _run_api
 | 
			
		||||
 | 
			
		||||
    run_exp()
 | 
			
		||||
    _run_api()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_chat():
 | 
			
		||||
    from llamafactory.chat.chat_model import run_chat as _run_chat
 | 
			
		||||
 | 
			
		||||
    return _run_chat()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_eval():
 | 
			
		||||
    raise NotImplementedError("Evaluation will be deprecated in the future.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def export_model():
 | 
			
		||||
    from llamafactory.train.tuner import export_model as _export_model
 | 
			
		||||
 | 
			
		||||
    return _export_model()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_exp():
 | 
			
		||||
    from llamafactory.train.tuner import run_exp as _run_exp
 | 
			
		||||
 | 
			
		||||
    return _run_exp()  # use absolute import
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_web_demo():
 | 
			
		||||
    from llamafactory.webui.interface import run_web_demo as _run_web_demo
 | 
			
		||||
 | 
			
		||||
    return _run_web_demo()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_web_ui():
 | 
			
		||||
    from llamafactory.webui.interface import run_web_ui as _run_web_ui
 | 
			
		||||
 | 
			
		||||
    return _run_web_ui()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    launch()
 | 
			
		||||
    run_exp()
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,33 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_train():
 | 
			
		||||
    raise NotImplementedError("Please use `llamafactory-cli sft` or `llamafactory-cli rm`.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_chat():
 | 
			
		||||
    from llamafactory.v1.core.chat_sampler import Sampler
 | 
			
		||||
 | 
			
		||||
    Sampler().cli()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_sft():
 | 
			
		||||
    from llamafactory.v1.train.sft import SFTTrainer
 | 
			
		||||
 | 
			
		||||
    SFTTrainer().run()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    run_train()
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user